On writing good
One() as return value
One() differential objects exist as an alternative to directly returning
I. They allow more optimal computation when chaining pullbacks/pushforwards, to avoid work. They should be used where possible.
If work is only required for one of the returned differentials, then it should be wrapped in a
@thunk (potentially using a
If there are multiple return values, their computation should almost always be wrapped in a
Do not wrap variables in a
@thunk; wrap the computations that fill those variables in
# good: ∂A = @thunk(foo(x)) return ∂A # bad: ∂A = foo(x) return @thunk(∂A)
In the bad example
foo(x) gets computed eagerly, and all that the thunk is doing is wrapping the already calculated result in a function that returns it.
Do not use
@thunk if this would be equal or more work than actually evaluating the expression itself. Examples being:
- The expression being a constant
- The expression is merely wrapping something in a
struct, such as
- The expression being itself a
- The expression being from another
frule; it would be
@thunked if required by the defining rule already.
- There is only one derivative being returned, so from the fact that the user called
rrulethey clearly will want to use that one.
Be careful with using
adjoint when you mean
Remember for complex numbers
adjoint(a)) takes the complex conjugate. Instead you probably want
transpose(a), unless you've already restricted
a to be a
Use named local functions for the
# good: function frule(::typeof(foo), x) Y = foo(x) function foo_pushforward(_, ẋ) return bar(ẋ) end return Y, foo_pushforward end #== output julia> frule(foo, 2) (4, var"#foo_pushforward#11"()) ==# # bad: function frule(::typeof(foo), x) return foo(x), (_, ẋ) -> bar(ẋ) end #== output: julia> frule(foo, 2) (4, var"##9#10"()) ==#
While this is more verbose, it ensures that if an error is thrown during the
gensym name of the local function will include the name you gave it. This makes it a lot simpler to debug from the stacktrace.
There are fairly decent tools for writing tests based on FiniteDifferences.jl. They are in
tests/test_utils.jl. Take a look at existing test and you should see how to do stuff.
Use finite differencing to test derivatives. Don't use analytical derivations for derivatives in the tests. Those are what you use to define the rules, and so can not be confidently used in the test. If you misread/misunderstood them, then your tests/implementation will have the same mistake.
CAS systems are your friends.
It is very easy to check gradients or derivatives with a computer algebra system (CAS) like WolframAlpha.