# Which functions need rules?

In principle, a perfect AD system only needs rules for basic operations and can infer the rules for more complicated functions automatically. In practice, performance needs to be considered as well.

Some functions use `ccall`

internally, for example `^`

. These functions cannot be differentiated through by AD systems, and need custom rules.

Other functions can in principle be differentiated through by an AD system, but there exists a mathematical insight that can dramatically improve the computation of the derivative. An example is numerical integration, where writing a rule implementing the fundamental theorem of calculus removes the need to perform AD through numerical integration.

Furthermore, AD systems make different trade-offs in performance due to their design. This means that a certain rule will help one AD system, but not improve (and also not harm) another. Below, we list some patterns relevant for the Zygote.jl AD system.

Rules for functions which mutate its arguments, e.g. `sort!`

, should not be written at the moment. While technically they are supported, they would break Zygote.jl such that it would sometimes quietly return the wrong answer. This may be resolved in the future by allowing AD systems to opt-in or opt-out of certain types of rules.

### Patterns that need rules in Zygote.jl

There are a few classes of functions that Zygote cannot differentiate through. Custom rules will need to be written for these to make AD work.

Other patterns can be AD'ed through, but the backward pass performance can be greatly improved by writing a rule.

#### Functions which mutate arrays

For example,

```
function addone(a::AbstractArray)
b = similar(a)
b .= a .+ 1
return sum(b)
end
```

complains that

```
julia> using Zygote
julia> gradient(addone, a)
ERROR: Mutating arrays is not supported
```

However, upon adding the `rrule`

(restart the REPL after calling `gradient`

)

```
function ChainRules.rrule(::typeof(addone), a)
y = addone(a)
function addone_pullback(ȳ)
return NoTangent(), ones(length(a))
end
return y, addone_pullback
end
```

the gradient can be evaluated:

```
julia> gradient(addone, a)
([1.0, 1.0, 1.0],)
```

Notice that `addone(a)`

mutates another array `b`

internally, but **not** its input. This is commonly done in less trivial functions, and is often what Zygote's `Mutating arrays is not supported`

error is telling you, even though you did not intend to mutate anything. Functions which mutate their own input are much more problematic. These are the ones named (by convention) with an exclamation mark, such as `fill!(a, x)`

or `push!(a, x)`

. It is not possible to write rules which handle all uses of such a function correctly, on current Zygote.

When `gradient`

is called in `Zygote`

for a function with no `rrule`

defined, a backward pass for the function call is generated and cached. When `gradient`

is called for the second time on the same function signature, the backward pass is reused without checking whether an an `rrule`

has been defined between the two calls to `gradient`

.

If an `rrule`

is defined before the first call to `gradient`

it should register the rule and use it, but that prevents comparing what happens before and after the `rrule`

is defined. To compare both versions with and without an `rrule`

in the REPL simultaneously, define a function `f(x) = <body>`

(no `rrule`

), another function `f_cr(x) = f(x)`

, and an `rrule`

for `f_cr`

.

Calling `Zygote.refresh()`

will often have the same effect as restarting the REPL.

#### Exception handling

Zygote does not support differentiating through `try`

/`catch`

statements. For example, differentiating through

```
function exception(x)
try
return x^2
catch e
println("could not square input")
throw(e)
end
end
```

does not work

```
julia> gradient(exception, 3.0)
ERROR: Compiling Tuple{typeof(exception),Int64}: try/catch is not supported.
```

without an `rrule`

defined (restart the REPL after calling `gradient`

)

```
function ChainRulesCore.rrule(::typeof(exception), x)
y = exception(x)
function exception_pullback(ȳ)
return NoTangent(), 2*x
end
return y, exception_pullback
end
```

```
julia> gradient(exception, 3.0)
(6.0,)
```

#### Loops

Julia runs loops fast. Unfortunately Zygote differentiates through loops slowly. So, for example, computing the mean squared error by using a loop

```
function mse(y, ŷ)
N = length(y)
s = 0.0
for i in 1:N
s += (y[i] - ŷ[i])^2.0
end
return s/N
end
```

takes a lot longer to AD through

```
julia> y = rand(30);
julia> ŷ = rand(30);
julia> @btime gradient(mse, $y, $ŷ)
38.180 μs (993 allocations: 65.00 KiB)
```

than if we supply an `rrule`

, (restart the REPL after calling `gradient`

)

```
function ChainRules.rrule(::typeof(mse), x, x̂)
output = mse(x, x̂)
function mse_pullback(ȳ)
N = length(x)
g = (2 ./ N) .* (x .- x̂) .* ȳ
return NoTangent(), g, -g
end
return output, mse_pullback
end
```

which is much faster

```
julia> @btime gradient(mse, $y, $ŷ)
143.697 ns (2 allocations: 672 bytes)
```

#### In-place accumulation

In-place accumulation of gradients is slow in `Zygote`

. The issue, demonstrated in the following example, is that the gradient of `getindex`

allocates an array of zeros with a single non-zero element.

```
function sum3(array)
x = array[1]
y = array[2]
z = array[3]
return x+y+z
end
```

```
julia> @btime gradient(sum3, rand(30))
424.510 ns (9 allocations: 2.06 KiB)
```

Computing the gradient with only a single array allocation using an `rrule`

(restart the REPL after calling `gradient`

)

```
function ChainRulesCore.rrule(::typeof(sum3), a)
y = sum3(a)
function sum3_pullback(ȳ)
grad = zeros(length(a))
grad[1:3] .+= ȳ
return NoTangent(), grad
end
return y, sum3_pullback
end
```

turns out to be significantly faster

```
julia> @btime gradient(sum3, rand(30))
192.818 ns (3 allocations: 784 bytes)
```