# Pedagogical Example

This pedagogical example will show you how to write an `rrule`

. See On writing good `rrule`

/ `frule`

methods section for more tips and gotchas. If you want to learn about `frule`

s, you should still read and understand this example as many concepts are shared, and then look for real world `frule`

examples in ChainRules.jl.

## The primal

We define a struct `Foo`

```
struct Foo{T}
A::Matrix{T}
c::Float64
end
```

and a function that multiplies `Foo`

with an `AbstractArray`

:

```
function foo_mul(foo::Foo, b::AbstractArray)
return foo.A * b
end
```

Note that field `c`

is ignored in the calculation.

## The `rrule`

The `rrule`

method for our primal computation should extend the `ChainRulesCore.rrule`

function.

```
function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo{T}, b::AbstractArray) where T
y = foo_mul(foo, b)
function foo_mul_pullback(ȳ)
f̄ = NoTangent()
f̄oo = Tangent{Foo{T}}(; A=ȳ * b', c=ZeroTangent())
b̄ = @thunk(foo.A' * ȳ)
return f̄, f̄oo, b̄
end
return y, foo_mul_pullback
end
```

We can check this rule against a finite-differences approach using `ChainRulesTestUtils`

:

```
julia> using ChainRulesTestUtils
julia> test_rrule(foo_mul, Foo(rand(3, 3), 3.0), rand(3, 3))
Test Summary: | Pass Total
test_rrule: foo_mul on Foo{Float64},Matrix{Float64} | 10 10
Test.DefaultTestSet("test_rrule: foo_mul on Foo{Float64},Matrix{Float64}", Any[], 10, false, false)
```

Now let's examine the rule in more detail:

```
function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo, b::AbstractArray)
...
return y, foo_mul_pullback
end
```

The `rrule`

dispatches on the `typeof`

of the function we are writing the `rrule`

for, as well as the types of its arguments. Read more about writing rules for constructors and callable objects here. The `rrule`

returns the primal result `y`

, and the pullback function. It is a *very* good idea to name your pullback function, so that they are helpful when appearing in the stacktrace.

`y = foo_mul(foo, b)`

Computes the primal result. It is possible to change the primal computation so that work can be shared between the primal and the pullback. See e.g. the rule for `sort`

, where the sorting is done only once.

```
function foo_mul_pullback(ȳ)
...
return f̄, f̄oo, b̄
end
```

The pullback function takes in the tangent of the primal output (`ȳ`

) and returns the tangents of the primal inputs. Note that it returns a tangent for the primal function in addition to the tangents of primal arguments.

Finally, computing the tangents of primal inputs:

`f̄ = NoTangent()`

The function `foo_mul`

has no fields (i.e. it is not a closure) and can not be perturbed. Therefore its tangent (`f̄`

) is a `NoTangent`

.

`f̄oo = Tangent{Foo}(; A=ȳ * b', c=ZeroTangent())`

The struct `foo::Foo`

gets a `Tangent{Foo}`

structural tangent, which stores the tangents of fields of `foo`

.

The tangent of the field `A`

is `ȳ * b'`

,

The tangent of the field `c`

is `ZeroTangent()`

, because `c`

can be perturbed but has no effect on the primal output.

`b̄ = @thunk(foo.A' * ȳ)`

The tangent of `b`

is `foo.A' * ȳ`

, but we have wrapped it into a `Thunk`

, a tangent type that represents delayed computation. The idea is that in case the tangent is not used anywhere, the computation never happens. Use `InplaceableThunk`

if you are interested in accumulating gradients in-place. Note that in practice one would also `@thunk`

the `f̄oo.A`

tangent, but it was omitted in this example for clarity.

As a final note, since `b`

is an `AbstractArray`

, its tangent `b̄`

should be projected to the right subspace. See the `ProjectTo`

the primal subspace section for more information and an example that motivates the projection operation.