ReverseDiffZero

This is a fairly standard operator overloading based reverse mode AD system. It defines a Tracked type which carries the primal value as well as a reference to the tape which is it using, a partially accumulated partial derivative and a propagate function that propagates its partial back to its input. A perhaps unusual thing about it is how little it carries around its creating operator's inputs. That information is all entirely wrapped up in the propagate function. The overload generation hook in this example is: define_tracked_overload.

"The simplest viable reverse mode a AD, only supports `Float64`"
module ReverseDiffZero
using ChainRulesCore
using ChainRulesOverloadGeneration
# resolve conflicts while this code exists in both.
const on_new_rule = ChainRulesOverloadGeneration.on_new_rule
const refresh_rules = ChainRulesOverloadGeneration.refresh_rules

using Test

#########################################
# Initial rule setup
@scalar_rule x + y (1, 1)
@scalar_rule x - y (1, -1)
##########################
#Define the AD

struct Tracked{F} <: Real
    propagate::F
    primal::Float64
    tape::Vector{Tracked}  # a reference to a shared tape
    partial::Base.RefValue{Float64}  # current accumulated sensitivity
end

"An intermediate value, a Branch in Nabla terms."
function Tracked(propagate, primal, tape)
    v = Tracked(propagate, primal, tape, Ref(zero(primal)))
    push!(tape, v)
    return v
end

"An input, a Leaf in Nabla terms. No inputs of its own to propagate to."
function Tracked(primal, tape)
    # don't actually need to put these on the tape, since they don't need to propagate
    return Tracked(NoPropagate(), primal, tape, Ref(zero(primal)))
end

"Marker for inputs (leaves) that don't need to propagate."
struct NoPropagate end

primal(d::Tracked) = d.primal
primal(d) = d

partial(d::Tracked) = d.partial[]
partial(d) = nothing

tape(d::Tracked) = d.tape
tape(d) = nothing

"we have many inputs grab the tape from the first one that is tracked"
get_tape(ds) = something(tape.(ds)...)

"propagate the currently stored partial back to my inputs."
propagate!(d::Tracked) = d.propagate(d.partial[])

"Accumulate the sensitivity, if the value is being tracked."
accum!(d::Tracked, x̄) = d.partial[] += x̄
accum!(d, x̄) = nothing

# needed for `^` to work from having `*` defined
Base.to_power_type(x::Tracked) = x

"What to do when a new rrule is declared"
function define_tracked_overload(sig)
    sig = Base.unwrap_unionall(sig)  # not really handling most UnionAll
    opT, argTs = Iterators.peel(sig.parameters)
    opT isa Type{<:Type} && return  # not handling constructors
    fieldcount(opT) == 0 || return  # not handling functors
    all(argT isa Type && Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops.

    N = length(sig.parameters) - 1  # skip the op
    fdef = quote
        # we use the function call overloading form as it lets us avoid namespacing issues
        # as we can directly interpolate the function type into to the AST.
        function (op::$opT)(tracked_args::Vararg{Union{Tracked, Float64}, $N}; kwargs...)
            args = (op, primal.(tracked_args)...)
            y, y_pullback = rrule(args...; kwargs...)
            the_tape = get_tape(tracked_args)
            y_tracked = Tracked(y, the_tape) do ȳ
                # pull this partial back and propagate it to the input's partial store
                _, ārgs = Iterators.peel(y_pullback(ȳ))
                accum!.(tracked_args, ārgs)
            end
            return y_tracked
        end
    end
    eval(fdef)
end

# !Important!: Attach the define function to the `on_new_rule` hook
on_new_rule(define_tracked_overload, rrule)

"Do a calculus. `f` should have a single output."
function derv(f, args::Vararg; kwargs...)
    the_tape = Vector{Tracked}()
    tracked_inputs = Tracked.(args, Ref(the_tape))
    tracked_output = f(tracked_inputs...; kwargs...)
    @assert tape(tracked_output) === the_tape

    # Now the backward pass
    out = primal(tracked_output)
    ōut = one(out)
    accum!(tracked_output, ōut)
    # By going down the tape backwards we know we will have fully accumulated partials
    # before propagating them onwards
    for op in reverse(the_tape)
        propagate!(op)
    end
    return partial.(tracked_inputs)
end

# End AD definition
################################

# add a rule later also
function ChainRulesCore.rrule(::typeof(*), x::Number, y::Number)
    function times_pullback(ΔΩ)
        # we will use thunks here to show we handle them fine.
        return (NoTangent(),  @thunk(ΔΩ * y'), @thunk(x' * ΔΩ))
    end
    return x * y, times_pullback
end

# Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call
refresh_rules();

@testset "ReversedDiffZero" begin
    foo(x) = x + x
    @test derv(foo, 1.6) == (2.0,)

    bar(x) = x + 2.1 * x
    @test derv(bar, 1.2) == (3.1,)

    baz(x) = 2.0 * x^2 + 3.0*x + 1.2
    @test derv(baz, 1.7) == (2 * 2.0 * 1.7 + 3.0,)

    qux(x) = foo(x) + bar(x) + baz(x)
    @test derv(qux, 1.7) == ((2 * 2.0 * 1.7 + 3.0) + 3.1 + 2,)

    function quux(x)
        y = 2.0*x + 3.0*x
        return 4.0*y + 5.0*y
    end
    @test derv(quux, 11.1) == (4*(2+3) + 5*(2+3),)
end
end  # module