# ChainRules

Automatic differentiation (AD) is a set of techniques for obtaining derivatives of arbitrary functions. There are surprisingly many packages for doing AD in Julia. ChainRules isn't one of these packages.

The AD packages essentially combine derivatives of simple functions into derivatives of more complicated functions. They differ in the way they break down complicated functions into simple ones, but they all require a common set of derivatives of simple functions (rules).

ChainRules is an AD-independent set of rules, and a system for defining and testing rules.

A rule encodes knowledge about propagating derivatives, e.g. that the derivative with respect to `x`

of `a*x`

is `a`

, and the derivative of `sin(x)`

is `cos(x)`

, etc.

## ChainRules ecosystem organisation

The ChainRules ecosystem comprises:

- ChainRulesCore.jl: a system for defining rules, and a collection of tangent types.
- ChainRules.jl: a collection of rules for Julia Base and standard libraries.
- ChainRulesTestUtils.jl: utilities for testing rules using finite differences.

AD systems depend on ChainRulesCore.jl to get access to tangent types and the core rule definition functionality (`frule`

and `rrule`

), and on ChainRules.jl to benefit from the collection of rules for Julia Base and the standard libraries.

Packages that just want to define rules only need to depend on ChainRulesCore.jl, which is an exceptionally light dependency. They should also have a test-only dependency on ChainRulesTestUtils.jl to test the rules using finite differences.

Note that the packages with rules do not have to depend on AD systems, and neither do the AD systems have to depend on individual packages.

## ChainRules roll-out status

Numerous packages depend on ChainRulesCore to define rules for their functions.

6 AD engines currently use ChainRules to get access to rules:

Zygote.jl is a reverse-mode AD that supports using `rrule`

s, calling back into AD, and opting out of rules. However, its own ZygoteRules.jl primitives (`@adjoint`

s) take precedence before `rrule`

s when both are defined – even if the `@adjoint`

is less specific than the `rrule`

. Internally it uses its own set of tangent types, e.g. `nothing`

instead of `NoTangent`

/`ZeroTangent`

. It also `unthunk`

s every tangent.

Diffractor.jl is a forward- and reverse-mode AD that fully supports ChainRules, including calling back into AD, opting out of rules, and uses tangent types internally.

Yota is a reverse-mode AD that fully supports ChainRules, including calling back into AD, opting out of rules, and uses tangent types internally.

ReverseDiff is a reverse-mode AD that supports using `rrule`

s, but not calling back into AD and opting out of rules.

Nabla.jl is a reverse-mode AD that supports using `rrule`

s, but not opting out of rules, nor calling back into AD.

ReversePropagation.jl is a reverse-mode AD that supports using `rrule`

s for scalar functions, but not calling back into AD and opting out of rules.

## Key functionality

Consider a relationship $y = f(x)$, where $f$ is some function. Computing $y$ from $x$ is the original problem, called the *primal* computation, in contrast to the problem of computing derivatives. We say that the *primal function* $f$ takes a *primal input* $x$ and returns the *primal output* $y$.

ChainRules rules are concerned with propagating *tangents* of primal inputs to *tangents* of primal outputs (`frule`

, from forwards mode AD), and propagating *cotangents* of primal outputs to *cotangents* of primal inputs (`rrule`

, from reverse mode AD). To be able to do that, ChainRules also defines a small number of tangent types to represent tangents and cotangents.

Strictly speaking tangents, $ẋ = \frac{dx}{da}$, are propagated in `frule`

s, and cotangents, $x̄ = \frac{da}{dx}$, are propagated in `rrule`

s. However, in practice there is rarely a need to distinguish between the two: both are represented by the same tangent types. Thus, except when the detail might clarify, we refer to both as tangents.

`frule`

and `rrule`

are ChainRules specific terms. Their exact functioning is fairly ChainRules specific, though other tools have similar functions. The core notion is sometimes called *custom AD primitives*, *custom adjoints*, *custom gradients*, *custom sensitivities*. The whole field is a mess for terminology.

### Forward-mode AD rules (`frule`

s)

If we know the value of $ẋ = \frac{dx}{da}$ for some $a$ and we want to know $ẏ = \frac{dy}{da}$, the chain rule tells us that $ẏ = \frac{dy}{dx} ẋ$. Intuitively, we are pushing the derivative forward. This is the basis for forward-mode AD.

The `frule`

for $f$ encodes how to propagate the tangent of the primal input ($ẋ$) to the tangent of the primal output ($ẏ$).

The `frule`

signature for a function `foo(args...; kwargs...)`

is

```
function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...)
...
return y, ∂Y
end
```

where `y = foo(args; kwargs...)`

is the primal output, and `∂Y`

is the result of propagating the input tangents `Δself`

, `Δargs...`

forwards at the point in the domain of `foo`

described by `args`

. This propagation is call the pushforward. Often we will think of the `frule`

as having the primal computation `y = foo(args...; kwargs...)`

, and the pushforward `∂Y = pushforward(Δself, Δargs...)`

, even though they are not present in seperate forms in the code.

For example, the `frule`

for `sin(x)`

is:

```
function frule((_, Δx), ::typeof(sin), x)
return sin(x), cos(x) * Δx
end
```

### Reverse-mode AD rules (`rrule`

s)

