API Documentation

Rules

ChainRulesCore.fruleMethod
frule([::RuleConfig,] (Δf, Δx...), f, x...)

Expressing the output of f(x...) as Ω, return the tuple:

(Ω, ΔΩ)

The second return value is the differential w.r.t. the output.

If no method matching frule((Δf, Δx...), f, x...) has been defined, then return nothing.

Examples:

unary input, unary output scalar function:

julia> dself = NoTangent();

julia> x = rand()
0.8236475079774124

julia> sinx, Δsinx = frule((dself, 1), sin, x)
(0.7336293678134624, 0.6795498147167869)

julia> sinx == sin(x)
true

julia> Δsinx == cos(x)
true

Unary input, binary output scalar function:

julia> sincosx, Δsincosx = frule((dself, 1), sincos, x);

julia> sincosx == sincos(x)
true

julia> Δsincosx[1] == cos(x)
true

julia> Δsincosx[2] == -sin(x)
true

Note that techically speaking julia does not have multiple output functions, just functions that return a single output that is iterable, like a Tuple. So this is actually a Tangent:

julia> Δsincosx
Tangent{Tuple{Float64, Float64}}(0.6795498147167869, -0.7336293678134624)

The optional RuleConfig option allows specifying frules only for AD systems that support given features. If not needed, then it can be omitted and the frule without it will be hit as a fallback. This is the case for most rules.

source
ChainRulesCore.rruleMethod
rrule([::RuleConfig,] f, x...)

Expressing x as the tuple (x₁, x₂, ...) and the output tuple of f(x...) as Ω, return the tuple:

(Ω, (Ω̄₁, Ω̄₂, ...) -> (s̄elf, x̄₁, x̄₂, ...))

Where the second return value is the the propagation rule or pullback. It takes in differentials corresponding to the outputs (x̄₁, x̄₂, ...), and s̄elf, the internal values of the function itself (for closures)

If no method matching rrule(f, xs...) has been defined, then return nothing.

Examples:

unary input, unary output scalar function:

julia> x = rand();

julia> sinx, sin_pullback = rrule(sin, x);

julia> sinx == sin(x)
true

julia> sin_pullback(1) == (NoTangent(), cos(x))
true

binary input, unary output scalar function:

julia> x, y = rand(2);

julia> hypotxy, hypot_pullback = rrule(hypot, x, y);

julia> hypotxy == hypot(x, y)
true

julia> hypot_pullback(1) == (NoTangent(), (x / hypot(x, y)), (y / hypot(x, y)))
true

The optional RuleConfig option allows specifying rrules only for AD systems that support given features. If not needed, then it can be omitted and the rrule without it will be hit as a fallback. This is the case for most rules.

source

Rule Definition Tools

ChainRulesCore.@non_differentiableMacro
@non_differentiable(signature_expression)

A helper to make it easier to declare that a method is not differentiable. This is a short-hand for defining an frule and rrule that return NoTangent() for all partials (even for the function s̄elf-partial itself)

Keyword arguments should not be included.

julia> @non_differentiable Base.:(==)(a, b)

julia> _, pullback = rrule(==, 2.0, 3.0);

julia> pullback(1.0)
(NoTangent(), NoTangent(), NoTangent())

You can place type-constraints in the signature:

julia> @non_differentiable Base.length(xs::Union{Number, Array})

julia> frule((ZeroTangent(), 1), length, [2.0, 3.0])
(2, NoTangent())
Warning

This helper macro covers only the simple common cases. It does not support where-clauses. For these you can declare the rrule and frule directly

source
ChainRulesCore.@opt_outMacro
@opt_out frule([config], _, f, args...)
@opt_out rrule([config], f, args...)

This allows you to opt-out of an frule or an rrule by providing a more specific method, that says to use the AD system to differentiate it.

For example, consider some function foo(x::AbtractArray). In general, you know an efficient and generic way to implement its rrule. You do so, (likely making use of ProjectTo). But it actually turns out that for some FancyArray type it is better to let the AD do its thing.

