# ChainRulesTestUtils

ChainRulesTestUtils.jl helps you test `ChainRulesCore.frule`

and `ChainRulesCore.rrule`

methods, when adding rules for your functions in your own packages. For information about ChainRules, including how to write rules, refer to the general ChainRules Documentation:

## Testing Method Table Sensibility

A basic feature of ChainRulesTestUtils is its ability to check that the method tables for `rrule`

and `frule`

remain sensible. This searches the method tables for methods that should not exist and when it fails tells you where they were defined. By calling `test_method_tables`

ChainRulesTestUtils will check for things such as having attracted a rule to `DataType`

rather than attaching it to a constructor. Basically all packages using ChainRulesTestUtils can use `test_method_tables`

, as it is independent of what rules you have written.

## Canonical example of testing frule and rrule

Let's suppose a custom transformation has been defined

```
function two2three(x1::Float64, x2::Float64)
return 1.0, 2.0*x1, 3.0*x2
end
# output
two2three (generic function with 1 method)
```

along with the `frule`

```
using ChainRulesCore
function ChainRulesCore.frule((Δf, Δx1, Δx2), ::typeof(two2three), x1, x2)
y = two2three(x1, x2)
∂y = Tangent{Tuple{Float64, Float64, Float64}}(ZeroTangent(), 2.0*Δx1, 3.0*Δx2)
return y, ∂y
end
# output
```

and `rrule`

which contains a mistake in the first cotangent

```
function ChainRulesCore.rrule(::typeof(two2three), x1, x2)
y = two2three(x1, x2)
function two2three_pullback(Ȳ)
return (NoTangent(), 2.1*Ȳ[2], 3.0*Ȳ[3])
end
return y, two2three_pullback
end
# output
```

The `test_frule`

/`test_rrule`

helper function compares the `frule`

/`rrule`

outputs to the gradients obtained by finite differencing. They can be used for any type and number of inputs and outputs.

### Testing the `frule`

`test_frule`

takes in the function `f`

and the primal input `x`

. The call will test the `frule`

for function `f`

at the point `x`

in the domain. Keep this in mind when testing discontinuous rules for functions like ReLU, which should ideally be tested at both `x`

being above and below zero.

```
julia> using ChainRulesTestUtils;
julia> test_frule(two2three, 3.33, -7.77);
Test Summary: | Pass Total Time
test_frule: two2three on Float64,Float64 | 6 6 2.7s
```

### Testing the `rrule`

`test_rrule`

takes in the function `f`

, and primal inputs `x`

. The call will test the `rrule`

for function `f`

at the point `x`

, and similarly to `frule`

some rules should be tested at multiple points in the domain.

```
julia> test_rrule(two2three, 3.33, -7.77);
test_rrule: two2three on Float64,Float64: Test Failed at /home/lior/.julia/dev/ChainRulesTestUtils/src/check_result.jl:24
Expression: isapprox(actual, expected; kwargs...)
Problem: cotangent for input 2, Float64
Evaluated: isapprox(-4.032, -3.840000000001641; rtol = 1.0e-9, atol = 1.0e-9)
[...]
```

The output of the test indicates to us the cause of the failure under "Problem:" with the expected (`rrule`

derived) and actual finite difference results. The Problem lies with the cotangent corresponding to input 2 of `rrule`

, which is the first cotangent as expected.

## Scalar example

For functions with a single argument and a single output, such as e.g. ReLU,

```
function relu(x::Real)
return max(0, x)
end
# output
relu (generic function with 1 method)
```

with the `frule`

and `rrule`

defined with the help of `@scalar_rule`

macro

```
@scalar_rule relu(x::Real) x <= 0 ? zero(x) : one(x)
# output
```

`test_scalar`

function is provided to test both the `frule`

and the `rrule`

with a single call.

```
julia> test_scalar(relu, 0.5);
Test Summary: | Pass Total Time
test_scalar: relu at 0.5 | 12 12 1.2s
julia> test_scalar(relu, -0.5);
Test Summary: | Pass Total Time
test_scalar: relu at -0.5 | 12 12 0.0s
```

## Testing constructors and functors (callable objects)

Testing constructor and functors works as you would expect. For struct `Foo`

,

