Using ChainRules in your AD system
This section is for authors of AD systems. It assumes a pretty solid understanding of both Julia and automatic differentiation. It explains how to make use of ChainRule's "rulesets" (frule
s, rrule
s,) to avoid having to code all your own AD primitives / custom sensitives.
There are 3 main ways to access ChainRules rule sets in your AutoDiff system.
- Operator Overloading Generation
- Use ChainRulesOverloadGeneration.jl.
- This is primarily intended for operator overloading based AD systems which will generate overloads for primal functions based for their overloaded types based on the existence of an
rrule
/frule
. - A source code generation based AD can also use this by overloading their transform generating function directly so as not to recursively generate a transform but to just return the rule.
- This does not play nice with Revise.jl, adding or modifying rules in loaded files will not be reflected until a manual refresh, and deleting rules will not be reflected at all.
- Source code transform based on inserting branches that check of
rrule
/frule
returnnothing
- If the
rrule
/frule
returns a rule result then use it, if it returnsnothing
then do normal AD path. - In theory type inference optimizes these branches out; in practice it may not.
- This is a fairly simple Cassette overdub (or similar) of all calls, and is suitable for overloading based AD or source code transformation.
- If the
- Source code transform based on
rrule
/frule
method-table- If an applicable
rrule
/frule
exists in the method table then use it, else generate normal AD path. - This avoids having branches in your generated code.
- This requires maintaining your own back-edges.
- This is pretty hardcore even by the standard of source code transformations.
- If an applicable