Then you would write something like:

function rrule(::typeof(foo), x::AbstractArray)
foo_pullback(ȳ) = ...
return foo(x), foo_pullback
end

@opt_out rrule(::typeof(foo), ::FancyArray)

This will generate an rrule that returns nothing, and will also add a similar entry to ChainRulesCore.no_rrule.

Similar applies for frule and ChainRulesCore.no_frule

source
ChainRulesCore.@scalar_ruleMacro
@scalar_rule(f(x₁, x₂, ...),
@setup(statement₁, statement₂, ...),
(∂f₁_∂x₁, ∂f₁_∂x₂, ...),
(∂f₂_∂x₁, ∂f₂_∂x₂, ...),
...)

A convenience macro that generates simple scalar forward or reverse rules using the provided partial derivatives. Specifically, generates the corresponding methods for frule and rrule:

function ChainRulesCore.frule((NoTangent(), Δx₁, Δx₂, ...), ::typeof(f), x₁::Number, x₂::Number, ...)
Ω = f(x₁, x₂, ...)
$(statement₁, statement₂, ...) return Ω, ( (∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...), (∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...), ... ) end function ChainRulesCore.rrule(::typeof(f), x₁::Number, x₂::Number, ...) Ω = f(x₁, x₂, ...)$(statement₁, statement₂, ...)
return Ω, ((ΔΩ₁, ΔΩ₂, ...)) -> (
NoTangent(),
∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...),
∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...),
...
)
end

If no type constraints in f(x₁, x₂, ...) within the call to @scalar_rule are provided, each parameter in the resulting frule/rrule definition is given a type constraint of Number. Constraints may also be explicitly be provided to override the Number constraint, e.g. f(x₁::Complex, x₂), which will constrain x₁ to Complex and x₂ to Number.

At present this does not support defining for closures/functors. Thus in reverse-mode, the first returned partial, representing the derivative with respect to the function itself, is always NoTangent(). And in forward-mode, the first input to the returned propagator is always ignored.

The result of f(x₁, x₂, ...) is automatically bound to Ω. This allows the primal result to be conveniently referenced (as Ω) within the derivative/setup expressions.

This macro assumes complex functions are holomorphic. In general, for non-holomorphic functions, the frule and rrule must be defined manually.

If the derivative is one, (e.g. for identity functions) true can be used as the most general multiplicative identity.

The @setup argument can be elided if no setup code is need. In other words:

@scalar_rule(f(x₁, x₂, ...),
(∂f₁_∂x₁, ∂f₁_∂x₂, ...),
(∂f₂_∂x₁, ∂f₂_∂x₂, ...),
...)

is equivalent to:

@scalar_rule(f(x₁, x₂, ...),
@setup(nothing),
(∂f₁_∂x₁, ∂f₁_∂x₂, ...),
(∂f₂_∂x₁, ∂f₂_∂x₂, ...),
...)

For examples, see ChainRules' rulesets directory.

See also: frule, rrule.

source

Differentials

ChainRulesCore.AbstractZeroType
AbstractZero <: AbstractTangent

Supertype for zero-like differentials—i.e., differentials that act like zero when added or multiplied to other values. If an AD system encounters a propagator that takes as input only subtypes of AbstractZero, then it can stop performing AD operations. All propagators are linear functions, and thus the final result will be zero.

All AbstractZero subtypes are singleton types. There are two of them: ZeroTangent() and NoTangent().

source
ChainRulesCore.NoTangentType
NoTangent() <: AbstractZero

This differential indicates that the derivative does not exist. It is the differential for primal types that are not differentiable, such as integers or booleans (when they are not being used to represent floating-point values). The only valid way to perturb such values is to not change them at all. As a consequence, NoTangent is functionally identical to ZeroTangent(), but it provides additional semantic information.

Adding this differential to a primal is generally wrong: gradient-based methods cannot be used to optimize over discrete variables. An optimization package making use of this might want to check for such a case.

Note

This does not indicate that the derivative is not implemented, but rather that mathematically it is not defined.

