Using ChainRules in your AD system
This section is for authors of AD systems. It assumes a pretty solid understanding of both Julia and automatic differentiation. It explains how to make use of ChainRule's "rulesets" (frules, rrules,) to avoid having to code all your own AD primitives / custom sensitives.
There are 3 main ways to access ChainRules rule sets in your AutoDiff system.
- Operator Overloading Generation
- Use ChainRulesOverloadGeneration.jl.
- This is primarily intended for operator overloading based AD systems which will generate overloads for primal functions based for their overloaded types based on the existence of an
rrule/frule. - A source code generation based AD can also use this by overloading their transform generating function directly so as not to recursively generate a transform but to just return the rule.
- This does not play nice with Revise.jl, adding or modifying rules in loaded files will not be reflected until a manual refresh, and deleting rules will not be reflected at all.
- Source code tranform based on inserting branches that check of
rrule/frulereturnnothing- If the
rrule/frulereturns a rule result then use it, if it returnsnothingthen do normal AD path. - In theory type inference optimizes these branchs out; in practice it may not.
- This is a fairly simple Cassette overdub (or similar) of all calls, and is suitable for overloading based AD or source code transformation.
- If the
- Source code transform based on
rrule/frulemethod-table- If an applicable
rrule/fruleexists in the method table then use it, else generate normal AD path. - This avoids having branches in your generated code.
- This requires maintaining your own back-edges.
- This is pretty hardcore even by the standard of source code tranformations.
- If an applicable