Operator Overloading

The principal interface for using the operator overload generation method is on_new_rule. This function allows one to register a hook to be run every time a new rule is defined. The hook receives a signature type-type as input, and generally will use eval to define an overload of an AD system's overloaded type. For example, using the signature type Tuple{typeof(+), Real, Real} to make +(::DualNumber, ::DualNumber) call the frule for +. A signature type tuple always has the form: Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2}, ...}, where pos_arg1 is the first positional argument. One can dispatch on the signature type to make rules with argument types your AD does not support not call eval; or more simply you can just use conditions for this. For example if your AD only supports AbstractMatrix{Float64} and Float64 inputs you might write:

const ACCEPT_TYPE = Union{Float64, AbstractMatrix{Float64}} 
function define_overload(sig::Type{<:Tuple{F, Vararg{ACCEPT_TYPE}}) where F
    @eval quote
        # ...
    end
end
define_overload(::Any) = nothing  # don't do anything for any other signature

on_new_rule(define_overload, frule)

or you might write:

const ACCEPT_TYPES = (Float64, AbstractMatrix{Float64})
function define_overload(sig)
    sig = Base.unwrap_unionall(sig)  # not really handling most UnionAll,
    opT, argTs = Iterators.peel(sig.parameters)
    all(any(acceptT<: argT for acceptT in ACCEPT_TYPES) for argT in argTs) || return
    @eval quote
        # ...
    end
end

on_new_rule(define_overload, frule)

The generation of overloaded code is the responsibility of the AD implementor. Packages like ExprTools.jl can be helpful for this. Its generally fairly simple, though can become complex if you need to handle complicated type-constraints. Examples are shown below.

The hook is automatically triggered whenever a package is loaded. It can also be triggers manually using refresh_rules(@ref). This is useful for example if new rules are define in the REPL, or if a package defining rules is modified. (Revise.jl will not automatically trigger). When the rules are refreshed (automatically or manually), the hooks are only triggered on new/modified rules; not ones that have already had the hooks triggered on.

clear_new_rule_hooks!(@ref) clears all registered hooks. It is useful to undo [on_new_rule] hook registration if you are iteratively developing your overload generation function.

Examples

ForwardDiffZero

The overload generation hook in this example is: define_dual_overload.

"The simplest viable forward mode a AD, only supports `Float64`"
module ForwardDiffZero
using ChainRulesCore
using Test

#########################################
# Initial rule setup
@scalar_rule x + y (1, 1)
@scalar_rule x - y (1, -1)
##########################
# Define the AD

# Note that we never directly define Dual Number Arithmetic on Dual numbers
# instead it is automatically defined from the `frules`
struct Dual <: Real
    primal::Float64
    partial::Float64
end

primal(d::Dual) = d.primal
partial(d::Dual) = d.partial

primal(d::Real) = d
partial(d::Real) = 0.0

# needed for `^` to work from having `*` defined
Base.to_power_type(x::Dual) = x


function define_dual_overload(sig)
    sig = Base.unwrap_unionall(sig)  # Not really handling most UnionAlls
    opT, argTs = Iterators.peel(sig.parameters)
    opT isa Type{<:Type} && return  # not handling constructors
    fieldcount(opT) == 0 || return  # not handling functors
    all(argT isa Type && Float64 <: argT for argT in argTs) || return  # only handling purely Float64 ops.

    N = length(sig.parameters) - 1  # skip the op
    fdef = quote
        # we use the function call overloading form as it lets us avoid namespacing issues
        # as we can directly interpolate the function type into to the AST.
        function (op::$opT)(dual_args::Vararg{Union{Dual, Float64}, $N}; kwargs...)
            ȧrgs = (NO_FIELDS,  partial.(dual_args)...)
            args = (op, primal.(dual_args)...)
            y, ẏ = frule(ȧrgs, args...; kwargs...)
            return Dual(y, ẏ)  # if y, ẏ are not `Float64` this will error.
        end
    end
    eval(fdef)
end

# !Important!: Attach the define function to the `on_new_rule` hook
on_new_rule(define_dual_overload, frule)

"Do a calculus. `f` should have a single input."
function derv(f, arg)
    duals = Dual(arg, one(arg))
    return partial(f(duals...))
end

# End AD definition
################################

# add a rule later also
function ChainRulesCore.frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number)
    return (x * y, Δx * y + x * Δy)
end

# Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call
refresh_rules();

@testset "ForwardDiffZero" begin
    foo(x) = x + x
    @test derv(foo, 1.6) == 2

    bar(x) = x + 2.1 * x
    @test derv(bar, 1.2) == 3.1

    baz(x) = 2.0 * x^2 + 3.0*x + 1.2
    @test derv(baz, 1.7) == 2*2.0*1.7 + 3.0

    qux(x) = foo(x) + bar(x) + baz(x)
    @test derv(qux, 1.7) == (2*2.0*1.7 + 3.0) + 3.1 + 2

    function quux(x)
        y = 2.0*x + 3.0*x
        return 4.0*y + 5.0*y
    end
    @test derv(quux, 11.1) == 4*(2+3) + 5*(2+3)