This mostly shows up as the derivative with respect to dimension, index, or size arguments.

    function rrule(fill, x, len::Int)
y = fill(x, len)
fill_pullback(ȳ) = (NoTangent(), @thunk(sum(Ȳ)), NoTangent())
return y, fill_pullback
end
source
ChainRulesCore.ZeroTangentType
ZeroTangent() <: AbstractZero

The additive identity for differentials. This is basically the same as 0. A derivative of ZeroTangent() does not propagate through the primal function.

source
ChainRulesCore.TangentType
Tangent{P, T} <: AbstractTangent

This type represents the differential for a struct/NamedTuple, or Tuple. P is the the corresponding primal type that this is a differential for.

Tangent{P} should have fields (technically properties), that match to a subset of the fields of the primal type; and each should be a differential type matching to the primal type of that field. Fields of the P that are not present in the Tangent are treated as Zero.

T is an implementation detail representing the backing data structure. For Tuple it will be a Tuple, and for everything else it will be a NamedTuple. It should not be passed in by user.

For Tangents of Tuples, iterate and getindex are overloaded to behave similarly to for a tuple. For Tangents of structs, getproperty is overloaded to allow for accessing values via tangent.fieldname. Any fields not explictly present in the Tangent are treated as being set to ZeroTangent(). To make a Tangent have all the fields of the primal the canonicalize function is provided.

source
ChainRulesCore.canonicalizeMethod
canonicalize(tangent::Tangent{P}) -> Tangent{P}

Return the canonical Tangent for the primal type P. The property names of the returned Tangent match the field names of the primal, and all fields of P not present in the input tangent are explictly set to ZeroTangent().

source
ChainRulesCore.InplaceableThunkType
InplaceableThunk(add!::Function, val::Thunk)

A wrapper for a Thunk, that allows it to define an inplace add! function.

add! should be defined such that: ithunk.add!(Δ) = Δ .+= ithunk.val but it should do this more efficently than simply doing this directly. (Otherwise one can just use a normal Thunk).

Most operations on an InplaceableThunk treat it just like a normal Thunk; and destroy its inplacability.

source
ChainRulesCore.ThunkType
Thunk(()->v)

A thunk is a deferred computation. It wraps a zero argument closure that when invoked returns a differential. @thunk(v) is a macro that expands into Thunk(()->v).

To evaluate the wrapped closure, call unthunk which is a no-op when the argument is not a Thunk.

julia> t = @thunk(3)
Thunk(var"#4#5"())

julia> unthunk(t)
3

When to @thunk?

When writing rrules (and to a lesser exent frules), it is important to @thunk appropriately. Propagation rules that return multiple derivatives may not have all deriviatives used. By @thunking the work required for each derivative, they then compute only what is needed.

How do thunks prevent work?

If we have res = pullback(...) = @thunk(f(x)), @thunk(g(x)) then if we did dx + res[1] then only f(x) would be evaluated, not g(x). Also if we did ZeroTangent() * res[1] then the result would be ZeroTangent() and f(x) would not be evaluated.

So why not thunk everything?

@thunk creates a closure over the expression, which (effectively) creates a struct with a field for each variable used in the expression, and call overloaded.

Do not use @thunk if this would be equal or more work than actually evaluating the expression itself. This is commonly the case for scalar operators.

For more details see the manual section on using thunks effectively

source
ChainRulesCore.@not_implementedMacro
@not_implemented(info)

Create a differential that indicates that the derivative is not implemented.

The info should be useful information about the missing differential for debugging.

Note

This macro should be used only if the automatic differentiation would error otherwise. It is mostly useful if the function has multiple inputs or outputs, and one has worked out analytically and implemented some but not all differentials.

Note

It is good practice to include a link to a GitHub issue about the missing differential in the debugging information.

source

Accumulation

ChainRulesCore.add!!Function
add!!(x, y)

Returns x+y, potentially mutating x in-place to hold this value. This avoids allocations when x can be mutated in this way.

source
add!!(x, t::InplacableThunk)