```
struct Foo
a::Float64
end
(f::Foo)(x) = return f.a + x
Base.length(::Foo) = 1
Base.iterate(f::Foo) = iterate(f.a)
Base.iterate(f::Foo, state) = iterate(f.a, state)
```

after defining the constructor and functor `f/rule`

s,

```
function ChainRulesCore.rrule(::Type{Foo}, val) # constructor rrule
y = Foo(val)
Foo_pb(ΔFoo) = (NoTangent(), unthunk(ΔFoo).a)
return y, Foo_pb
end
function ChainRulesCore.rrule(foo::Foo, val) # functor rrule
y = foo(val)
function foo_pb(Δ)
Δut = unthunk(Δ)
return (Tangent{Foo}(;a=Δut), Δut)
end
return y, foo_pb
end
```

both `f/rrule`

s can be tested by

```
test_rrule(Foo, rand()) # constructor
foo = Foo(rand())
test_rrule(foo, rand()) # functor
# it is also possible to provide tangents for `foo` explicitly
test_frule(foo ⊢ Tangent{Foo}(;a=rand()), rand())
```

## Specifying Tangents

`test_frule`

and `test_rrule`

allow you to specify the tangents used for testing. By default, tangents will be automatically generated via `FiniteDifferences.rand_tangent`

. To explicitly specify a tangent, pass in `x ⊢ Δx`

, where `x`

is the primal and `Δx`

is the tangent, in the place of the primal inputs. (You can enter `⊢`

via `\vdash`

+ tab in the Julia REPL and supporting editors.) A special case of this is that if you specify it as `x ⊢ NoTangent()`

then finite differencing will not be used on that input. Similarly, by setting the `output_tangent`

keyword argument, you can specify the tangent for the primal output.

This can be useful when the default provided `FiniteDifferences.rand_tangent`

doesn't produce the desired tangent for your type. For example, the default tangent for an `Int`

is `NoTangent()`

, which is correct e.g. when the `Int`

represents a discrete integer like in indexing. But if you are testing something where the `Int`

is actually a special case of a real number, then you would want to specify the tangent as a `Float64`

.

Care must be taken when manually specifying tangents. In particular, when specifying the input tangents to `test_frule`

and the output tangent to `test_rrule`

. As these tangents are used to seed the derivative computation. Inserting inappropriate zeros can thus hide errors.

## Testing higher order functions

Higher order functions, such as `map`

, take a function (or a functor) `f`

as an argument. `f/rrule`

s for these functions call back into AD to compute the `f/rrule`

of `f`

. To test these functions, we use a dummy AD system, which simply calls the appropriate rule for `f`

directly. For this reason, when testing `map(f, collection)`

, the rules for `f`

need to be defined. The `RuleConfig`

for this dummy AD system is the default one, and does not need to be provided.

```
test_rrule(map, x->2x [1, 2, 3.]) # fails, because there is no rrule for x->2x
mydouble(x) = 2x
function ChainRulesCore.rrule(::typeof(mydouble), x)
mydouble_pullback(ȳ) = (NoTangent(), ȳ)
return mydouble(x), mydouble_pullback
end
test_rrule(map, mydouble, [1, 2, 3.]) # works
```

## Testing AD systems

The gradients computed by AD systems can be also be tested using `test_rrule`

. To do that, one needs to provide an `rrule_f`

/`frule_f`

keyword argument, as well as the `RuleConfig`

used by the AD system. `rrule_f`

is a function that wraps the gradient computation by an AD system in the same API as the `rrule`

. `RuleConfig`

is an object that determines which sets of rules are defined for an AD system. For example, let's say we have a complicated function

```
function complicated(x, y)
return do(x + y) + some(x) * hard(y) + maths(x * y)
end
```

that we do not know an `rrule`

for, and we want to check whether the gradients provided by the AD system are correct.

To test gradients computed by the AD system you need to provide a `rrule_f`

function that acts like calling `rrule`

but use AD rather than a defined rule. This has the exact same semantics as is required to overload `ChainRulesCore.rrule_via_ad`

, thus almost all systems doing so should just overload that, and pass in that and the config, and then trigger `test_rrule(MyADConfig, f, xs; rrule_f = ChainRulesCore.rrule_via_ad)`

