ForwardDiffZero
This is a fairly standard operator overloading-based forward mode AD system. It defines a Dual
part which holds both the primal value, paired with the partial derivative. It doesn't handle chunked-mode, or perturbation confusion. The overload generation hook in this example is: define_dual_overload
.
"The simplest viable forward mode a AD, only supports `Float64`"
module ForwardDiffZero
using ChainRulesCore
using ChainRulesOverloadGeneration
# resolve conflicts while this code exists in both.
const on_new_rule = ChainRulesOverloadGeneration.on_new_rule
const refresh_rules = ChainRulesOverloadGeneration.refresh_rules
using Test
#########################################
# Initial rule setup
@scalar_rule x + y (1, 1)
@scalar_rule x - y (1, -1)
##########################
# Define the AD
# Note that we never directly define Dual Number Arithmetic on Dual numbers
# instead it is automatically defined from the `frules`
struct Dual <: Real
primal::Float64
partial::Float64
end
primal(d::Dual) = d.primal
partial(d::Dual) = d.partial
primal(d::Real) = d
partial(d::Real) = 0.0
# needed for `^` to work from having `*` defined
Base.to_power_type(x::Dual) = x
function define_dual_overload(sig)
sig = Base.unwrap_unionall(sig) # Not really handling most UnionAlls
opT, argTs = Iterators.peel(sig.parameters)
opT isa Type{<:Type} && return # not handling constructors
fieldcount(opT) == 0 || return # not handling functors
all(argT isa Type && Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops.
N = length(sig.parameters) - 1 # skip the op
fdef = quote
# we use the function call overloading form as it lets us avoid namespacing issues
# as we can directly interpolate the function type into to the AST.
function (op::$opT)(dual_args::Vararg{Union{Dual, Float64}, $N}; kwargs...)
ȧrgs = (NoTangent(), partial.(dual_args)...)
args = (op, primal.(dual_args)...)
y, ẏ = frule(ȧrgs, args...; kwargs...)
return Dual(y, ẏ) # if y, ẏ are not `Float64` this will error.
end
end
eval(fdef)
end
# !Important!: Attach the define function to the `on_new_rule` hook
on_new_rule(define_dual_overload, frule)
"Do a calculus. `f` should have a single input."
function derv(f, arg)
duals = Dual(arg, one(arg))
return partial(f(duals...))
end
# End AD definition
################################
# add a rule later also
function ChainRulesCore.frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number)
return (x * y, Δx * y + x * Δy)
end
# Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call
refresh_rules();
@testset "ForwardDiffZero" begin
foo(x) = x + x
@test derv(foo, 1.6) == 2
bar(x) = x + 2.1 * x
@test derv(bar, 1.2) == 3.1
baz(x) = 2.0 * x^2 + 3.0*x + 1.2
@test derv(baz, 1.7) == 2*2.0*1.7 + 3.0
qux(x) = foo(x) + bar(x) + baz(x)
@test derv(qux, 1.7) == (2*2.0*1.7 + 3.0) + 3.1 + 2
function quux(x)
y = 2.0*x + 3.0*x
return 4.0*y + 5.0*y
end
@test derv(quux, 11.1) == 4*(2+3) + 5*(2+3)
end
end # module