The specialization of add!! for InplaceableThunk promises to only call t.add! on x if x is suitably mutable; otherwise it will be out of place.

source
ChainRulesCore.is_inplaceable_destinationFunction
is_inplaceable_destination(x) -> Bool

Returns true if x is suitable for for storing inplace accumulation of gradients. For arrays this boils down x .= y if will work to mutate x, if y is an appropriate differential. Wrapper array types do not need to overload this if they overload Base.parent, and are is_inplaceable_destination if and only if their parent array is. Other types should overload this, as it defaults to false.

source

RuleConfig

ChainRulesCore.RuleConfigType
RuleConfig{T}

The configuration for what rules to use. T: traits. This should be a Union of all special traits needed for rules to be allowed to be defined for your AD. If nothing special this should be set to Union{}.

AD authors should define a subtype of RuleConfig to use when calling frule/rrule.

Rule authors can dispatch on this config when defining rules. For example:

# only define rrule for pop! on AD systems where mutation is supported.
rrule(::RuleConfig{>:SupportsMutation}, typeof(pop!), ::Vector) = ...

# this definition of map is for any AD that defines a forwards mode
rrule(conf::RuleConfig{>:HasForwardsMode}, typeof(map), ::Vector) = ...

# this definition of map is for any AD that only defines a reverse mode.
# It is not as good as the rrule that can be used if the AD defines a forward-mode as well.
rrule(conf::RuleConfig{>:Union{NoForwardsMode, HasReverseMode}}, typeof(map), ::Vector) = ...

For more details see rule configurations and calling back into AD.

source
ChainRulesCore.frule_via_adFunction
frule_via_ad(::RuleConfig{>:HasForwardsMode}, ȧrgs, f, args...; kwargs...)

This function has the same API as frule, but operates via performing forwards mode automatic differentiation. Any RuleConfig subtype that supports the HasForwardsMode special feature must provide an implementation of it.

See also: rrule_via_ad, RuleConfig and the documentation on rule configurations and calling back into AD

source
ChainRulesCore.rrule_via_adFunction
rrule_via_ad(::RuleConfig{>:HasReverseMode}, f, args...; kwargs...)

This function has the same API as rrule, but operates via performing reverse mode automatic differentiation. Any RuleConfig subtype that supports the HasReverseMode special feature must provide an implementation of it.

See also: frule_via_ad, RuleConfig and the documentation on rule configurations and calling back into AD

source

ProjectTo

ChainRulesCore.ProjectToType
(p::ProjectTo{T})(dx)

Projects the differential dx onto a specific tangent space.

The type T is meant to encode the largest acceptable space, so usually this enforces p(dx)::T. But some subspaces which aren't subtypes of T may be allowed, and in particular dx::AbstractZero always passes through.

Usually T is the "outermost" part of the type, and p stores additional properties such as projectors for each constituent field. Arrays have either one projector p.element expressing the element type for an array of numbers, or else an array of projectors p.elements. These properties can be supplied as keyword arguments on construction, p = ProjectTo{T}(; field=data, element=Projector(x)). For each T in use, corresponding methods should be written for ProjectTo{T}(dx) with nonzero dx.

When called on dx::Thunk, the projection is inserted into the thunk.

source

ChainRulesCore.ignore_derivativesFunction
ignore_derivatives(f::Function)

Tells the AD system to ignore the gradients of the wrapped closure. The primal computation (forward pass) is executed normally.

ignore_derivatives() do
value = rand()
push!(collection, value)
end

Using this incorrectly could lead to incorrect gradients. For example, the following function will have zero gradients with respect to its argument:

function wrong_grads(x)
y = ones(3)
ignore_derivatives() do
push!(y, x)
end
return sum(y)
end
source
ignore_derivatives(x)

Tells the AD system to ignore the gradients of the argument. Can be used to avoid unnecessary computation of gradients.

ignore_derivatives(x) * w
source

Internal

ChainRulesCore.AbstractTangentType

