Writing Good Rules

# On writing good `rrule` / `frule` methods

## Use `Zero()` or `One()` as return value

The `Zero()` and `One()` differential objects exist as an alternative to directly returning `0` or `zeros(n)`, and `1` or `I`. They allow more optimal computation when chaining pullbacks/pushforwards, to avoid work. They should be used where possible.

## Use `Thunk`s appropriately

If work is only required for one of the returned differentials, then it should be wrapped in a `@thunk` (potentially using a `begin`-`end` block).

If there are multiple return values, their computation should almost always be wrapped in a `@thunk`.

Do not wrap variables in a `@thunk`; wrap the computations that fill those variables in `@thunk`:

``````# good:
∂A = @thunk(foo(x))
return ∂A

∂A = foo(x)
return @thunk(∂A)``````

In the bad example `foo(x)` gets computed eagerly, and all that the thunk is doing is wrapping the already calculated result in a function that returns it.

Do not use `@thunk` if this would be equal or more work than actually evaluating the expression itself. Examples being:

• The expression being a constant
• The expression is merely wrapping something in a `struct`, such as `Adjoint(x)` or `Diagonal(x)`
• The expression being itself a `thunk`
• The expression being from another `rrule` or `frule`; it would be `@thunk`ed if required by the defining rule already.
• There is only one derivative being returned, so from the fact that the user called `frule`/`rrule` they clearly will want to use that one.

## Be careful with using `adjoint` when you mean `transpose`

Remember for complex numbers `a'` (i.e. `adjoint(a)`) takes the complex conjugate. Instead you probably want `transpose(a)`, unless you've already restricted `a` to be a `AbstractMatrix{<:Real}`.

## Code Style

Use named local functions for the `pushforward`/`pullback`:

``````# good:
function frule(::typeof(foo), x)
Y = foo(x)
function foo_pushforward(_, ẋ)
return bar(ẋ)
end
return Y, foo_pushforward
end
#== output
julia> frule(foo, 2)
(4, var"#foo_pushforward#11"())
==#

function frule(::typeof(foo), x)
return foo(x), (_, ẋ) -> bar(ẋ)
end
#== output:
julia> frule(foo, 2)
(4, var"##9#10"())
==#``````

While this is more verbose, it ensures that if an error is thrown during the `pullback`/`pushforward` the `gensym` name of the local function will include the name you gave it. This makes it a lot simpler to debug from the stacktrace.

## Write tests

There are fairly decent tools for writing tests based on FiniteDifferences.jl. They are in `tests/test_utils.jl`. Take a look at existing test and you should see how to do stuff.

Warning

Use finite differencing to test derivatives. Don't use analytical derivations for derivatives in the tests. Those are what you use to define the rules, and so can not be confidently used in the test. If you misread/misunderstood them, then your tests/implementation will have the same mistake.

## CAS systems are your friends.

It is very easy to check gradients or derivatives with a computer algebra system (CAS) like WolframAlpha.