DiffRules
Many differentiation methods rely on the notion of "primitive" differentiation rules that can be composed via various formulations of the chain rule. Using DiffRules, you can define new differentiation rules, query whether or not a given rule exists, and symbolically apply rules to simple Julia expressions.
Note that DiffRules is not a fully-fledged symbolic differentiation tool. It is a (very) simple global database of common derivative definitions, and was developed with the goal of improving derivative coverage in downstream tools.
DiffRules.@define_diffrule
— Macro.@define_diffrule M.f(x) = :(df_dx($x))
@define_diffrule M.f(x, y) = :(df_dx($x, $y)), :(df_dy($x, $y))
⋮
Define a new differentiation rule for the function M.f
and the given arguments, which should be treated as bindings to Julia expressions. Return the defined rule's key.
The LHS should be a function call with a non-splatted argument list, and the RHS should be the derivative expression, or in the n
-ary case, an n
-tuple of expressions where the i
th expression is the derivative of f
w.r.t the i
th argument. Arguments should be interpolated wherever they are used on the RHS.
Note that differentiation rules are purely symbolic, so no type annotations should be used.
Examples:
@define_diffrule Base.cos(x) = :(-sin($x))
@define_diffrule Base.:/(x, y) = :(inv($y)), :(-$x / ($y^2))
@define_diffrule Base.polygamma(m, x) = :NaN, :(polygamma($m + 1, $x))
DiffRules.diffrule
— Function.diffrule(M::Union{Expr,Symbol}, f::Symbol, args...)
Return the derivative expression for M.f
at the given argument(s), with the argument(s) interpolated into the returned expression.
In the n
-ary case, an n
-tuple of expressions will be returned where the i
th expression is the derivative of f
w.r.t the i
th argument.
Examples:
julia> DiffRules.diffrule(:Base, :sin, 1)
:(cos(1))
julia> DiffRules.diffrule(:Base, :sin, :x)
:(cos(x))
julia> DiffRules.diffrule(:Base, :sin, :(x * y^2))
:(cos(x * y ^ 2))
julia> DiffRules.diffrule(:Base, :^, :(x + 2), :c)
(:(c * (x + 2) ^ (c - 1)), :((x + 2) ^ c * log(x + 2)))
DiffRules.hasdiffrule
— Function.hasdiffrule(M::Union{Expr,Symbol}, f::Symbol, arity::Int)
Return true
if a differentiation rule is defined for M.f
and arity
, or return false
otherwise.
Here, arity
refers to the number of arguments accepted by f
.
Examples:
julia> DiffRules.hasdiffrule(:Base, :sin, 1)
true
julia> DiffRules.hasdiffrule(:Base, :sin, 2)
false
julia> DiffRules.hasdiffrule(:Base, :-, 1)
true
julia> DiffRules.hasdiffrule(:Base, :-, 2)
true
julia> DiffRules.hasdiffrule(:Base, :-, 3)
false
DiffRules.diffrules
— Function.diffrules()
Return a list of keys that can be used to access all defined differentiation rules.
Each key is of the form (M::Symbol, f::Symbol, arity::Int)
.
Here, arity
refers to the number of arguments accepted by f
.
Examples:
julia> first(DiffRules.diffrules())
(:Base, :asind, 1)