Operator Overloading
The principal interface for using the operator overload generation method is on_new_rule
. This function allows one to register a hook to be run every time a new rule is defined. The hook receives a signature type-type as input, and generally will use eval
to define an overload of an AD system's overloaded type. For example, using the signature type Tuple{typeof(+), Real, Real}
to make +(::DualNumber, ::DualNumber)
call the frule
for +
. A signature type tuple always has the form: Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2}, ...}
, where pos_arg1
is the first positional argument. One can dispatch on the signature type to make rules with argument types your AD does not support not call eval
; or more simply you can just use conditions for this. For example if your AD only supports AbstractMatrix{Float64}
and Float64
inputs you might write:
const ACCEPT_TYPE = Union{Float64, AbstractMatrix{Float64}}
function define_overload(sig::Type{<:Tuple{F, Vararg{ACCEPT_TYPE}}) where F
@eval quote
# ...
end
end
define_overload(::Any) = nothing # don't do anything for any other signature
on_new_rule(define_overload, frule)
or you might write:
const ACCEPT_TYPES = (Float64, AbstractMatrix{Float64})
function define_overload(sig)
sig = Base.unwrap_unionall(sig) # not really handling most UnionAll,
opT, argTs = Iterators.peel(sig.parameters)
all(any(acceptT<: argT for acceptT in ACCEPT_TYPES) for argT in argTs) || return
@eval quote
# ...
end
end
on_new_rule(define_overload, frule)
The generation of overloaded code is the responsibility of the AD implementor. Packages like ExprTools.jl can be helpful for this. Its generally fairly simple, though can become complex if you need to handle complicated type-constraints. Examples are shown below.
The hook is automatically triggered whenever a package is loaded. It can also be triggers manually using refresh_rules
(@ref). This is useful for example if new rules are define in the REPL, or if a package defining rules is modified. (Revise.jl will not automatically trigger). When the rules are refreshed (automatically or manually), the hooks are only triggered on new/modified rules; not ones that have already had the hooks triggered on.
clear_new_rule_hooks!
(@ref) clears all registered hooks. It is useful to undo [on_new_rule
] hook registration if you are iteratively developing your overload generation function.
Examples
ForwardDiffZero
The overload generation hook in this example is: define_dual_overload
.
"The simplest viable forward mode a AD, only supports `Float64`"
module ForwardDiffZero
using ChainRulesCore
using Test
#########################################
# Initial rule setup
@scalar_rule x + y (1, 1)
@scalar_rule x - y (1, -1)
##########################
# Define the AD
# Note that we never directly define Dual Number Arithmetic on Dual numbers
# instead it is automatically defined from the `frules`
struct Dual <: Real
primal::Float64
partial::Float64
end
primal(d::Dual) = d.primal
partial(d::Dual) = d.partial
primal(d::Real) = d
partial(d::Real) = 0.0
# needed for `^` to work from having `*` defined
Base.to_power_type(x::Dual) = x
function define_dual_overload(sig)
sig = Base.unwrap_unionall(sig) # Not really handling most UnionAlls
opT, argTs = Iterators.peel(sig.parameters)
opT isa Type{<:Type} && return # not handling constructors
fieldcount(opT) == 0 || return # not handling functors
all(argT isa Type && Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops.
N = length(sig.parameters) - 1 # skip the op
fdef = quote
# we use the function call overloading form as it lets us avoid namespacing issues
# as we can directly interpolate the function type into to the AST.
function (op::$opT)(dual_args::Vararg{Union{Dual, Float64}, $N}; kwargs...)
ȧrgs = (NO_FIELDS, partial.(dual_args)...)
args = (op, primal.(dual_args)...)
y, ẏ = frule(ȧrgs, args...; kwargs...)
return Dual(y, ẏ) # if y, ẏ are not `Float64` this will error.
end
end
eval(fdef)
end
# !Important!: Attach the define function to the `on_new_rule` hook
on_new_rule(define_dual_overload, frule)
"Do a calculus. `f` should have a single input."
function derv(f, arg)
duals = Dual(arg, one(arg))
return partial(f(duals...))
end
# End AD definition
################################
# add a rule later also
function ChainRulesCore.frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number)
return (x * y, Δx * y + x * Δy)
end
# Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call
refresh_rules();
@testset "ForwardDiffZero" begin
foo(x) = x + x
@test derv(foo, 1.6) == 2
bar(x) = x + 2.1 * x
@test derv(bar, 1.2) == 3.1
baz(x) = 2.0 * x^2 + 3.0*x + 1.2
@test derv(baz, 1.7) == 2*2.0*1.7 + 3.0
qux(x) = foo(x) + bar(x) + baz(x)
@test derv(qux, 1.7) == (2*2.0*1.7 + 3.0) + 3.1 + 2
function quux(x)
y = 2.0*x + 3.0*x
return 4.0*y + 5.0*y
end
@test derv(quux, 11.1) == 4*(2+3) + 5*(2+3)
end
end # module
ReverseDiffZero
The overload generation hook in this example is: define_tracked_overload
.
"The simplest viable reverse mode a AD, only supports `Float64`"
module ReverseDiffZero
using ChainRulesCore
using Test
#########################################
# Initial rule setup
@scalar_rule x + y (1, 1)
@scalar_rule x - y (1, -1)
##########################
#Define the AD
struct Tracked{F} <: Real
propagate::F
primal::Float64
tape::Vector{Tracked} # a reference to a shared tape
partial::Base.RefValue{Float64} # current accumulated sensitivity
end
"An intermediate value, a Branch in Nabla terms."
function Tracked(propagate, primal, tape)
v = Tracked(propagate, primal, tape, Ref(zero(primal)))
push!(tape, v)
return v
end
"Marker for inputs (leaves) that don't need to propagate."
struct NoPropagate end
"An input, a Leaf in Nabla terms. No inputs of its own to propagate to."
function Tracked(primal, tape)
# don't actually need to put these on the tape, since they don't need to propagate
return Tracked(NoPropagate(), primal, tape, Ref(zero(primal)))
end
primal(d::Tracked) = d.primal
primal(d) = d
partial(d::Tracked) = d.partial[]
partial(d) = nothing
tape(d::Tracked) = d.tape
tape(d) = nothing
"we have many inputs grab the tape from the first one that is tracked"
get_tape(ds) = something(tape.(ds)...)
"propagate the currently stored partial back to my inputs."
propagate!(d::Tracked) = d.propagate(d.partial[])
"Accumulate the sensitivity, if the value is being tracked."
accum!(d::Tracked, x̄) = d.partial[] += x̄
accum!(d, x̄) = nothing
# needed for `^` to work from having `*` defined
Base.to_power_type(x::Tracked) = x
"What to do when a new rrule is declared"
function define_tracked_overload(sig)
sig = Base.unwrap_unionall(sig) # not really handling most UnionAll
opT, argTs = Iterators.peel(sig.parameters)
opT isa Type{<:Type} && return # not handling constructors
fieldcount(opT) == 0 || return # not handling functors
all(argT isa Type && Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops.
N = length(sig.parameters) - 1 # skip the op
fdef = quote
# we use the function call overloading form as it lets us avoid namespacing issues
# as we can directly interpolate the function type into to the AST.
function (op::$opT)(tracked_args::Vararg{Union{Tracked, Float64}, $N}; kwargs...)
args = (op, primal.(tracked_args)...)
y, y_pullback = rrule(args...; kwargs...)
the_tape = get_tape(tracked_args)
y_tracked = Tracked(y, the_tape) do ȳ
# pull this partial back and propagate it to the input's partial store
_, ārgs = Iterators.peel(y_pullback(ȳ))
accum!.(tracked_args, ārgs)
end
return y_tracked
end
end
eval(fdef)
end
# !Important!: Attach the define function to the `on_new_rule` hook
on_new_rule(define_tracked_overload, rrule)
"Do a calculus. `f` should have a single output."
function derv(f, args::Vararg; kwargs...)
the_tape = Vector{Tracked}()
tracked_inputs = Tracked.(args, Ref(the_tape))
tracked_output = f(tracked_inputs...; kwargs...)
@assert tape(tracked_output) === the_tape
# Now the backward pass
out = primal(tracked_output)
ōut = one(out)
accum!(tracked_output, ōut)
# By going down the tape backwards we know we will have fully accumulated partials
# before propagating them onwards
for op in reverse(the_tape)
propagate!(op)
end
return partial.(tracked_inputs)
end
# End AD definition
################################
# add a rule later also
function ChainRulesCore.rrule(::typeof(*), x::Number, y::Number)
function times_pullback(ΔΩ)
# we will use thunks here to show we handle them fine.
return (NO_FIELDS, @thunk(ΔΩ * y'), @thunk(x' * ΔΩ))
end
return x * y, times_pullback
end
# Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call
refresh_rules();
@testset "ReversedDiffZero" begin
foo(x) = x + x
@test derv(foo, 1.6) == (2.0,)
bar(x) = x + 2.1 * x
@test derv(bar, 1.2) == (3.1,)
baz(x) = 2.0 * x^2 + 3.0*x + 1.2
@test derv(baz, 1.7) == (2 * 2.0 * 1.7 + 3.0,)
qux(x) = foo(x) + bar(x) + baz(x)
@test derv(qux, 1.7) == ((2 * 2.0 * 1.7 + 3.0) + 3.1 + 2,)
function quux(x)
y = 2.0*x + 3.0*x
return 4.0*y + 5.0*y
end
@test derv(quux, 11.1) == (4*(2+3) + 5*(2+3),)
end
end # module