ChainRulesTestUtils
ChainRulesTestUtils.jl helps you test ChainRulesCore.frule
and ChainRulesCore.rrule
methods, when adding rules for your functions in your own packages. For information about ChainRules, including how to write rules, refer to the general ChainRules Documentation:
Testing Method Table Sensibility
A basic feature of ChainRulesTestUtils is its ability to check that the method tables for rrule
and frule
remain sensible. This searches the method tables for methods that should not exist and when it fails tells you where they were defined. By calling test_method_tables
ChainRulesTestUtils will check for things such as having attracted a rule to DataType
rather than attaching it to a constructor. Basically all packages using ChainRulesTestUtils can use test_method_tables
, as it is independent of what rules you have written.
Canonical example of testing frule and rrule
Let's suppose a custom transformation has been defined
function two2three(x1::Float64, x2::Float64)
return 1.0, 2.0*x1, 3.0*x2
end
# output
two2three (generic function with 1 method)
along with the frule
using ChainRulesCore
function ChainRulesCore.frule((Δf, Δx1, Δx2), ::typeof(two2three), x1, x2)
y = two2three(x1, x2)
∂y = Tangent{Tuple{Float64, Float64, Float64}}(ZeroTangent(), 2.0*Δx1, 3.0*Δx2)
return y, ∂y
end
# output
and rrule
which contains a mistake in the first cotangent
function ChainRulesCore.rrule(::typeof(two2three), x1, x2)
y = two2three(x1, x2)
function two2three_pullback(Ȳ)
return (NoTangent(), 2.1*Ȳ[2], 3.0*Ȳ[3])
end
return y, two2three_pullback
end
# output
The test_frule
/test_rrule
helper function compares the frule
/rrule
outputs to the gradients obtained by finite differencing. They can be used for any type and number of inputs and outputs.
Testing the frule
test_frule
takes in the function f
and the primal input x
. The call will test the frule
for function f
at the point x
in the domain. Keep this in mind when testing discontinuous rules for functions like ReLU, which should ideally be tested at both x
being above and below zero.
julia> using ChainRulesTestUtils;
julia> test_frule(two2three, 3.33, -7.77);
Test Summary: | Pass Total Time
test_frule: two2three on Float64,Float64 | 6 6 2.7s
Testing the rrule
test_rrule
takes in the function f
, and primal inputs x
. The call will test the rrule
for function f
at the point x
, and similarly to frule
some rules should be tested at multiple points in the domain.
julia> test_rrule(two2three, 3.33, -7.77);
test_rrule: two2three on Float64,Float64: Test Failed at /home/lior/.julia/dev/ChainRulesTestUtils/src/check_result.jl:24
Expression: isapprox(actual, expected; kwargs...)
Problem: cotangent for input 2, Float64
Evaluated: isapprox(-4.032, -3.840000000001641; rtol = 1.0e-9, atol = 1.0e-9)
[...]
The output of the test indicates to us the cause of the failure under "Problem:" with the expected (rrule
derived) and actual finite difference results. The Problem lies with the cotangent corresponding to input 2 of rrule
, which is the first cotangent as expected.
Scalar example
For functions with a single argument and a single output, such as e.g. ReLU,
function relu(x::Real)
return max(0, x)
end
# output
relu (generic function with 1 method)
with the frule
and rrule
defined with the help of @scalar_rule
macro
@scalar_rule relu(x::Real) x <= 0 ? zero(x) : one(x)
# output
test_scalar
function is provided to test both the frule
and the rrule
with a single call.
julia> test_scalar(relu, 0.5);
Test Summary: | Pass Total Time
test_scalar: relu at 0.5 | 12 12 1.2s
julia> test_scalar(relu, -0.5);
Test Summary: | Pass Total Time
test_scalar: relu at -0.5 | 12 12 0.0s
Testing constructors and functors (callable objects)
Testing constructor and functors works as you would expect. For struct Foo
,
struct Foo
a::Float64
end
(f::Foo)(x) = return f.a + x
Base.length(::Foo) = 1
Base.iterate(f::Foo) = iterate(f.a)
Base.iterate(f::Foo, state) = iterate(f.a, state)
after defining the constructor and functor f/rule
s,
function ChainRulesCore.rrule(::Type{Foo}, val) # constructor rrule
y = Foo(val)
Foo_pb(ΔFoo) = (NoTangent(), unthunk(ΔFoo).a)
return y, Foo_pb
end
function ChainRulesCore.rrule(foo::Foo, val) # functor rrule
y = foo(val)
function foo_pb(Δ)
Δut = unthunk(Δ)
return (Tangent{Foo}(;a=Δut), Δut)
end
return y, foo_pb
end
both f/rrule
s can be tested by
test_rrule(Foo, rand()) # constructor
foo = Foo(rand())
test_rrule(foo, rand()) # functor
# it is also possible to provide tangents for `foo` explicitly
test_frule(foo ⊢ Tangent{Foo}(;a=rand()), rand())
Specifying Tangents
test_frule
and test_rrule
allow you to specify the tangents used for testing. By default, tangents will be automatically generated via FiniteDifferences.rand_tangent
. To explicitly specify a tangent, pass in x ⊢ Δx
, where x
is the primal and Δx
is the tangent, in the place of the primal inputs. (You can enter ⊢
via \vdash
+ tab in the Julia REPL and supporting editors.) A special case of this is that if you specify it as x ⊢ NoTangent()
then finite differencing will not be used on that input. Similarly, by setting the output_tangent
keyword argument, you can specify the tangent for the primal output.
This can be useful when the default provided FiniteDifferences.rand_tangent
doesn't produce the desired tangent for your type. For example, the default tangent for an Int
is NoTangent()
, which is correct e.g. when the Int
represents a discrete integer like in indexing. But if you are testing something where the Int
is actually a special case of a real number, then you would want to specify the tangent as a Float64
.
Care must be taken when manually specifying tangents. In particular, when specifying the input tangents to test_frule
and the output tangent to test_rrule
. As these tangents are used to seed the derivative computation. Inserting inappropriate zeros can thus hide errors.
Testing higher order functions
Higher order functions, such as map
, take a function (or a functor) f
as an argument. f/rrule
s for these functions call back into AD to compute the f/rrule
of f
. To test these functions, we use a dummy AD system, which simply calls the appropriate rule for f
directly. For this reason, when testing map(f, collection)
, the rules for f
need to be defined. The RuleConfig
for this dummy AD system is the default one, and does not need to be provided.
test_rrule(map, x->2x [1, 2, 3.]) # fails, because there is no rrule for x->2x
mydouble(x) = 2x
function ChainRulesCore.rrule(::typeof(mydouble), x)
mydouble_pullback(ȳ) = (NoTangent(), ȳ)
return mydouble(x), mydouble_pullback
end
test_rrule(map, mydouble, [1, 2, 3.]) # works
Testing AD systems
The gradients computed by AD systems can be also be tested using test_rrule
. To do that, one needs to provide an rrule_f
/frule_f
keyword argument, as well as the RuleConfig
used by the AD system. rrule_f
is a function that wraps the gradient computation by an AD system in the same API as the rrule
. RuleConfig
is an object that determines which sets of rules are defined for an AD system. For example, let's say we have a complicated function
function complicated(x, y)
return do(x + y) + some(x) * hard(y) + maths(x * y)
end
that we do not know an rrule
for, and we want to check whether the gradients provided by the AD system are correct.
To test gradients computed by the AD system you need to provide a rrule_f
function that acts like calling rrule
but use AD rather than a defined rule. This has the exact same semantics as is required to overload ChainRulesCore.rrule_via_ad
, thus almost all systems doing so should just overload that, and pass in that and the config, and then trigger test_rrule(MyADConfig, f, xs; rrule_f = ChainRulesCore.rrule_via_ad)
. See more info on rrule_via_ad
and the rule configs in the ChainRules documentation. For some AD systems (e.g. Zygote) rrule_via_ad
already exists. If it does not exist, see How to write rrule_via_ad
function section below.
We use the test_rrule
function to test the gradients using the config used by the AD system
config = MyAD.CustomRuleConfig()
test_rrule(config, complicated, 2.3, 6.1; rrule_f=rrule_via_ad)
by providing the rule config and specifying the rrule_via_ad
as the rrule_f
keyword argument.
How to write rrule_via_ad
function
rrule_via_ad
will use the AD system to compute gradients and will package them in the rrule
-like API.
Let's say the AD package uses some custom differential types and does not provide a gradient w.r.t. the function itself. In order to make the pullback compatible with the rrule
API we need to add a NoTangent()
to represent the differential w.r.t. the function itself. We also need to transform the ChainRules
differential types to the custom types (cr2custom
) before feeding the Δ
to the AD-generated pullback, and back to ChainRules
differential types when returning from the rrule
(custom2cr
).
function rrule_via_ad(config::MyAD.CustomRuleConfig, f::Function, args...)
y, ad_pullback = MyAD.pullback(f, args...)
function rrulelike_pullback(Δ)
diffs = custom2cr(ad_pullback(cr2custom(Δ)))
return NoTangent(), diffs...
end
return y, rrulelike_pullback
end
custom2cr(differential) = ...
cr2custom(differential) = ...
Custom finite differencing
If a package is using a custom finite differencing method of testing the frule
s and rrule
s, test_approx
function provides a convenient way of comparing various types of differentials.
It is effectively (a, b) -> @test isapprox(a, b)
, but it preprocesses thunk
s and ChainRules
differential types ZeroTangent()
, NoTangent()
, and Tangent
, such that the error messages are helpful.
For example,
test_approx((@thunk 2*2.0), 4.1)
shows both the expression and the evaluated thunk
s
Expression: isapprox(actual, expected; kwargs...)
Evaluated: isapprox(4.0, 4.1)
ERROR: There was an error during testing
compared to
julia> @test isapprox(@thunk 2*2.0, 4.0)
Test Failed at REPL[52]:1
Expression: isapprox(#= REPL[52]:1 =# @thunk((2 * 2.0, 4.0)))
Evaluated: isapprox(Thunk(var"#24#25"()))
ERROR: There was an error during testing
which should have passed the test.
Inference tests
By default, all functions for testing rules check whether the output type (as well as that of the pullback for rrule
s) can be completely inferred, such that everything is type stable:
julia> function ChainRulesCore.rrule(::typeof(abs), x)
abs_pullback(Δ) = (NoTangent(), x >= 0 ? Δ : big(-1.0) * Δ)
return abs(x), abs_pullback
end
julia> test_rrule(abs, 1.)
test_rrule: abs on Float64: Error During Test at /home/runner/work/ChainRulesTestUtils.jl/ChainRulesTestUtils.jl/src/testers.jl:170
Got exception outside of a @test
return type Tuple{ChainRulesCore.NoTangent, Float64} does not match inferred return type Tuple{ChainRulesCore.NoTangent, Union{Float64, BigFloat}}
[...]
This can be disabled on a per-rule basis using the check_inferred
keyword argument:
julia> test_rrule(abs, 1.; check_inferred=false)
Test Summary: | Pass Total
test_rrule: abs on Float64 | 5 5
Test.DefaultTestSet("test_rrule: abs on Float64", Any[], 5, false, false)
This behavior can also be overridden globally by setting the environment variable CHAINRULES_TEST_INFERRED
before ChainRulesTestUtils is loaded or by changing ChainRulesTestUtils.TEST_INFERRED[]
from inside Julia. ChainRulesTestUtils can detect whether a test is run as part of PkgEval and in this case disables inference tests automatically. Packages can use @maybe_inferred
to get the same behavior for other inference tests.