. See more info on `rrule_via_ad`

and the rule configs in the ChainRules documentation. For some AD systems (e.g. Zygote) `rrule_via_ad`

already exists. If it does not exist, see How to write `rrule_via_ad`

function section below.

We use the `test_rrule`

function to test the gradients using the config used by the AD system

```
config = MyAD.CustomRuleConfig()
test_rrule(config, complicated, 2.3, 6.1; rrule_f=rrule_via_ad)
```

by providing the rule config and specifying the `rrule_via_ad`

as the `rrule_f`

keyword argument.

### How to write `rrule_via_ad`

function

`rrule_via_ad`

will use the AD system to compute gradients and will package them in the `rrule`

-like API.

Let's say the AD package uses some custom differential types and does not provide a gradient w.r.t. the function itself. In order to make the pullback compatible with the `rrule`

API we need to add a `NoTangent()`

to represent the differential w.r.t. the function itself. We also need to transform the `ChainRules`

differential types to the custom types (`cr2custom`

) before feeding the `Δ`

to the AD-generated pullback, and back to `ChainRules`

differential types when returning from the `rrule`

(`custom2cr`

).

```
function rrule_via_ad(config::MyAD.CustomRuleConfig, f::Function, args...)
y, ad_pullback = MyAD.pullback(f, args...)
function rrulelike_pullback(Δ)
diffs = custom2cr(ad_pullback(cr2custom(Δ)))
return NoTangent(), diffs...
end
return y, rrulelike_pullback
end
custom2cr(differential) = ...
cr2custom(differential) = ...
```

## Custom finite differencing

If a package is using a custom finite differencing method of testing the `frule`

s and `rrule`

s, `test_approx`

function provides a convenient way of comparing various types of differentials.

It is effectively `(a, b) -> @test isapprox(a, b)`

, but it preprocesses `thunk`

s and `ChainRules`

differential types `ZeroTangent()`

, `NoTangent()`

, and `Tangent`

, such that the error messages are helpful.

For example,

`test_approx((@thunk 2*2.0), 4.1)`

shows both the expression and the evaluated `thunk`

s

```
Expression: isapprox(actual, expected; kwargs...)
Evaluated: isapprox(4.0, 4.1)
ERROR: There was an error during testing
```

compared to

```
julia> @test isapprox(@thunk 2*2.0, 4.0)
Test Failed at REPL[52]:1
Expression: isapprox(#= REPL[52]:1 =# @thunk((2 * 2.0, 4.0)))
Evaluated: isapprox(Thunk(var"#24#25"()))
ERROR: There was an error during testing
```

which should have passed the test.

## Inference tests

By default, all functions for testing rules check whether the output type (as well as that of the pullback for `rrule`

s) can be completely inferred, such that everything is type stable:

```
julia> function ChainRulesCore.rrule(::typeof(abs), x)
abs_pullback(Δ) = (NoTangent(), x >= 0 ? Δ : big(-1.0) * Δ)
return abs(x), abs_pullback
end
julia> test_rrule(abs, 1.)
test_rrule: abs on Float64: Error During Test at /home/runner/work/ChainRulesTestUtils.jl/ChainRulesTestUtils.jl/src/testers.jl:170
Got exception outside of a @test
return type Tuple{ChainRulesCore.NoTangent, Float64} does not match inferred return type Tuple{ChainRulesCore.NoTangent, Union{Float64, BigFloat}}
[...]
```

This can be disabled on a per-rule basis using the `check_inferred`

keyword argument:

```
julia> test_rrule(abs, 1.; check_inferred=false)
Test Summary: | Pass Total
test_rrule: abs on Float64 | 5 5
Test.DefaultTestSet("test_rrule: abs on Float64", Any[], 5, false, false)
```

This behavior can also be overridden globally by setting the environment variable `CHAINRULES_TEST_INFERRED`

before ChainRulesTestUtils is loaded or by changing `ChainRulesTestUtils.TEST_INFERRED[]`

from inside Julia. ChainRulesTestUtils can detect whether a test is run as part of PkgEval and in this case disables inference tests automatically. Packages can use `@maybe_inferred`

to get the same behavior for other inference tests.