API Documentation
Rules
ChainRulesCore.frule
— Methodfrule((Δf, Δx...), f, x...)
Expressing the output of f(x...)
as Ω
, return the tuple:
(Ω, ΔΩ)
The second return value is the differential w.r.t. the output.
If no method matching frule((Δf, Δx...), f, x...)
has been defined, then return nothing
.
Examples:
unary input, unary output scalar function:
julia> dself = NO_FIELDS;
julia> x = rand()
0.8236475079774124
julia> sinx, Δsinx = frule((dself, 1), sin, x)
(0.7336293678134624, 0.6795498147167869)
julia> sinx == sin(x)
true
julia> Δsinx == cos(x)
true
Unary input, binary output scalar function:
julia> sincosx, Δsincosx = frule((dself, 1), sincos, x);
julia> sincosx == sincos(x)
true
julia> Δsincosx[1] == cos(x)
true
julia> Δsincosx[2] == -sin(x)
true
Note that techically speaking julia does not have multiple output functions, just functions that return a single output that is iterable, like a Tuple
. So this is actually a Tangent
:
julia> Δsincosx
Tangent{Tuple{Float64, Float64}}(0.6795498147167869, -0.7336293678134624)
.
See also: rrule
, @scalar_rule
ChainRulesCore.rrule
— Methodrrule(f, x...)
Expressing x
as the tuple (x₁, x₂, ...)
and the output tuple of f(x...)
as Ω
, return the tuple:
(Ω, (Ω̄₁, Ω̄₂, ...) -> (s̄elf, x̄₁, x̄₂, ...))
Where the second return value is the the propagation rule or pullback. It takes in differentials corresponding to the outputs (x̄₁, x̄₂, ...
), and s̄elf
, the internal values of the function itself (for closures)
If no method matching rrule(f, xs...)
has been defined, then return nothing
.
Examples:
unary input, unary output scalar function:
julia> x = rand();
julia> sinx, sin_pullback = rrule(sin, x);
julia> sinx == sin(x)
true
julia> sin_pullback(1) == (NO_FIELDS, cos(x))
true
binary input, unary output scalar function:
julia> x, y = rand(2);
julia> hypotxy, hypot_pullback = rrule(hypot, x, y);
julia> hypotxy == hypot(x, y)
true
julia> hypot_pullback(1) == (NO_FIELDS, (x / hypot(x, y)), (y / hypot(x, y)))
true
See also: frule
, @scalar_rule
Rule Definition Tools
ChainRulesCore.@non_differentiable
— Macro@non_differentiable(signature_expression)
A helper to make it easier to declare that a method is not not differentiable. This is a short-hand for defining an frule
and rrule
that return NoTangent()
for all partials (even for the function s̄elf
-partial itself)
Keyword arguments should not be included.
julia> @non_differentiable Base.:(==)(a, b)
julia> _, pullback = rrule(==, 2.0, 3.0);
julia> pullback(1.0)
(NoTangent(), NoTangent(), NoTangent())
You can place type-constraints in the signature:
julia> @non_differentiable Base.length(xs::Union{Number, Array})
julia> frule((ZeroTangent(), 1), length, [2.0, 3.0])
(2, NoTangent())
This helper macro covers only the simple common cases. It does not support where
-clauses. For these you can declare the rrule
and frule
directly
ChainRulesCore.@scalar_rule
— Macro@scalar_rule(f(x₁, x₂, ...),
@setup(statement₁, statement₂, ...),
(∂f₁_∂x₁, ∂f₁_∂x₂, ...),
(∂f₂_∂x₁, ∂f₂_∂x₂, ...),
...)
A convenience macro that generates simple scalar forward or reverse rules using the provided partial derivatives. Specifically, generates the corresponding methods for frule
and rrule
:
function ChainRulesCore.frule((NO_FIELDS, Δx₁, Δx₂, ...), ::typeof(f), x₁::Number, x₂::Number, ...)
Ω = f(x₁, x₂, ...)
$(statement₁, statement₂, ...)
return Ω, (
(∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...),
(∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...),
...
)
end
function ChainRulesCore.rrule(::typeof(f), x₁::Number, x₂::Number, ...)
Ω = f(x₁, x₂, ...)
$(statement₁, statement₂, ...)
return Ω, ((ΔΩ₁, ΔΩ₂, ...)) -> (
NO_FIELDS,
∂f₁_∂x₁ * ΔΩ₁ + ∂f₂_∂x₁ * ΔΩ₂ + ...),
∂f₁_∂x₂ * ΔΩ₁ + ∂f₂_∂x₂ * ΔΩ₂ + ...),
...
)
end
If no type constraints in f(x₁, x₂, ...)
within the call to @scalar_rule
are provided, each parameter in the resulting frule
/rrule
definition is given a type constraint of Number
. Constraints may also be explicitly be provided to override the Number
constraint, e.g. f(x₁::Complex, x₂)
, which will constrain x₁
to Complex
and x₂
to Number
.
At present this does not support defining for closures/functors. Thus in reverse-mode, the first returned partial, representing the derivative with respect to the function itself, is always NO_FIELDS
. And in forward-mode, the first input to the returned propagator is always ignored.
The result of f(x₁, x₂, ...)
is automatically bound to Ω
. This allows the primal result to be conveniently referenced (as Ω
) within the derivative/setup expressions.
This macro assumes complex functions are holomorphic. In general, for non-holomorphic functions, the frule
and rrule
must be defined manually.
If the derivative is one, (e.g. for identity functions) true
can be used as the most general multiplicative identity.
The @setup
argument can be elided if no setup code is need. In other words:
@scalar_rule(f(x₁, x₂, ...),
(∂f₁_∂x₁, ∂f₁_∂x₂, ...),
(∂f₂_∂x₁, ∂f₂_∂x₂, ...),
...)
is equivalent to:
@scalar_rule(f(x₁, x₂, ...),
@setup(nothing),
(∂f₁_∂x₁, ∂f₁_∂x₂, ...),
(∂f₂_∂x₁, ∂f₂_∂x₂, ...),
...)
For examples, see ChainRules' rulesets
directory.
Differentials
ChainRulesCore.AbstractZero
— TypeAbstractZero <: AbstractTangent
Supertype for zero-like differentials—i.e., differentials that act like zero when added or multiplied to other values. If an AD system encounters a propagator that takes as input only subtypes of AbstractZero
, then it can stop performing AD operations. All propagators are linear functions, and thus the final result will be zero.
All AbstractZero
subtypes are singleton types. There are two of them: ZeroTangent()
and NoTangent()
.
ChainRulesCore.NoTangent
— TypeNoTangent() <: AbstractZero
This differential indicates that the derivative does not exist. It is the differential for primal types that are not differentiable, such as integers or booleans (when they are not being used to represent floating-point values). The only valid way to perturb such values is to not change them at all. As a consequence, NoTangent
is functionally identical to ZeroTangent()
, but it provides additional semantic information.
Adding this differential to a primal is generally wrong: gradient-based methods cannot be used to optimize over discrete variables. An optimization package making use of this might want to check for such a case.
This does not indicate that the derivative is not implemented, but rather that mathematically it is not defined.
This mostly shows up as the derivative with respect to dimension, index, or size arguments.
function rrule(fill, x, len::Int)
y = fill(x, len)
fill_pullback(ȳ) = (NO_FIELDS, @thunk(sum(Ȳ)), NoTangent())
return y, fill_pullback
end
ChainRulesCore.ZeroTangent
— TypeZeroTangent() <: AbstractZero
The additive identity for differentials. This is basically the same as 0
. A derivative of ZeroTangent()
does not propagate through the primal function.
ChainRulesCore.One
— Type One()
The Differential which is the multiplicative identity. Basically, this represents 1
.
ChainRulesCore.NO_FIELDS
— ConstantNO_FIELDS
Constant for the reverse-mode derivative with respect to a structure that has no fields. The most notable use for this is for the reverse-mode derivative with respect to the function itself, when that function is not a closure.
ChainRulesCore.Tangent
— TypeTangent{P, T} <: AbstractTangent
This type represents the differential for a struct
/NamedTuple
, or Tuple
. P
is the the corresponding primal type that this is a differential for.
Tangent{P}
should have fields (technically properties), that match to a subset of the fields of the primal type; and each should be a differential type matching to the primal type of that field. Fields of the P that are not present in the Tangent are treated as Zero
.
T
is an implementation detail representing the backing data structure. For Tuple it will be a Tuple, and for everything else it will be a NamedTuple
. It should not be passed in by user.
For Tangent
s of Tuple
s, iterate
and getindex
are overloaded to behave similarly to for a tuple. For Tangent
s of struct
s, getproperty
is overloaded to allow for accessing values via comp.fieldname
. Any fields not explictly present in the Tangent
are treated as being set to ZeroTangent()
. To make a Tangent
have all the fields of the primal the canonicalize
function is provided.
ChainRulesCore.canonicalize
— Methodcanonicalize(comp::Tangent{P}) -> Tangent{P}
Return the canonical Tangent
for the primal type P
. The property names of the returned Tangent
match the field names of the primal, and all fields of P
not present in the input comp
are explictly set to ZeroTangent()
.
ChainRulesCore.InplaceableThunk
— TypeInplaceableThunk(val::Thunk, add!::Function)
A wrapper for a Thunk
, that allows it to define an inplace add!
function.
add!
should be defined such that: ithunk.add!(Δ) = Δ .+= ithunk.val
but it should do this more efficently than simply doing this directly. (Otherwise one can just use a normal Thunk
).
Most operations on an InplaceableThunk
treat it just like a normal Thunk
; and destroy its inplacability.
ChainRulesCore.Thunk
— TypeThunk(()->v)
A thunk is a deferred computation. It wraps a zero argument closure that when invoked returns a differential. @thunk(v)
is a macro that expands into Thunk(()->v)
.
Calling a thunk, calls the wrapped closure. If you are unsure if you have a Thunk
, call unthunk
which is a no-op when the argument is not a Thunk
. If you need to unthunk recursively, call extern
, which also externs the differial that the closure returns.
julia> t = @thunk(@thunk(3))
Thunk(var"#4#6"())
julia> extern(t)
3
julia> t()
Thunk(var"#5#7"())
julia> t()()
3
When to @thunk
?
When writing rrule
s (and to a lesser exent frule
s), it is important to @thunk
appropriately. Propagation rules that return multiple derivatives may not have all deriviatives used. By @thunk
ing the work required for each derivative, they then compute only what is needed.
How do thunks prevent work?
If we have res = pullback(...) = @thunk(f(x)), @thunk(g(x))
then if we did dx + res[1]
then only f(x)
would be evaluated, not g(x)
. Also if we did ZeroTangent() * res[1]
then the result would be ZeroTangent()
and f(x)
would not be evaluated.
So why not thunk everything?
@thunk
creates a closure over the expression, which (effectively) creates a struct
with a field for each variable used in the expression, and call overloaded.
Do not use @thunk
if this would be equal or more work than actually evaluating the expression itself. This is commonly the case for scalar operators.
For more details see the manual section on using thunks effectively
ChainRulesCore.unthunk
— Methodunthunk(x)
On AbstractThunk
s this removes 1 layer of thunking. On any other type, it is the identity operation.
In contrast to extern
this is nonrecursive.
ChainRulesCore.@thunk
— Macro@thunk expr
Define a Thunk
wrapping the expr
, to lazily defer its evaluation.
ChainRulesCore.extern
— Methodextern(x)
Makes a best effort attempt to convert a differential into a primal value. This is not always a well-defined operation. For two reasons:
- It may not be possible to determine the primal type for a given differential.
For example, Zero
is a valid differential for any primal.
- The primal type might not be a vector space, thus might not be a valid differential type.
For example, if the primal type is DateTime
, it's not a valid differential type as two DateTime
can not be added (fun fact: Milisecond
is a differential for DateTime
).
Where it is defined the operation of extern
for a primal type P
should be extern(x) = zero(P) + x
.
Because of its limitations, extern
should only really be used for testing. It can be useful, if you know what you are getting out, as it recursively removes thunks, and otherwise makes outputs more consistent with finite differencing.
The more useful action in general is to call +
, or in the case of a Thunk
to call unthunk
.
extern
may return an alias (not necessarily a copy) to data wrapped by x
, such that mutating extern(x)
might mutate x
itself.
ChainRulesCore.@not_implemented
— Macro@not_implemented(info)
Create a differential that indicates that the derivative is not implemented.
The info
should be useful information about the missing differential for debugging.
This macro should be used only if the automatic differentiation would error otherwise. It is mostly useful if the function has multiple inputs or outputs, and one has worked out analytically and implemented some but not all differentials.
It is good practice to include a link to a GitHub issue about the missing differential in the debugging information.
Accumulation
ChainRulesCore.add!!
— Functionadd!!(x, y)
Returns x+y
, potentially mutating x
in-place to hold this value. This avoids allocations when x
can be mutated in this way.
add!!(x, t::InplacableThunk)
The specialization of add!!
for InplaceableThunk
promises to only call t.add!
on x
if x
is suitably mutable; otherwise it will be out of place.
ChainRulesCore.is_inplaceable_destination
— Functionis_inplaceable_destination(x) -> Bool
Returns true if x
is suitable for for storing inplace accumulation of gradients. For arrays this boils down x .= y
if will work to mutate x
, if y
is an appropriate differential. Wrapper array types do not need to overload this if they overload Base.parent
, and are is_inplaceable_destination
if and only if their parent array is. Other types should overload this, as it defaults to false
.
Ruleset Loading
ChainRulesCore.on_new_rule
— Methodon_new_rule(hook, frule | rrule)
Register a hook
function to run when new rules are defined. The hook receives a signature type-type as input, and generally will use eval
to define an overload of an AD system's overloaded type For example, using the signature type Tuple{typeof(+), Real, Real}
to make +(::DualNumber, ::DualNumber)
call the frule
for +
. A signature type tuple always has the form: Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2}...}
, where pos_arg1
is the first positional argument.
The hooks are automatically run on new rules whenever a package is loaded. They can be manually triggered by refresh_rules
. When a hook is first registered with on_new_rule
it is run on all existing rules.
ChainRulesCore.refresh_rules
— Methodrefresh_rules()
refresh_rules(frule | rrule)
This triggers all on_new_rule
hooks to run on any newly defined rules. It is automatically run when ever a package is loaded. It can also be manually called to run it directly, for example if a rule was defined in the REPL or within the same file as the AD function.
Internal
ChainRulesCore.AbstractTangent
— TypeThe subtypes of AbstractTangent
define a custom "algebra" for chain rule evaluation that attempts to factor various features like complex derivative support, broadcast fusion, zero-elision, etc. into nicely separated parts.
In general a differential type is the type of a derivative of a value. The type of the value is for contrast called the primal type. Differential types correspond to primal types, although the relation is not one-to-one. Subtypes of AbstractTangent
are not the only differential types. In fact for the most common primal types, such as Real
or AbstractArray{Real}
the the differential type is the same as the primal type.
In a circular definition: the most important property of a differential is that it should be able to be added (by defining +
) to another differential of the same primal type. That allows for gradients to be accumulated.
It generally also should be able to be added to a primal to give back another primal, as this facilitates gradient descent.
All subtypes of AbstractTangent
implement the following operations:
+(a, b)
: linearly combine differentiala
and differentialb
*(a, b)
: multiply the differentialb
by the scaling factora
Base.zero(x) = ZeroTangent()
: a zero.
Further, they often implement other linear operators, such as conj
, adjoint
, dot
. Pullbacks/pushforwards are linear operators, and their inputs are often AbstractTangent
subtypes. Pullbacks/pushforwards in-turn call other linear operators on those inputs. Thus it is desirable to have all common linear operators work on AbstractTangent
s.
ChainRulesCore.debug_mode
— Functiondebug_mode() -> Bool
Determines if ChainRulesCore is in debug_mode
. Defaults to false
, but if the user redefines it to return true
then extra information will be shown when errors occur.
Enable via:
ChainRulesCore.debug_mode() = true
ChainRulesCore.clear_new_rule_hooks!
— Functionclear_new_rule_hooks!(frule|rrule)
Clears all hooks that were registered with corresponding on_new_rule
. This is useful for while working interactively to define your rule generating hooks. If you previously wrong an incorrect hook, you can use this to get rid of the old one.
This absolutely should not be used in a package, as it will break any other AD system using the rule hooks that might happen to be loaded.