The subtypes of AbstractTangent define a custom "algebra" for chain rule evaluation that attempts to factor various features like complex derivative support, broadcast fusion, zero-elision, etc. into nicely separated parts.

In general a differential type is the type of a derivative of a value. The type of the value is for contrast called the primal type. Differential types correspond to primal types, although the relation is not one-to-one. Subtypes of AbstractTangent are not the only differential types. In fact for the most common primal types, such as Real or AbstractArray{Real} the the differential type is the same as the primal type.

In a circular definition: the most important property of a differential is that it should be able to be added (by defining +) to another differential of the same primal type. That allows for gradients to be accumulated.

It generally also should be able to be added to a primal to give back another primal, as this facilitates gradient descent.

All subtypes of AbstractTangent implement the following operations:

• +(a, b): linearly combine differential a and differential b
• *(a, b): multiply the differential b by the scaling factor a
• Base.zero(x) = ZeroTangent(): a zero.

Further, they often implement other linear operators, such as conj, adjoint, dot. Pullbacks/pushforwards are linear operators, and their inputs are often AbstractTangent subtypes. Pullbacks/pushforwards in-turn call other linear operators on those inputs. Thus it is desirable to have all common linear operators work on AbstractTangents.

source
ChainRulesCore.debug_modeFunction
debug_mode() -> Bool

Determines if ChainRulesCore is in debug_mode. Defaults to false, but if the user redefines it to return true then extra information will be shown when errors occur.

Enable via:

ChainRulesCore.debug_mode() = true
source
ChainRulesCore.no_rruleFunction
no_rrule

This is an piece of infastructure supporting opting out of rrule. It follows the signature for rrule exactly. A collection of type-tuples is stored in its method-table. If something has this defined, it means that it must having a must also have a rrule, defined that returns nothing.

!!! warning "do not overload norrule directly It is fine and intended to query the method table of norrule. It is not safe to add to that directly, as corresponding changes also need to be made torrule. The [@optout](@ref) macro does both these things, and so should almost always be used rather than defining a method ofnorrule directly.

Mechanics

note: when the text below says methods == it actually means: parameters(m.sig)[2:end] (i.e. the signature type tuple) rather than the method object m itself.

To decide if should opt-out using this mechanism.

• find the most specific method of rrule and no_rule e.g with Base.which
• if the method of no_rrule == the method of rrule, then should opt-out

To just ignore the fact that rules can be opted-out from, and that some rules thus return nothing, then filter the list of methods of rrule to remove those that are == to ones that occur in the method table of no_rrule.

Note also when doing this you must still also handle falling back from rule with config, to rule without config.

On the other-hand if your AD can work with rrules that return nothing, then it is simpler to just use that mechanism for opting out; and you don't need to worry about this at all.

See also ChainRulesCore.no_frule.

source
ChainRulesCore.no_fruleFunction
no_frule

This is an piece of infastructure supporting opting out of frule. It follows the signature for frule exactly. A collection of type-tuples is stored in its method-table. If something has this defined, it means that it must having a must also have a frule, defined that returns nothing.

!!! warning "do not overload nofrule directly It is fine and intended to query the method table of nofrule. It is not safe to add to that directly, as corresponding changes also need to be made tofrule. The [@optout](@ref) macro does both these things, and so should almost always be used rather than defining a method ofnofrule directly.

Mechanics

note: when the text below says methods == it actually means: parameters(m.sig)[2:end] (i.e. the signature type tuple) rather than the method object m itself.

To decide if should opt-out using this mechanism.

• find the most specific method of frule and no_rule e.g with Base.which
• if the method of no_frule == the method of frule, then should opt-out

To just ignore the fact that rules can be opted-out from, and that some rules thus return nothing, then filter the list of methods of frule to remove those that are == to ones that occur in the method table of no_frule.

Note also when doing this you must still also handle falling back from rule with config, to rule without config.

On the other-hand if your AD can work with frules that return nothing, then it is simpler to just use that mechanism for opting out; and you don't need to worry about this at all.

See also ChainRulesCore.no_rrule.