If we know the value of $ȳ = \frac{da}{dy}$ for some $a$ and we want to know $x̄ = \frac{da}{dx}$, the chain rule tells us that $x̄ =ȳ \frac{dy}{dx}$. Intuitively, we are pushing the derivative backward. This is the basis for reverse-mode AD.

The `rrule`

for $f$ encodes how to propagate the cotangents of the primal output ($ȳ$) to the cotangent of the primal input ($x̄$).

The `rrule`

signature for a function `foo(args...; kwargs...)`

is

```
function rrule(::typeof(foo), args...; kwargs...)
...
return y, pullback
end
```

where `y`

(the primal output) must be equal to `foo(args...; kwargs...)`

. `pullback`

is a function to propagate the derivative information backwards at the point in the domain of `foo`

described by `args`

. That pullback function is used like: `∂self, ∂args... = pullback(Δy)`

Almost always the *pullback* will be declared locally within the `rrule`

, and will be a *closure* over some of the other arguments, and potentially over the primal result too.

For example, the `rrule`

for `sin(x)`

is:

```
function rrule(::typeof(sin), x)
sin_pullback(Δy) = (NoTangent(), cos(x)' * Δy)
return sin(x), sin_pullback
end
```

While `rrule`

takes only the arguments to the original function (the primal arguments) and returns a function (the pullback) that operates with the derivative information, the `frule`

does it all at once. This is because the `frule`

fuses the primal computation and the pushforward. This is an optimization that allows `frule`

s to contain single large operations that perform both the primal computation and the pushforward at the same time (for example solving an ODE). This operation is only possible in forward mode (where `frule`

is used) because the derivative information needed by the pushforward available with the `frule`

is invoked – it is about the primal function's inputs. In contrast, in reverse mode the derivative information needed by the pullback is about the primal function's output. Thus the reverse mode returns the pullback function which the caller (usually an AD system) keeps hold of until derivative information about the output is available.

### Tangent types

The types of (co)-tangents depend on the types of the primals. Scalar primals are represented by scalar tangents (e.g. `Float64`

tangent for a `Float64`

primal). Vector, matrix, and higher rank tensor primals can be represented by vector, matrix and tensor tangents.

ChainRules defines a `Tangent`

tangent type to represent tangents of `struct`

s, `Tuple`

s, `NamedTuple`

s, and `Dict`

s.

Additionally, for signalling semantics, we distinguish between two tangent types representing a zero tangent. `NoTangent`

type represent situtations in which the tangent space does not exist, e.g. an index into an array can not be perturbed. `ZeroTangent`

is used for cases where the tangent happens to be zero, e.g. because the primal argument is not used in the computation.

We also define `Thunk`

s to allow certain optimisation. `Thunk`

s are a wrapper over a computation that can potentially be avoided, depending on the downstream use.

See the section on tangent types for more details.

## Example of using ChainRules directly

While ChainRules is largely intended as a backend for autodiff systems, it can be used directly. In fact, this can be very useful if you can constrain the code you need to differentiate to only use things that have rules defined for. This was once how all neural network code worked.

Using ChainRules directly also helps get a feel for it.

```
using ChainRulesCore
function foo(x)
a = sin(x)
b = 0.2 + a
c = asin(b)
return c
end
# Define rules (alternatively get them for free via `using ChainRules`)
@scalar_rule(sin(x), cos(x))
@scalar_rule(+(x, y), (1.0, 1.0))
@scalar_rule(asin(x), inv(sqrt(1 - x^2)))
```

```
#### Find dfoo/dx via rrules
#### First the forward pass, gathering up the pullbacks
x = 3;
a, a_pullback = rrule(sin, x);
b, b_pullback = rrule(+, 0.2, a);
c, c_pullback = rrule(asin, b)
#### Then the backward pass calculating gradients
c̄ = 1; # ∂c/∂c
_, b̄ = c_pullback(c̄); # ∂c/∂b = ∂c/∂b ⋅ ∂c/∂c
_, _, ā = b_pullback(b̄); # ∂c/∂a = ∂c/∂b ⋅ ∂b/∂a
_, x̄ = a_pullback(ā); # ∂c/∂x = ∂c/∂a ⋅ ∂a/∂x
x̄ # ∂c/∂x = ∂foo/∂x
# output
-1.0531613736418153
```

```
#### Find dfoo/dx via frules
x = 3;
ẋ = 1; # ∂x/∂x
nofields = ZeroTangent(); # ∂self/∂self
a, ȧ = frule((nofields, ẋ), sin, x); # ∂a/∂x = ∂a/∂x ⋅ ∂x/∂x
b, ḃ = frule((nofields, ZeroTangent(), ȧ), +, 0.2, a); # ∂b/∂x = ∂b/∂a ⋅ ∂a/∂x
c, ċ = frule((nofields, ḃ), asin, b); # ∂c/∂x = ∂c/∂b ⋅ ∂b/∂x
ċ # ∂c/∂x = ∂foo/∂x
# output
-1.0531613736418153
```

```
#### Find dfoo/dx via FiniteDifferences.jl
using FiniteDifferences
central_fdm(5, 1)(foo, x)
# output
-1.0531613736418257
#### Find dfoo/dx via ForwardDiff.jl
using ForwardDiff
ForwardDiff.derivative(foo, x)
# output
-1.0531613736418153
#### Find dfoo/dx via Zygote.jl
using Zygote
Zygote.gradient(foo, x)
# output
(-1.0531613736418153,)
```