Documentation

DiffRules

Many differentiation methods rely on the notion of "primitive" differentiation rules that can be composed via various formulations of the chain rule. Using DiffRules, you can define new differentiation rules, query whether or not a given rule exists, and symbolically apply rules to simple Julia expressions.

Note that DiffRules is not a fully-fledged symbolic differentiation tool. It is a (very) simple global database of common derivative definitions, and was developed with the goal of improving derivative coverage in downstream tools.

@define_diffrule M.f(x) = :(df_dx($x))
@define_diffrule M.f(x, y) = :(df_dx($x, $y)), :(df_dy($x, $y))
⋮

Define a new differentiation rule for the function M.f and the given arguments, which should be treated as bindings to Julia expressions. Return the defined rule's key.

The LHS should be a function call with a non-splatted argument list, and the RHS should be the derivative expression, or in the n-ary case, an n-tuple of expressions where the ith expression is the derivative of f w.r.t the ith argument. Arguments should be interpolated wherever they are used on the RHS.

Note that differentiation rules are purely symbolic, so no type annotations should be used.

Examples:

@define_diffrule Base.cos(x)          = :(-sin($x))
@define_diffrule Base.:/(x, y)        = :(inv($y)), :(-$x / ($y^2))
@define_diffrule Base.polygamma(m, x) = :NaN,       :(polygamma($m + 1, $x))
source
DiffRules.diffruleFunction.
diffrule(M::Union{Expr,Symbol}, f::Symbol, args...)

Return the derivative expression for M.f at the given argument(s), with the argument(s) interpolated into the returned expression.

In the n-ary case, an n-tuple of expressions will be returned where the ith expression is the derivative of f w.r.t the ith argument.

Examples:

julia> DiffRules.diffrule(:Base, :sin, 1)
:(cos(1))

julia> DiffRules.diffrule(:Base, :sin, :x)
:(cos(x))

julia> DiffRules.diffrule(:Base, :sin, :(x * y^2))
:(cos(x * y ^ 2))

julia> DiffRules.diffrule(:Base, :^, :(x + 2), :c)
(:(c * (x + 2) ^ (c - 1)), :((x + 2) ^ c * log(x + 2)))
source
DiffRules.hasdiffruleFunction.
hasdiffrule(M::Union{Expr,Symbol}, f::Symbol, arity::Int)

Return true if a differentiation rule is defined for M.f and arity, or return false otherwise.

Here, arity refers to the number of arguments accepted by f.

Examples:

julia> DiffRules.hasdiffrule(:Base, :sin, 1)
true

julia> DiffRules.hasdiffrule(:Base, :sin, 2)
false

julia> DiffRules.hasdiffrule(:Base, :-, 1)
true

julia> DiffRules.hasdiffrule(:Base, :-, 2)
true

julia> DiffRules.hasdiffrule(:Base, :-, 3)
false
source
DiffRules.diffrulesFunction.
diffrules()

Return a list of keys that can be used to access all defined differentiation rules.

Each key is of the form (M::Symbol, f::Symbol, arity::Int).

Here, arity refers to the number of arguments accepted by f.

Examples:

julia> first(DiffRules.diffrules())
(:Base, :asind, 1)
source