Gradient Accumulation
Consider some function $f(x) = g(x) + h(x)$. If we would like the derivative of $f$ with respect to $x$ we must compute it for each part and then sum them, i.e. $\frac{\partial f}{\partial x} = \frac{\partial g}{\partial x} + \frac{\partial h}{\partial x}$. In general, we must accumulate (sum) gradients from each sub-part of a program where a variable is used.
Consider for example:
function sum_first_and_second(X::Array{Float64})
a = X[1]
b = X[2]
y = a + b
return y
end
The AD software must transform that into something which repeatedly sums up the gradient of each part: X̄ = ā + b̄
.
This requires that all differential types D
must implement +
: +(::D, ::D)::D
.
We can note that in this particular case ā
and b̄
will both be arrays. This operation (X̄ = ā + b̄
) will allocate one array to hold ā
, another one to hold b̄
, and a third one to hold ā + b̄
. This is three allocations. Allocations are not free, they increase the time the program takes to run by a nontrivial amount, even with a good allocator and a good garbage collector.
Maybe-mutating accumulation (add!!
)
We can note that in the above that neither ā
nor b̄
are ever used again after accumulating to get X̄
. Furthermore, Array
s are mutable. That means we could over-write either ā
or b̄
and use the result as X̄
:
ā .+= b̄
X̄ = ā
This cuts our allocations down to 2, just ā
and b̄
.
However, we have a bit of a problem that not all types are mutable, so this pattern is hard to apply in general. To deal with that ChainRulesCore provides add!!
. Per the BangBang.jl convention, this is a maybe mutating addition. It may mutate its first argument (if it is mutable), but it will definitely return the correct result. We would write using that as X̄ = add!!(ā, b̄)
: which would in this case give us just 2 allocations. AD systems can generate add!!
instead of +
when accumulating gradient to take advantage of this.
Inplaceable Thunks (InplaceableThunks
) avoid allocating values in the first place.
We got down to two allocations from using add!!
, but can we do better? We can think of having a differential type which acts on a partially accumulated result, to mutate it to contain its current value plus the partial derivative being accumulated. Rather than having an actual computed value, we can just have a thing that will act on a value to perform the addition. Let's illustrate it with our example.
b̄
is the partial for X[2]
and its value can be computed by:
b̄ = zeros(size(X))
b̄[2] = ȳ # the scalar sensitivity of the `sum_first_and_second` output
b̄
is a matrix entirely of zeros, except for at the index 2
, where it is set to the output sensitivity ȳ
. ā
is similar, except with the non-zero at index 1
.
What is the action of b̄
upon ā
, to get the same result as X̄ = add!!(ā, b̄)
(or X̄ = ā + b̄
for that matter)? It is:
function b̄_add!(ā)
ā[2] += ȳ
return ā
end
We don't need to worry about all those zeros since x + 0 == x
.
InplaceableThunk
is the type we have to represent derivatives as gradient accumulating actions. We must note that to do this we do need a value form of ā
for b̄
to act upon. For this reason every inplaceable thunk has both a val
field holding the value representation, and a add!
field holding the action representation. The val
field use a plain Thunk
to avoid the computation (and thus allocation) if it is unused.
Right now every InplaceableThunk
has two fields that need to be specified. The value form (represented as a the Thunk
typed field), and the action form (represented as the add!
field). It is possible in a future version of ChainRulesCore.jl we will work out a clever way to find the zero differential for arbitrary primal values. Given that, we could always just determine the value form from inplaceable.add!(zero_differential(primal))
. There are some technical difficulties in finding the zero differentials, but this may be solved at some point.
The +
operation on InplaceableThunk
s is overloaded to unthunk
that val
field to get the value form. Where as the add!!
operation is overloaded to call add!
to invoke the action.
With getindex
defined to return an InplaceableThunk
, we now get to X̄ = add!!(ā, b̄)
requires only a single allocation. This allocation occurs when unthunk
ing ā
, which is then mutated to become X̄
. This is basically as good as we can get: if we want X̄
to be an Array
then at some point we need to allocate that array.
We could keep going further to drop allocations if we really wanted. If we didn't care about X̄
being an Array
then we could defer its computation too. X̄ = @thunk add!!(ā, b̄)
. This kind of deferral will work fine and you can keep chaining it. It does start to burn stack space, and might make the compiler's optimization passes cry. But it's valid and should work fine.
Examples of InplaceableThunks
getindex
The aforementioned getindex
is really the poster child for this. Consider something like:
function mysum(X::Array{Float64})
total = 0.0
for i in eachindex(X)
total += X[i]
end
return total
end
If one only has value representation of derivatives one ends up having to allocate a derivative array for every single element of the original array X
. That's terrible. On the other hand, with the action representation that InplaceableThunk
s provide, there is just a single Array
allocated. One can see the getindex
rule in ChainRules.jl for the implementation.
matmul etc (*
)
Multiplication of scalars/vectors/matrices of compatible dimensions can all also have their derivatives represented as an InplaceableThunk
. These tend to pivot around that add!
action being defined along the lines of: X̄ -> mul!(X̄, A', Ȳ, true, true)
. Where 5-arg mul!
is the in place multiply-add operation. mul!(X̄, A', Ȳ, true, true)
has the same effect as (X̄ .+= A'*Ȳ)
but avoids allocating the matrix A'*Ȳ
This is one of the fundamental operations provided by BLAS – including the application of the conjugate transpose. e.g. the Matrix-Matrix form is GEMM
(GEneralized Matrix-Matrix Multiplication), the Matrix-Vector form is GEMV
(GEneralized Matrix-Vector Multiplication) etc. Under the hood doing it out of place is going to call one of these methods anyway, but on a freshly allocated output array. So we are going to hit a very efficient implementation and get the addition for free.
One can see the *
rules in ChainRules.jl for the implementations