end

end  # module

ReverseDiffZero

The overload generation hook in this example is: define_tracked_overload.

"The simplest viable reverse mode a AD, only supports `Float64`"
module ReverseDiffZero
using ChainRulesCore
using Test

#########################################
# Initial rule setup
@scalar_rule x + y (1, 1)
@scalar_rule x - y (1, -1)
##########################
#Define the AD

struct Tracked{F} <: Real
    propagate::F
    primal::Float64
    tape::Vector{Tracked}  # a reference to a shared tape
    partial::Base.RefValue{Float64} # current accumulated sensitivity
end

"An intermediate value, a Branch in Nabla terms."
function Tracked(propagate, primal, tape)
    v = Tracked(propagate, primal, tape, Ref(zero(primal)))
    push!(tape, v)
    return v
end

"Marker for inputs (leaves) that don't need to propagate."
struct NoPropagate end

"An input, a Leaf in Nabla terms. No inputs of its own to propagate to."
function Tracked(primal, tape)
    # don't actually need to put these on the tape, since they don't need to propagate
    return Tracked(NoPropagate(), primal, tape, Ref(zero(primal)))
end

primal(d::Tracked) = d.primal
primal(d) = d

partial(d::Tracked) = d.partial[]
partial(d) = nothing

tape(d::Tracked) = d.tape
tape(d) = nothing

"we have many inputs grab the tape from the first one that is tracked"
get_tape(ds) = something(tape.(ds)...)

"propagate the currently stored partial back to my inputs."
propagate!(d::Tracked) = d.propagate(d.partial[])

"Accumulate the sensitivity, if the value is being tracked."
accum!(d::Tracked, x̄) = d.partial[] += x̄
accum!(d, x̄) = nothing

# needed for `^` to work from having `*` defined
Base.to_power_type(x::Tracked) = x

"What to do when a new rrule is declared"
function define_tracked_overload(sig)
    sig = Base.unwrap_unionall(sig)  # not really handling most UnionAll
    opT, argTs = Iterators.peel(sig.parameters)
    opT isa Type{<:Type} && return  # not handling constructors
    fieldcount(opT) == 0 || return  # not handling functors
    all(argT isa Type && Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops.

    N = length(sig.parameters) - 1  # skip the op
    fdef = quote
        # we use the function call overloading form as it lets us avoid namespacing issues
        # as we can directly interpolate the function type into to the AST.
        function (op::$opT)(tracked_args::Vararg{Union{Tracked, Float64}, $N}; kwargs...)
            args = (op, primal.(tracked_args)...)
            y, y_pullback = rrule(args...; kwargs...)
            the_tape = get_tape(tracked_args)
            y_tracked = Tracked(y, the_tape) do ȳ
                # pull this partial back and propagate it to the input's partial store
                _, ārgs = Iterators.peel(y_pullback(ȳ))
                accum!.(tracked_args, ārgs)
            end
            return y_tracked
        end
    end
    eval(fdef)
end

# !Important!: Attach the define function to the `on_new_rule` hook
on_new_rule(define_tracked_overload, rrule)

"Do a calculus. `f` should have a single output."
function derv(f, args::Vararg; kwargs...)
    the_tape = Vector{Tracked}()
    tracked_inputs = Tracked.(args, Ref(the_tape))
    tracked_output = f(tracked_inputs...; kwargs...)
    @assert tape(tracked_output) === the_tape

    # Now the backward pass
    out = primal(tracked_output)
    ōut = one(out)
    accum!(tracked_output, ōut)
    # By going down the tape backwards we know we will have fully accumulated partials
    # before propagating them onwards
    for op in reverse(the_tape)
        propagate!(op)
    end
    return partial.(tracked_inputs)
end

# End AD definition
################################

# add a rule later also
function ChainRulesCore.rrule(::typeof(*), x::Number, y::Number)
    function times_pullback(ΔΩ)
        # we will use thunks here to show we handle them fine.
        return (NO_FIELDS,  @thunk(ΔΩ * y'), @thunk(x' * ΔΩ))
    end
    return x * y, times_pullback
end

# Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call
refresh_rules();

@testset "ReversedDiffZero" begin
    foo(x) = x + x
    @test derv(foo, 1.6) == (2.0,)

    bar(x) = x + 2.1 * x
    @test derv(bar, 1.2) == (3.1,)

    baz(x) = 2.0 * x^2 + 3.0*x + 1.2
    @test derv(baz, 1.7) == (2 * 2.0 * 1.7 + 3.0,)

    qux(x) = foo(x) + bar(x) + baz(x)
    @test derv(qux, 1.7) == ((2 * 2.0 * 1.7 + 3.0) + 3.1 + 2,)

    function quux(x)
        y = 2.0*x + 3.0*x
        return 4.0*y + 5.0*y
    end
    @test derv(quux, 11.1) == (4*(2+3) + 5*(2+3),)
end
end  # module