The Elements of Differentiable Programming
The Elements of Differentiable Programming
The Elements of
Differentiable Programming
Mathieu Blondel
Google DeepMind
[email protected]
Vincent Roulet
Google DeepMind
[email protected]
1 Introduction 4
1.1 What is differentiable programming? . . . . . . . . . . . . 4
1.2 Book goals and scope . . . . . . . . . . . . . . . . . . . . 6
1.3 Intended audience . . . . . . . . . . . . . . . . . . . . . . 7
1.4 How to read this book? . . . . . . . . . . . . . . . . . . . 7
1.5 Related work . . . . . . . . . . . . . . . . . . . . . . . . . 7
I Fundamentals 9
2 Differentiation 10
2.1 Univariate functions . . . . . . . . . . . . . . . . . . . . . 10
2.1.1 Derivatives . . . . . . . . . . . . . . . . . . . . . . 10
2.1.2 Calculus rules . . . . . . . . . . . . . . . . . . . . 13
2.1.3 Leibniz’s notation . . . . . . . . . . . . . . . . . . 15
2.2 Multivariate functions . . . . . . . . . . . . . . . . . . . . 16
2.2.1 Directional derivatives . . . . . . . . . . . . . . . . 16
2.2.2 Gradients . . . . . . . . . . . . . . . . . . . . . . 17
2.2.3 Jacobians . . . . . . . . . . . . . . . . . . . . . . 20
2.3 Linear differentiation maps . . . . . . . . . . . . . . . . . 26
2.3.1 The need for linear maps . . . . . . . . . . . . . . 26
2.3.2 Euclidean spaces . . . . . . . . . . . . . . . . . . . 27
2.3.3 Linear maps and their adjoints . . . . . . . . . . . 28
2.3.4 Jacobian-vector products . . . . . . . . . . . . . . 29
2.3.5 Vector-Jacobian products . . . . . . . . . . . . . . 30
2.3.6 Chain rule . . . . . . . . . . . . . . . . . . . . . . 31
2.3.7 Functions of multiple inputs (fan-in) . . . . . . . . 32
2.3.8 Functions of multiple outputs (fan-out) . . . . . . 34
2.3.9 Extensions to non-Euclidean linear spaces . . . . . 34
2.4 Second-order differentiation . . . . . . . . . . . . . . . . . 36
2.4.1 Second derivatives . . . . . . . . . . . . . . . . . . 36
2.4.2 Second directional derivatives . . . . . . . . . . . . 36
2.4.3 Hessians . . . . . . . . . . . . . . . . . . . . . . . 37
2.4.4 Hessian-vector products . . . . . . . . . . . . . . . 39
2.4.5 Second-order Jacobians . . . . . . . . . . . . . . . 39
2.5 Higher-order differentiation . . . . . . . . . . . . . . . . . 40
2.5.1 Higher-order derivatives . . . . . . . . . . . . . . . 40
2.5.2 Higher-order directional derivatives . . . . . . . . . 41
2.5.3 Higher-order Jacobians . . . . . . . . . . . . . . . 41
2.5.4 Taylor expansions . . . . . . . . . . . . . . . . . . 42
2.6 Differential geometry . . . . . . . . . . . . . . . . . . . . 43
2.6.1 Differentiability on manifolds . . . . . . . . . . . . 43
2.6.2 Tangent spaces and pushforward operators . . . . . 44
2.6.3 Cotangent spaces and pullback operators . . . . . 45
2.7 Generalized derivatives . . . . . . . . . . . . . . . . . . . 48
2.7.1 Rademacher’s theorem . . . . . . . . . . . . . . . 49
2.7.2 Clarke derivatives . . . . . . . . . . . . . . . . . . 49
2.8 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 52
3 Probabilistic learning 54
3.1 Probability distributions . . . . . . . . . . . . . . . . . . . 54
3.1.1 Discrete probability distributions . . . . . . . . . . 54
3.1.2 Continuous probability distributions . . . . . . . . 55
3.2 Maximum likelihood estimation . . . . . . . . . . . . . . . 56
3.2.1 Negative log-likelihood . . . . . . . . . . . . . . . 56
3.2.2 Consistency w.r.t. the Kullback-Leibler divergence . 56
3.3 Probabilistic supervised learning . . . . . . . . . . . . . . 57
3.3.1 Conditional probability distributions . . . . . . . . 57
3.3.2 Inference . . . . . . . . . . . . . . . . . . . . . . . 57
3.3.3 Binary classification . . . . . . . . . . . . . . . . . 58
3.3.4 Multiclass classification . . . . . . . . . . . . . . . 60
3.3.5 Regression . . . . . . . . . . . . . . . . . . . . . . 61
3.3.6 Multivariate regression . . . . . . . . . . . . . . . 62
3.3.7 Integer regression . . . . . . . . . . . . . . . . . . 63
3.3.8 Loss functions . . . . . . . . . . . . . . . . . . . . 63
3.4 Exponential family distributions . . . . . . . . . . . . . . . 65
3.4.1 Definition . . . . . . . . . . . . . . . . . . . . . . 65
3.4.2 The log-partition function . . . . . . . . . . . . . . 67
3.4.3 Maximum entropy principle . . . . . . . . . . . . . 68
3.4.4 Maximum likelihood estimation . . . . . . . . . . . 69
3.4.5 Probabilistic learning with exponential families . . . 69
3.5 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 71
II Differentiable programs 72
4 Parameterized programs 73
4.1 Representing computer programs . . . . . . . . . . . . . . 73
4.1.1 Computation chains . . . . . . . . . . . . . . . . . 73
4.1.2 Directed acylic graphs . . . . . . . . . . . . . . . . 74
4.1.3 Computer programs as DAGs . . . . . . . . . . . . 76
4.1.4 Arithmetic circuits . . . . . . . . . . . . . . . . . . 78
4.2 Feedforward networks . . . . . . . . . . . . . . . . . . . . 79
4.3 Multilayer perceptrons . . . . . . . . . . . . . . . . . . . . 79
4.3.1 Combining affine layers and activations . . . . . . . 79
4.3.2 Link with generalized linear models . . . . . . . . . 80
4.4 Activation functions . . . . . . . . . . . . . . . . . . . . . 81
4.4.1 Scalar-to-scalar nonlinearities . . . . . . . . . . . . 81
4.4.2 Vector-to-scalar nonlinearities . . . . . . . . . . . . 81
4.4.3 Scalar-to-scalar probability mappings . . . . . . . . 82
4.4.4 Vector-to-vector probability mappings . . . . . . . 83
4.5 Residual neural networks . . . . . . . . . . . . . . . . . . 85
4.6 Recurrent neural networks . . . . . . . . . . . . . . . . . . 86
4.6.1 Vector to sequence . . . . . . . . . . . . . . . . . 86
4.6.2 Sequence to vector . . . . . . . . . . . . . . . . . 88
4.6.3 Sequence to sequence (aligned) . . . . . . . . . . . 88
4.6.4 Sequence to sequence (unaligned) . . . . . . . . . 88
4.7 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 89
5 Control flows 90
5.1 Comparison operators . . . . . . . . . . . . . . . . . . . . 90
5.2 Soft inequality operators . . . . . . . . . . . . . . . . . . 92
5.2.1 Heuristic definition . . . . . . . . . . . . . . . . . 92
5.2.2 Stochastic process perspective . . . . . . . . . . . 92
5.3 Soft equality operators . . . . . . . . . . . . . . . . . . . 93
5.3.1 Heuristic definition . . . . . . . . . . . . . . . . . 93
5.3.2 Gaussian process perspective . . . . . . . . . . . . 94
5.4 Logical operators . . . . . . . . . . . . . . . . . . . . . . 95
5.5 Continuous extensions of logical operators . . . . . . . . . 96
5.5.1 Probabilistic continuous extension . . . . . . . . . 96
5.5.2 Triangular norms and co-norms . . . . . . . . . . . 98
5.6 If-else statements . . . . . . . . . . . . . . . . . . . . . . 98
5.6.1 Differentiating through branch variables . . . . . . 99
5.6.2 Differentiating through predicate variables . . . . . 100
5.6.3 Continuous relaxations . . . . . . . . . . . . . . . 101
5.7 Else-if statements . . . . . . . . . . . . . . . . . . . . . . 102
5.7.1 Encoding K branches . . . . . . . . . . . . . . . . 103
5.7.2 Conditionals . . . . . . . . . . . . . . . . . . . . . 104
5.7.3 Differentiating through branch variables . . . . . . 105
5.7.4 Differentiating through predicate variables . . . . . 106
5.7.5 Continuous relaxations . . . . . . . . . . . . . . . 106
5.8 For loops . . . . . . . . . . . . . . . . . . . . . . . . . . . 108
5.9 Scan functions . . . . . . . . . . . . . . . . . . . . . . . . 109
5.10 While loops . . . . . . . . . . . . . . . . . . . . . . . . . 110
5.10.1 While loops as cyclic graphs . . . . . . . . . . . . 110
5.10.2 Unrolled while loops . . . . . . . . . . . . . . . . . 111
5.10.3 Markov chain perspective . . . . . . . . . . . . . . 113
5.11 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 116
III Differentiating through programs 117
17 Duality 348
17.1 Dual norms . . . . . . . . . . . . . . . . . . . . . . . . . 348
17.2 Fenchel duality . . . . . . . . . . . . . . . . . . . . . . . . 349
17.3 Bregman divergences . . . . . . . . . . . . . . . . . . . . 352
17.4 Fenchel-Young loss functions . . . . . . . . . . . . . . . . 355
17.5 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 356
References 357
The Elements of
Differentiable Programming
Mathieu Blondel1 and Vincent Roulet1
1 Google DeepMind
ABSTRACT
Artificial intelligence has recently experienced remarkable
advances, fueled by large models, vast datasets, acceler-
ated hardware, and, last but not least, the transformative
power of differentiable programming. This new programming
paradigm enables end-to-end differentiation of complex com-
puter programs (including those with control flows and data
structures), making gradient-based optimization of program
parameters possible.
As an emerging paradigm, differentiable programming builds
upon several areas of computer science and applied mathe-
matics, including automatic differentiation, graphical mod-
els, optimization and statistics. This book presents a com-
prehensive review of the fundamental concepts useful for
differentiable programming. We adopt two main perspec-
tives, that of optimization and that of probability, with clear
analogies between the two.
Differentiable programming is not merely the differentiation
of programs, but also the thoughtful design of programs
intended for differentiation. By making programs differen-
tiable, we inherently introduce probability distributions over
their execution, providing a means to quantify the uncer-
tainty associated with program outputs.
Notation
Notation Description
X ⊆R D
Input space (e.g., features)
Y ⊆ RM Output space (e.g., classes)
Sk ⊆ RDk Output space on layer or state k
W ⊆ RP Weight space
Λ ⊆ RQ Hyperparameter space
Θ ⊆ RR Distribution parameter space, logit space
N Number of training samples
T Number of optimization iterations
x∈X Input vector
y∈Y Target vector
sk ∈ Sk State vector k
w∈W Network (model) weights
λ∈Λ Hyperparameters
θ∈Θ Distribution parameters, logits
π ∈ [0, 1] Probability value
π ∈ △M Probability vector
2
3
Notation Description
f Network function
f (·; x) Network function with x fixed
L Objective function
ℓ Loss function
κ Kernel function
ϕ Output embedding, sufficient statistic
step Heaviside step function
logisticσ Logistic function with temperature σ
logistic Shorthand for logistic1
pθ Model distribution with parameters θ
ρ Data distribution over X × Y
ρX Data distribution over X
µ, σ 2 Mean and variance
Z Random noise variable
1
Introduction
4
1.1. What is differentiable programming? 5
This book does not need to be read linearly chapter by chapter. When
needed, we indicate at the beginning of a chapter what chapters are
recommended to be read as a prerequisite.
Fundamentals
2
Differentiation
2.1.1 Derivatives
10
2.1. Univariate functions 11
if
|g(v)|
lim = 0.
v→w |f (v)|
f (w + δ) = f (w) + o(1) as δ → 0.
2.2.2 Gradients
We now introduce the gradient vector, which gathers the partial deriva-
tives. We first recall the definitions of linear map and linear form.
∂1 f (w) ∂f (w)[e1 ]
∇f (w) := .
.. ..
=
. .
∂P f (w) ∂f (w)[eP ]
18 Differentiation
P
∂f (w)[v] = vi ∂f (w)[ei ] = ⟨v, ∇f (w)⟩.
X
i=1
In the definition above, the fact that the gradient can be used to
compute the directional derivative is a mere consequence of the linearity.
However, in more abstract cases presented in later sections, the gradient
is defined through this property.
As a simple example, any linear function of the form f (w) = a⊤ w =
i=1 ai wi is differentiable as we have (a (w +v)−a w −a v)/∥v∥2 =
PP ⊤ ⊤ ⊤
Linearity of gradients
The notion of differentiability for multi-inputs functions naturally inher-
its from the linearity of derivatives for single-input functions. For any
u1 , . . . , uM ∈ R and any multi-inputs functions f1 , . . . , fM differentiable
at w, the function u1 f1 + . . . + uM fM is differentiable at w and its
gradient is
2.2.3 Jacobians
Let us now consider a multi-output function f : RP → RM defined by
f (w) := (f1 (w), . . . , fM (w)), where fj : RP → R. A typical example
in machine learning is a neural network. The notion of directional
derivative can be extended to such function by defining it as the vector
composed of the coordinate-wise directional derivatives:
f (w+δv)−f (w)
1 1
δ
f (w + δv) − f (w) ..
∂f (w)[v] := lim = lim ∈ RM ,
δ→0 δ δ→0
.
fM (w+δv)−fM (w)
δ
2.2. Multivariate functions 21
where the limits (provided that they exist) are applied coordinate-wise.
The directional derivative of f in the direction v ∈ RP is therefore the
vector that gathers the directional derivative of each fj , i.e., ∂f (w)[w] =
(∂fj (v)[v])M
j=1 . In particular, we can define the partial derivatives of
f at w as the vectors
∂i f1 (w)
∂i f (w) := ∂f (w)[ei ] =
.. ∈ RM .
.
∂i fM (w)
As for the usual definition of the derivative, the directional derivative
can provide a linear approximation of a function around a current input
as illustrated in Fig. 2.3 for a parameterized curve f : R → R2 .
Just as in the single-output case, differentiability is defined not only
as the existence of directional derivatives in any direction but also by
the linearity in the chosen direction.
∂1 f1 (w) . . . ∂P f1 (w)
∂f (w) :=
.. .. .. ∈ RM ×P .
. . .
∂1 fM (w) . . . ∂P fM (w)
The Jacobian can be represented by stacking columns of partial
derivatives or rows of gradients,
∇f1 (w)⊤
..
∂f (w) = ∂1 f (w), . . . , ∂P f (w) =
. .
∇fM (w)⊤
P
∂f (w)[v] = vi ∂i f (w) = ∂f (w)v ∈ RM .
X
i=1
f1′ (w)
∂f (w) = f ′ (w) :=
.. M ×1
. ∈R .
′ (w)
fM
∂f (w) = f ′ (w) ∈ R.
Example 2.4 illustrates the form of the Jacobian matrix for the
element-wise application of a differentiable function such as the softplus
activation. This example already shows that the Jacobian takes a simple
diagonal matrix form. As a consequence, the directional derivative
associated with this function is simply given by an element-wise product
rather than a full matrix-vector product as suggested in Definition 2.8.
We will revisit this point in Section 2.3.
σ(w1 )
.
.. ∈ R
f (w) := where σ(w) := log(1 + ew ).
P
σ(wP )
24 Differentiation
σ (w1 ) 0 ... 0
′
.. .. ..
. .
0 .
∂f (w) = diag(σ ′ (w1 ), . . . , σ ′ (wP )) := . .
. .. ..
. . . 0
0 . . . 0 σ ′ (wP )
u= M j=1 uj ej ∈ R
M are given by
P
M
∇(u⊤ f )(w) = uj ∇fj (w) = ∂f (w)⊤ u ∈ RP ,
X
j=1
Chain rule
Equipped with a generic definition of differentiability and the associated
objects, gradients and Jacobians, we can now generalize the chain rule,
2.2. Multivariate functions 25
y = (y1 , . . . , yN ) ∈ R .
⊤ N
get
∇f (w) = ∂f1 (w)⊤ ∇f2 (f1 (w))
provided that f1 , f2 are differentiable at w and f1 (w), respectively.
The function f1 is linear so differentiable with Jacobian ∂f1 (w) =
X. On the other hand the partial derivatives of f2 are given by
∂j f2 (p) = 2(pj − yj ) for j ∈ {1, . . . , N }. Therefore, f2 is differen-
tiable at any p and its gradient is ∇f2 (p) = 2(p − y). By combining
the computations of the Jacobian of f1 and the gradient of f2 , we
then get the gradient of f as
Linear spaces, a.k.a. vector spaces, are spaces equipped (and closed
under) an addition rule compatible with multiplication by a scalar
(we limit ourselves to the field of reals). Namely, in a vector space E,
there exists the operations + and ·, such that for any u, v ∈ E, and
a ∈ R, we have u + v ∈ E and a · u ∈ E. Euclidean spaces are linear
spaces equipped with a basis e1 , . . . , eP ∈ E. Any element v ∈ E can be
decomposed as v = Pi=1 vi ei for some unique scalars v1 , . . . , vP ∈ R. A
P
where tr(Z) := Pi=1 Zii is the trace operator defined for square matrices
P
As this is true for any v ∈ E, ∂f (w)∗ [u] is the gradient of ⟨u, f ⟩F per
Proposition 2.4.
The proof follows the one of Proposition 2.2. When the last function
is scalar-valued, which is often the case in machine learning, we obtain
the following simplified result.
i=1
2.3. Linear differentiation maps 33
∂f (x, W )[v, V ] = W v + V x ∈ F.
∂1 f (x, W )∗ [u] = W ⊤ u ∈ E1 ,
∂2 f (x, W )∗ [u] = ux⊤ ∈ E2 .
i=1
T
∂h(w)[v] = ∂i f (g(w))[∂gi (w)[v]].
X
i=1
spaces.
For example, directional derivatives (see Definition 2.11) can
be defined in any linear space equipped with a norm and complete
with respect to this norm. Such spaces are called Banach spaces.
Completeness is a technical assumption that requires that any Cauchy
sequence converges (a Cauchy sequence is a sequence whose elements
become arbitrarily close to each other as the sequence progresses). A
function f : E → F defined from a Banach space E onto a Banach space
F is then called Gateaux differentiable if its directional derivative is
defined along any direction (where limits are defined w.r.t. the norm in
F). Some authors also require the directional derivative to be linear to
define a Gateaux differentiable function.
Fréchet differentiability can also naturally be generalized to
Banach spaces. The only difference is that, in generic Banach spaces,
the linear map l satisfying Definition 2.11 must be continuous, i.e., there
must exist C > 0, such that l[v] ≤ C∥v∥, where ∥ · ∥ is the norm in the
Banach space E.
The definitions of gradient and VJPs require in addition a notion of
inner product. They can be defined in Hilbert spaces, that is, linear
spaces equipped with an inner product and complete with respect to
the norm induced by the inner product (they could also be defined
in a Banach space by considering operations in the dual space, see,
e.g. (Clarke et al., 2008)). The existence of the gradient is ensured by
Riesz’s representation theorem which states that any continuous
linear form in a Hilbert space can be represented by the inner product
with a vector. Since for a differentiable function f : E → R, the JVP
∂f (w) : E → R is a linear form, Riesz’s representation theorem ensures
the existence of the gradient as the element g ∈ E such that ∂f (w)v =
⟨g, v⟩ for any v ∈ E. The VJP is also well-defined as the adjoint of the
JVP w.r.t. the inner product of the Hilbert space.
As an example, the space of squared integrable functions on R is a
Hilbert space equipped with the inner product ⟨a, b⟩ := a(x)b(x)dx.
R
Here, we cannot find a finite number of functions that can express all
possible functions on R. Therefore, this space is not a mere Euclidean
space. Nevertheless, we can consider functions on this Hilbert space
(called functionals to distinguish them from the elements of the space).
36 Differentiation
Figure 2.4: Points at which the second derivative is small are points along which
the function is well approximated by its tangent line. On the other hand, point with
large second derivative tend to be badly approximated by the tangent line.
2.4.3 Hessians
For a multi-input function, twice differentiability is simply defined as
the differentiability of any directional derivative ∂f (w)[v] w.r.t. w.
38 Differentiation
∇ f (w) :=
2
.. .. ..
. . . ,
∂P 1 f (w) . . . ∂P P f (w)
provided that all second partial derivatives are well-defined.
The second directional derivative at w is bilinear in any direc-
tions v = Pi=1 vi ei and v ′ = Pi=1 vi′ ei . Therefore,
P P
P
∂ 2 f (w)[v, v ′ ] = vi vj′ ∂ 2 f (w)[ei , ej ] = ⟨v, ∇2 f (w)v ′ ⟩.
X
i,j=1
its definition for single-output functions), we have that the Hessian can
be expressed as ∇2 f (w) = ∇(∇f )(w), which justifies its notation.
Similarly as for the differentiability of a function f , twice differen-
tiability of f at w is equivalent to having the second partial derivatives
not only defined but also continuous in a neighborhood of w. Remark-
ably, by requiring twice differentiability, i.e., continuous second partial
derivatives, the Hessian is guaranteed to be symmetric (Schwarz, 1873).
∂ 2 f (w)[v, v ′ ] = ⟨v ′ , ∇2 f (w)[v]⟩.
The function f is twice differentiable if and only if all its coordinates are
twice differentiable. The second directional derivative is then a bilinear
map. We can then compute second directional derivatives as
P
∂ 2 f (w)[v, v ′ ] = vi vj′ ∂ 2 f (w)[ei , ej ] = (⟨v, ∇2 fj (w)v ′ ⟩)M
X
j=1 .
i,j=1
∂ n f (w)[v1 , . . . , vn ]
= ∂(∂ n−1 f (·)[v1 , . . . , vn−1 ])[vn ]
∂f (w + δvn )[v1 , . . . , vn−1 ] − ∂f (w)[v1 , . . . , vn−1 ]
= lim
δ→0 δ
i1 ,...,in =1
i1 ,...,in =1
In the case of the sphere in Fig. 2.5, the tangent space is a plane, that
is, a Euclidean space. This property is generally true: tangent spaces
2.6. Differential geometry 45
Note that elements of the cotangent space are linear mappings, not
vectors. This distinction is important to define the pullback operator
as an operator on functions as done in measure theory. From a linear
algebra viewpoint, the cotangent space is exactly the dual space of
Ty N , that is, the set of linear maps from Ty N to R, called linear forms.
As Ty N is a Euclidean space, its dual space Ty∗ N is also a Euclidean
space. The pullback operator is then defined as the operator that gives
access to directional derivatives of β ◦ f given the directional derivative
of β at f (w).
Function f M→N
Push-forward ∂f (w) Tw M → Tf (w) N
Pullback ∂f (w)⋆ Tf∗(w) N → Tw∗ M
Adjoint of pushforward ∂f (w)∗ Tf (w) N → Tw M
Table 2.1: For a differentiable function f defined from a manifold M onto a manifold
N , the JVP is generalized with the notion of pushforward ∂f (w). The counterpart
of the pushforward is the pullback operation ∂f (w)⋆ that acts on linear forms in the
tangent spaces. For Riemannian manifolds, the pullback operation can be identified
with the adjoint operator ∂f (w)∗ of the pushforward operator as any linear form is
represented by a vector.
Tw M = Null(2⟨w, ·⟩)
= {v ∈ RP : ⟨w, v⟩ = 0}
i=1
2.8 Summary
P(Y = y) := p(y),
E[ϕ(Y )] =
X
p(y)ϕ(y),
y∈Y
54
3.1. Probability distributions 55
y∈Y
the variance is
Z
V[ϕ(Y )] = E[(ϕ(Y ) − E[ϕ(Y )])2 ] = p(y)(ϕ(y) − E[ϕ(Y )])2 dy
Y
1 1
2 !
y−µ
pλ (y) := √ exp − .
σ 2π 2 σ
p and
p(Y )
λ∞ := arg min KL(p, pλ ) = EY ∼p log ,
λ∈Λ pλ (Y )
then λ
b N → λ∞ in expectation over the observations, as N → ∞. This
can be seen by using
N N
1 X p(yi ) 1 X
KL(p, pλ ) ≈ log = log p(yi ) − log pλ (yi )
N i=1 pλ (yi ) N i=1
and the law of large numbers.
3.3.2 Inference
The main advantage of this probabilistic approach is that our prediction
model is much richer than if we just learned a function from X to Y.
58 Probabilistic learning
See also Section 4.4.3 for more details. When g is linear in w, this is
known as binary logistic regression.
Remark 3.1 (Link with the logistic distribution). The logistic distri-
bution with mean and scale parameters µ and σ is a continuous
probability distribution with PDF
u−µ
pµ,σ (u) := p0,1
σ
where
exp(−z)
p0,1 (z) := .
(1 + exp(−z))2
The corresponding CDF is
Z u
u−µ
P(U ≤ u) = pµ,σ (u)du = logistic .
−∞ σ
Therefore, if
U ∼ Logistic(µ, σ)
and
u−µ
Y ∼ Bernouilli logistic ,
σ
then
P(Y = 1) = P(U ≤ u).
Here, U can be interpreted as a latent continuous variable and u
as a threshold.
60 Probabilistic learning
λ := π ∈ △M ,
△M := {π ∈ RM
+ : ⟨π, 1⟩ = 1},
Y ∼ Categorical(π).
where
ϕ(y) := ey
is the standard basis vector for the coordinate y ∈ [M ].
Since Y is a categorical variable, it does not make sense to compute
the expectation of Y but we can compute that of ϕ(Y ) = eY ,
EY ∼pλ [ϕ(Y )] = π.
That is, the mean and the probability distribution (represented by the
vector π) are the same in this case.
3.3.5 Regression
λ := (µ, σ),
1 1 (y − µ)2
!
pλ (y) := √ exp − .
σ 2π 2 σ2
The expectation is
EY ∼pλ [Y ] = µ.
1 y−µ
P(Y ≤ y) = 1 + erf √ , (3.1)
2 σ 2
1 b−µ a−µ
P(a < Y ≤ b) = erf √ − erf √ .
2 σ 2 σ 2
62 Probabilistic learning
Parameterization
Typically, in regression, the mean is output by a model, while the
standard deviation σ is kept fixed (typically set to 1). Since µ is uncon-
strained, we can simply set
µ := f (x, w) ∈ R,
λ := (µ, Σ),
EY ∼pλ [Y ] = µ.
Parameterization
Typically, in multivariate regression, the mean is output by a model,
while the covariance matrix is kept fixed (typically set to the identity
matrix). Since µ is again unconstrained, we can simply set
µ := f (x, w) ∈ RM .
3.3. Probabilistic supervised learning 63
where
λi := f (xi , w).
Again, this is equivalent to minimizing the negative log-likelihood,
N
b = arg min − log pλi (yi ).
X
w
w∈W i=1
j=1
3.4.1 Definition
The exponential family is a class of probability distributions, whose
PMF or PDF can be written in the form
h(y) exp [⟨θ, ϕ(y)⟩]
pθ (y) =
exp(A(θ))
= h(y) exp [⟨θ, ϕ(y)⟩ − A(θ)] ,
where θ are the natural or canonical parameters of the distribution.
The function h is known as the base measure. The function ϕ is the
sufficient statistic: it holds all the information about y and is used
to embed y in a vector space. The function A is the log-partition
or log-normalizer (see below for a details). All the distributions we
reviewed in Section 3.3 belong to the exponential family. With some
abuse of notation, we use pλ for the distribution in original form and
pθ for the distribution in exponential family form. As we will see,
we can go from θ to λ and vice-versa. We illustrate how to rewrite a
distribution in exponential family form below.
Example 3.2 (Bernouilli distribution). The PMF of the Bernouilli
distribution with parameter λ = π equals
pλ (y) := π y (1 − π)1−y
= exp(log(π y (1 − π)1−y ))
= exp(y log(π) + (1 − y) log(1 − π))
= exp(log(π/(1 − π))y + log(1 − π))
= exp(θy − log(1 + exp(θ)))
= exp(θy − softplus(θ))
=: pθ (y).
Bernouilli Categorical
Y {0, 1} [M ]
λ π = logistic(θ) π = softmax(θ)
θ logit(π) log π + exp(A(θ))
ϕ(y) y ey
A(θ) softplus(θ) = − log(1 − π) logsumexp(θ)
h(y) 1 1
µ
θ σ ( σµ2 , 2σ
−1
2)
y
ϕ(y) σ (y, y 2 )
θ2 µ2 −θ12 µ2
A(θ) 2 = 2σ 2 4θ2 − 12 log(−2θ2 ) = σ2
+ log σ
2
exp( −y2 )
h(y) √ 2σ √1
2πσ 2π
y∈Y
y∈Y
and similarly for continuous variables. Here, we defined the linear map
y∈Y
Its gradient is
which is independent of y.
Inference
Once we found w by minimizing the objective function above, there are
several possible strategies to perform inference for a new input x.
• Expectation. When the goal is to compute the expectation of
ϕ(Y ), we can use ∇A(f (x, w)). That is, we compute the distribu-
tion parameters associated with x by θ = f (x, w) and then we
compute the mean by µ = ∇A(θ). When f is linear in w, the
composition ∇A ◦ f is called a generalized linear model.
3.5 Summary
Differentiable programs
4
Parameterized programs
73
74 Parameterized programs
... ...
s0 ∈ S0
s1 := f1 (s0 ) ∈ S1
..
.
sK := fK (sK−1 ) ∈ SK
f (s0 ) := sK . (4.1)
Figure 4.2: Example of a directed acyclic graph. Here the nodes are V =
{0, 1, 2, 3, 4}, the edges are E = {(0, 1), (0, 2), (0, 3), (1, 3), (2, 3), (1, 4), (3, 4)}. Parents
of the node 3 are parents(3) = {0, 1, 2}. Children of node 1 are children(1) = {3, 4}.
There is a unique root, 0, and a unique leaf, 4; 0 → 3 → 4 is a path from 0 to 4. This
is an acyclic graph since there is no cycle (i.e., a path from a node to itself). We can
order nodes 0 and 3 as 0 ≤ 3 since there is no path from 3 to 0. Similarly, we can
order 1 and 2 as 1 ≤ 2 since there is no path from 2 to 1. Two possible topological
orders of the nodes are (0, 1, 2, 3, 4) and (0, 2, 1, 3, 4).
√
Figure 4.3: Representation of f (x1 , x2 ) = x2 ex1 x1 + x2 ex1 as a DAG, with
functions as nodes and variables as edges. The function is decomposed as 8 functions
in topological order.
Executing a program
When a program has multiple inputs, we can always group them into
s0 ∈ S0 as s0 = (s0,1 , . . . , s0,N0 ) with S0 = (S0,1 × · · · × S0,N0 ), since
later functions can always filter out what elements of s0 they need.
Likewise, if an intermediate function fk has multiple outputs, we can
always group them as a single output sk = (sk,1 , . . . , sk,Nk ) with Sk =
(Sk,1 × · · · × Sk,Nk ), since later functions can filter out the elements of
sk that they need.
78 Parameterized programs
Figure 4.4: Two possible representations of a program. Left: Functions and output
variables are represented by the same nodes. Right: functions and variables are
represented by a disjoint set of nodes.
logsumexp(u) = logsumexp(u − c 1) + c,
where c := maxj∈[M ] uj .
Two important properties of the logistic function are that for all u ∈ R
logistic(−u) = 1 − logistic(u)
and
This mapping puts all the probability mass onto a single coordinate
(in case of ties, we pick a single coordinate arbitrarily). Unfortunately,
this mapping is a discontinuous function. As a differentiable everywhere
relaxation, we can use the softargmax defined by
exp(u)
softargmax(u) := PM .
j=1 exp(uj )
84 Parameterized programs
∇logsumexp(u) = softargmax(u).
where ū := M1 PM
j=1 uj ). Using this constraint together with
M
log πi = ui − log exp(uj ),
X
j=1
we then obtain
M M
log πi = −M log exp(uj )
X X
i=1 j=1
so that
M
1 X
ui = [softargmax−1 (π)]i = log πi − log πj .
M j=1
4.5. Residual neural networks 85
where L stands for length. Note that we use the same number of
parameters P for each setup for notational convenience, but this of
course does not need to be the case. Throughout this section, we use
the notation p1:L := (p1 , . . . , pL ) for a sequence of L vectors.
′
p1:L′ ∈ RL ×M , which potentially has a different length. An example
of application is machine translation, where the sentences in the source
and target languages do not necessarily have the same length. Typically,
p1:L′ = f u (x1:L , w) is defined as the following two steps
c := f e (x1:L , we )
p1:L′ := f d (c, wd )
where w := (we , wd ), and where we reused the previously-defined
encoder fe and decoder fd . Putting the two steps together, we obtain
sl := γ(xl , sl−1 , wγ ) l ∈ [L]
c = pooling(s1:L )
zl := g(c, pl−1 , zl−1 , wg ) l ∈ [L′ ]
pl := h(zl , wh ) l ∈ [L′ ].
This architecture is aptly named the encoder-decoder architecture.
Note that we denoted the length of the target sequence as L′ . However,
in practice, the target length can be input dependent and is often not
known ahead of time. To deal with this issue, the vocabulary (of size D
is our notation) is typically augmented with an “end of sequence” (EOS)
token so that, at inference time, we know when to stop generating the
output sequence. One disadvantage of this encoder-decoder architecture,
however, is that all the information about the input sequence is contained
in the context vector c, which can therefore become a bottleneck. For
this reason, this architecture has been largely replaced with attention
mechanisms and transformers, which we study in the next sections.
4.7 Summary
• greater than:
1 if u1 ≥ u2
gt(u1 , u2 ) :=
0 otherwise
= step(u1 − u2 )
90
5.1. Comparison operators 91
• less than:
1 if u1 ≤ u2
lt(u1 , u2 ) :=
0 otherwise
= 1 − gt(u1 , u2 )
= step(u2 − u1 )
• equal:
1 if |u1 − u2 | = 0
eq(u1 , u2 ) :=
0 otherwise
= gt(u2 , u1 ) · gt(u1 , u2 )
= step(u2 − u1 ) · step(u1 − u2 )
• not equal:
1 if |u1 − u2 | > 0
neq(u1 , u2 ) :=
0 otherwise
= 1 − eq(u1 , u2 )
= 1 − step(u2 − u2 ) · step(u1 − u2 ),
eq(µ1 , µ2 ) = δ(µ1 − µ2 )
where
1 if u = 0
δ(u) := .
0 if u ̸= 0
It equals 1 when µ1 = µ2 and is zero everywhere else. A natural idea to
obtain a continuous and smooth relaxation is therefore to replace δ by
a bell-shaped function. For this purpose, we can use a kernel, such as
the PDF of a distribution centered at 0. Formally, we may define
κ(µ1 , µ2 )
eq(µ1 , µ2 ) ≈ ,
κ(0, 0)
where
κ(µ1 , µ2 ) := p0,1 (µ1 − µ2 ),
for p0,1 the PDF of a zero-mean unit-scale distribution in the location-
scale family. The normalization κ(0, 0) ensures that the soft equality
operator is 1 at µ1 = µ2 . For instance, we can use the Gaussian kernel
2 /2
κ(µ1 , µ2 ) := e−(µ1 −µ2 )
in Fig. 5.1. The soft equality operator obtained with the logistic kernel
coincides with the expression Petersen et al. (2021) arrive at, in a
different manner.
κ(µ1 , µ2 )
eq(µ1 , µ2 ) ≈ corr(U1 , U2 ) = .
κ(0, 0)
Both the Gaussian and logistic kernels arise naturally as the density
of µ1 + Z1 − (µ2 + Z2 ) at 0, with Z1 and Z2 Gaussian or Gumbel random
variables.
5.4. Logical operators 95
• Associativity:
and(π, and(π ′ , π ′′ )) = and(and(π, π ′ ), π ′′ )
or(π, or(π ′ , π ′′ )) = or(or(π, π ′ ), π ′′ )
96 Control flows
• Neutral element:
and(π, 1) = π
or(π, 0) = π
• De Morgan’s laws:
and
1 if 1 ∈ {π1 , . . . , πK }
any(π) := .
0 otherwise
and(π, π ′ ) = π · π ′
or(π, π ′ ) = π + π ′ − π · π ′
not(π) = 1 − π.
1.0 1.0
0.8 0.8
0.6 0.6
0.4 0.4
0.2 0.2
0.0 0.0
1.0 1.0
0.8 0.8
0.0 0.6 0.0 0.6
0.2 0.4 0.2 0.4
0.4 0.2 0.4 0.2
0.6 0.6
0.8 0.0 0.8 0.0
1.0 1.0
Figure 5.2: The Boolean and and or operators are functions from {0, 1} × {0, 1}
to {0, 1} but their continuous extensions and(π, π ′ ) := π · π ′ as well as or(π, π ′ ) :=
π + π ′ − π · π ′ define a function from [0, 1] × [0, 1] to [0, 1].
K
all(π) =
Y
πi
i=1
K
any(π) = 1 − (1 − πi ).
Y
i=1
i=1
K
any(π) = P(Y1 = 1 ∪ · · · ∪ YK = 1) = 1 − (1 − P(Yi = 1)).
Y
i=1
This is the chain rule of probability for K independent variables, and
the addition rule for K variables.
Table 5.1: Examples of triangular norms and conorms, which are continuous
relaxations of the and and or operators, respectively. More instances can be obtained
by smoothing out the min and max operators.
t-norm t-conorm
Probabilistic π · π′ π + π′ − π · π′
Extremum min(π, π ′ ) max(π, π ′ )
Łukasiewicz max(π + π ′ − 1, 0) min(π + π ′ , 1)
and
I if π = 1
∂v1 ifelse(π, v1 , v0 ) :=
0 if π = 0
= π · I,
where I is the identity matrix of appropriate size. Most of the time,
if-else statements are composed with other functions. Let g1 : U1 → V
and g0 : U0 → V be differentiable functions. We then define v1 := g1 (u1 )
and v0 := g0 (u0 ), where u1 ∈ U1 and u0 ∈ U0 . The composition of
ifelse, g1 and g0 is then the function f : {0, 1} × U1 × U0 → V defined by
f (π, u1 , u0 ) := ifelse(π, g1 (u1 ), g0 (u0 ))
= π · g1 (u1 ) + (1 − π) · g0 (u0 ).
We obtain that the Jacobians are
∂u1 f (π, u1 , u0 ) = π · ∂g1 (u1 )
and
∂u0 f (π, u1 , u0 ) = (1 − π)∂g0 (u0 ).
As long as g1 and g0 are differentiable functions, we can therefore
differentiate through the branch variables u1 and u0 without any issue.
More problematic is the predicate variable π, as we now discuss.
∂p fh (p, u1 , u0 ) = ∂1 fh (p, u1 , u0 )
= step′ (p)(g1 (u1 ) − g0 (u0 ))
= 0.
gh (u1 , u0 ) := fh (t(u1 ), u1 , u0 ),
ifelse(π, v1 , v0 ) = π · v1 + (1 − π) · v0 .
where sigmoid is for instance the logistic function or the Gaussian CDF.
If we now define
Probabilistic perspective
From a probabilistic perspective, we can view Eq. (5.2) as the expecta-
tion of gi (ui ), where i ∈ {0, 1} is a binary random variable distributed
according to a Bernoulli distribution with parameter π = sigmoid(p):
Taking the expectation over the two possibles branches makes the
function differentiable with respect to p. Of course, this comes at the cost
of evaluating both branches, instead of a single one. The probabilistic
perspective suggests that we can also compute the variance if needed as
ei := (0, . . . , |{z}
1 , . . . , 0),
i
Combining booleans
5.7.2 Conditionals
We can now express a conditional statement as the function
cond : {e1 , . . . , eK } × V K → V defined by
v if π = e1
1
.
cond(π, v1 , . . . , vK ) := .. (5.3)
if π = eK
v
K
K
=
X
πi vi .
i=1
Similarly as for the ifelse function, the cond function is discontinuous
and nondifferentiable w.r.t. π ∈ {e1 , . . . , eK }. However, given π = ei
fixed for some i, the function is linear in vi and constant in vj for j ̸= i.
We illustrate how to express a simple example, using this formalism.
6 Mean
4 Standard deviation
Hard
2
0
2
4
6
4 2 0 2 4
As for the ifelse case, the Jacobian w.r.t. p is null almost everywhere,
∂p fa (p, u1 , . . . , uK ) = 0.
and similarly
softargmin(p) := softargmax(−p) ∈ △K .
i=1
Probabilistic perspective
Taking the expectation over the K possible branches makes the function
differentiable with respect to p, at the cost of evaluating all branches,
instead of a single one. Similarly as for the if-else case, we can compute
the variance if needed as
For loops are a control flow for sequentially calling a fixed number K
of functions, reusing the output from the previous iteration. In full
generality, a for loop can be written as follows.
for i := 1, . . . , N do
for j := 1, . . . , N − i − 1 do
v ′ := swap(v, j, j + 1)
π := step(vj − vj+1 )
v ← ifelse(π, v ′ , v)
v1 := u1
v2 := u1 + u2
v3 := u1 + u2 + u3
..
.
vk := sk−1 + uk
f (sk−1 , uk ) := (vk , vk )
Unlike for loops and scan, the number of iterations of while loops is
not known ahead of time, and may even be infinite. In this respect, a
while loop can be seen as a cyclic graph.
To avoid the issues with unbounded while loops, we can enforce that a
while loop stops after T iterations, i.e., we can truncate the while loop.
112 Control flows
π0 := f (s0 )
if π0 = 1 then
r := s0
else
s1 := g(s0 ), π1 := f (s1 )
if π1 = 1 then
r := s1
else
s2 := g(s1 ), π2 := f (s2 )
if π2 = 1 then
r := s2
else
r := s3 := g(s2 )
r = ifelse(π0 ,
s0 ,
ifelse(π1 ,
s1 ,
ifelse(π2 ,
s2 ,
s3 ))).
i=0
T i−1
= (1 − πj ) πi si ,
X Y
i=0 j=0
where we defined
πi
if i = j
P(St+1 = si |St = sj ) = pi,j := (1 − πi ) if i = j + 1 ,
0 otherwise
s0 s1 s2 s3
s0 0 1 0 0
P := (pi,j )Ti,j=0 := s1 0 0 1 0 .
s 0
2 0 0 1
s3 0 0 0 1
i=1
+∞
= P(I = i)si−1
X
i=1
X i−2
+∞
= (1 − πj )πi−1 si−1
Y
i=1 j=0
X i−1
+∞
= (1 − πj )πi si .
Y
i=0 j=0
Because the stopping time is not known ahead of time, the sum over i
goes from 0 to ∞. However, if we enforce in the stopping criterion that
the while loop run no longer than T iterations, by setting
πi := or(f (si ), eq(i, T )) ∈ {0, 1},
we then naturally recover the expression found by unrolling the while
loop before,
T i−1
r = E[SI ] = (1 − πj )πi si .
X Y
i=0 j=0
With step, the derivative of the while loop with respect to τ will always
be 0, just like it was the case for if-else statements. If we change the
stopping criterion to f (si ) = sigmoid(τ − ε(si )), we then have (recall
that or is well defined on [0, 1] × [0, 1])
i=0 j=0
T i−1
= (1 − sigmoid(ε(si ) − τ ))sigmoid(ε(si ) − τ )si .
X Y
i=0 j=0
5.11 Summary
Differentiating through
programs
6
Finite differences
From Definition 2.3 and Definition 2.12, the directional derivative and
more generally the JVP are defined as a limit,
f (w + δv) − f (w)
∂f (w)[v] := lim .
δ→0 δ
This suggests that we can approximate the directional derivative and
the JVP using
f (w + δv) − f (w)
∂f (w)[v] ≈ ,
δ
118
6.2. Backward differences 119
δ2 2 δ3
f (w+δv)−f (w) = δ∂f (w)[v]+ ∂ f (w)[v, v]+ ∂ 3 f (w)[v, v, v]+. . .
2 3!
so that
f (w + δv) − f (w) δ δ2
= ∂f (w)[v] + ∂ 2 f (w)[v, v] + ∂ 3 f (w)[v, v, v] + . . .
δ 2 3!
= ∂f (w)[v] + o(δ).
10−6
Approx. Error
10−8
10−10
10−12
Figure 6.1: The forward or central difference schemes applied to f (x) := ln(1 +
exp(−x)) to approximate f ′ (x) at x = 1 induce both truncation error (for large δ)
and round-off error (for small δ).
i=0
δ
By grouping the terms in the sum for each order of derivative, we obtain
a set of p + 1 equations to be satisfied by the p + 1 coefficients a0 , . . . , ap ,
that is,
a0 + a1 + . . . + ap = 0
a1 + 2a2 + . . . + pap = 1
a1 + 2j a2 + . . . + pj ap = 0 ∀j ∈ {2, . . . , p}.
i=−p
δ
i=0
δk
As before, we can expand the terms in the sum. For the approximation
to capture only the k th derivative, we now require the coefficients ai to
satisfy
0j a0 + 1j a1 + 2j a2 + . . . + pj ap = 0 ∀j ∈ {0, . . . , k − 1}.
0k a0 + 1k a1 + 2k a2 + . . . + pk ap = k!
0j a0 + 1j a1 + 2j a2 + . . . + pj ap = 0 ∀j ∈ {k + 1, . . . , p}.
i=−p
δk
and
f (w + (iδ)v)
Im = ∂f (w)[v] + o(δ 2 ).
δ
This suggests that we can compute directional derivatives using the
approximation
f (w + (iδ)v)
∂f (w)[v] ≈ Im ,
δ
for 0 < δ ≪ 1. This is called the complex-step derivative approxi-
mation (Squire and Trapp, 1998; Martins et al., 2003).
Contrary to forward, backward and central differences, we see that
only a single function call is necessary. A function call on complex
numbers may take roughly twice the cost of a function call on real
numbers. However, thanks to the fact that a difference of functions is
no longer needed, the complex-step derivative approximation usually
enjoys smaller round-off error as illustrated in Fig. 6.1. That said, one
drawback of the method is that all elementary operations within the
program implementing the function f must be well-defined on complex
numbers, e.g., using overloading.
124 Finite differences
6.7 Complexity
6.8 Summary
round-off error than central differences but require the function and
the program implementing it to be well-defined on complex numbers.
However, whatever the method used, finite differences require a number
of function calls that is proportional to the number of dimensions. They
are therefore seldom used in machine learning, where there can be
millions or billions of dimensions. The main use cases of finite differ-
ences are therefore i) for blackbox functions of low dimension and ii)
for test purposes (e.g., checking that a gradient function is correctly
implemented). For modern machine learning, the main workhorse is
automatic differentiation, as it leverages the compositional structure of
functions. This is what we study in the next chapter.
7
Automatic differentiation
126
7.1. Computation chains 127
s0 ∈ S0
s1 := f1 (s0 ) ∈ S1
..
.
sK := fK (sK−1 ) ∈ SK
f (x) := sK . (7.1)
∂f (s0 ) = ∂fK (sK−1 )∂fK−1 (sK−2 ) . . . ∂f2 (s1 ), ∂f1 (s0 ), (7.2)
where ∂fk (sk−1 ) are the Jacobians of the intermediate functions com-
puted at s0 , . . . , sK , as defined in Eq. (7.1). The main drawback of
this approach is computational: computing the full ∂f (s0 ) requires
to materialize the intermediate Jacobians in memory and to perform
matrix-matrix multiplications. However, in practice, computing the full
Jacobian is rarely needed. Indeed, oftentimes, we only need to right-
multiply or left-multiply with ∂f (s0 ). This gives rise to forward-mode
and reverse-mode autodiff, respectively.
7.1.1 Forward-mode
... ...
... ...
t0 := v
t1 := ∂f1 (s0 )[t1 ]
..
.
tK := ∂fK (sK−1 )[tK−1 ]
∂f (s0 )[v] := tK .
7.1.2 Reverse-mode
In machine learning, most functions whose gradient we need to compute
take the form ℓ ◦ f , where ℓ is a scalar-valued loss function and f is a
network. As seen in Proposition 2.3, the gradient takes the form
Memory
usage
Algorithm steps
Figure 7.2: Memory usage of forward-mode autodiff for a computation chain. Here
t0 = v, sK = f (s0 ), tK = ∂f (s0 )[v].
This motivates the need for applying the adjoint ∂f (s0 )∗ to ∇ℓ(f (s0 )) ∈
SK and more generally to any output direction u ∈ SK . From Propo-
sition 2.7, we have
rK = u
rK−1 = ∂fK (sK−1 )∗ [rK ]
..
.
r0 = ∂f1 (s0 )∗ [r1 ]
∂f (s0 )∗ [u] = r0 .
Forward pass
... ...
... ...
Backward pass
Memory
usage
Algorithm steps
Forward pass Backward pass
Figure 7.4: Memory usage of reverse mode autodiff for a computation chain.
where
∂f (s0 )∗ [u] := backward(u; s0 , . . . , sK−1 ).
In functional programming terminology, the VJP ∂f (s0 )∗ is a closure,
as it contains the intermediate computations s0 , . . . sK . The same can
be done for the JVP ∂f (s0 ) if we want to apply to multiple directions
vi .
s0 = x
s1 = f1 (s0 ) = σ(A1 s0 + b1 )
s2 = f2 (s1 ) = A2 s1 + b2
f (x) = s2 ,
t0 = v
t1 = σ ′ (A1 s0 + b1 ) ⊙ (A1 t0 )
t2 = A2 t1
∂f (x)[v] = t2 ,
r2 = u
r1 = ∂f2 (s1 )∗ [r2 ] = A⊤
2 r2
r0 = ∂f1 (s0 )∗ [r1 ] = A⊤ ′
1 (σ (A1 s0 + b1 ) ⊙ r1 )
∂f (x)∗ [u] = r0 .
Forward-mode Reverse-mode
Time O(M D2 + KD3 ) O(M 2 D + KM D2 )
Space O(max{M, D}) O(KD + M )
Table 7.1: Time and space complexities of forward-mode and reverse-mode autodiff
for computing the full Jacobian of a chain of functions f = fK ◦ · · · ◦ f1 , where
fk : RD → RD if k = 1, . . . , K − 1 and fK : RD → RM . We assume ∂fk is a dense
linear operator. Forward mode requires D JVPs. Reverse mode requires M VJPs.
Using Definition 2.8, we find that we can extract each row of the Jaco-
bian matrix, which corresponds to the transposed gradients ∇fi (s0 ) ∈
RD , for i ∈ [M ], by multiplying with the standard basis vector ei ∈ RM :
... ...
∇L(w; x, y) = (g1 , . . . , gK ) ∈ W1 × · · · × WK
where
∂f (x, w)∗ [u] = (r0 , (g1 , . . . , gK )),
with u = ∇ℓ(f (x, w); y) ∈ SK . The output r0 ∈ S0 , where S0 =
X , corresponds to the gradient w.r.t. x ∈ X and is typically not
needed, except in generative modelling settings. The full procedure is
summarized in Algorithm 7.3.
138 Automatic differentiation
Forward pass
... ...
Backward pass
j=1
7.3.1 Forward-mode
The forward mode corresponds to computing a JVP in an input direction
v ∈ S0 . The algorithm consists in computing intermediate JVPs along
the forward pass. We initialize t0 := v ∈ S0 . Using Proposition 2.8, the
derivatives on iteration k ∈ [K] are propagated as
j=1
7.3.2 Reverse-mode
The theorem gives a lower bound on the size of the best circuit for
computing ∇f from the size of the best circuit for computing f .
7.4 Implementation
(Section 4.1.4), A = {+, ×}. More generally, A may contain all the
necessary functions for expressing programs. We emphasize, however,
that A is not necessarily restricted to low-level functions such as log
and exp, but may also contain higher-level functions. For instance,
even though the log-sum-exp can be expressed as the composition
of elementary operations (log, sum, exp), it is usually included as a
primitive on its own, both because it is a very commonly-used building
block, but also for numerical stability reasons.
An autodiff system must implement for each f ∈ A its JVP for support-
ing the forward mode, and its VJP for supporting the reverse mode.
We give a couple of examples. We start with the JVP and VJP of linear
functions.
Example 7.2 (JVP and VJP of linear functions). Consider the matrix-
vector product f (W ) = W x ∈ RM , where x ∈ RD is fixed and
W ∈ RM ×D . As already mentioned, the JVP of f at W ∈ RM ×D
along an input direction V ∈ RM ×D is simply
∂f (W )[V ] = f (V ) = V x ∈ RM .
∂f (W )∗ [u] = ux⊤ ∈ RM ×D .
∂f (W )[V ] = f (V ) = V X ∈ RM ×N .
∂f (W )∗ [U ] = U X ⊤ ∈ RM ×D .
Example 7.3 (JVP and VJP of separable function). Consider the func-
tion f (w) := (g1 (w1 ), . . . , gP (wP )), where each gi : R → R has a
derivative gi′ . The Jacobian matrix is then a diagonal matrix
In this case, the JVP and VJP are actually the same
7.5 Checkpointing
Function step
2
1
0
Time step
M(K) = log2 K.
Analytical formula
It turns out that we can also find an optimal scheme analytically. This
scheme was found by Griewank (1992), following the analysis of optimal
inversions of sequential programs by divide-and-conquer algorithms
done by Grimm et al. (1996); see also Griewank (2003, Section 6) for
a simple proof. The main idea consists in considering the number of
times an evaluation step fk is repeated. As we split the chain at l, all
steps from 1 to l will be repeated at least once. In other words, treating
the second half of the chain incurs one memory cost, while treating the
first half of the chain incurs one repetition cost. Griewank (1992) shows
that for fixed K, S, we can find the minimal number of repetitions
analytically and build the corresponding scheme with simple formulas
for the optimal splits.
Compared to the dynamic programming approach, it means that we
do not need to compute the pointers l∗ (k, s), and we can use a simple
formula to set l∗ (k, s). We still need to traverse the corresponding binary
tree given K, S and the l∗ (k, s) to obtain the schedules. Note that such
optimal scheme does not take into account varying computational costs
for the functions fk .
The optimal scheme presented above requires knowing the total number
of nodes in the computation graph ahead of time. However, when
differentiating through for example a while loop (Section 5.10), this is
not the case. To circumvent this issue, online checkpointing schemes
have been developed and proven to be nearly optimal (Stumm and
Walther, 2010; Wang et al., 2009). These schemes start by defining a set
of S checkpoints with the first S computations, then these checkpoints
are rewritten dynamically as the computations keep going. Once the
computations terminate, the optimal approach presented above for a
fixed length is applied on the set of checkpoints recorded.
7.6. Reversible layers 151
7.8 Summary
154
8.1. Hessian-vector products 155
8.1.2 Complexity
To get a sense of the computational and memory complexity of the four
approaches, we consider a chain of functions f := fK ◦ · · · ◦ f1 as done
156 Second-order automatic differentiation
Method Computation
Reverse on reverse (VJP of gradient) ∂(∇f )(w)∗ [v]
Forward on reverse (JVP of gradient) ∂(∇f )(w)[v]
Reverse on forward (gradient of JVP) ∇(∂f (·)[v])(w)
Forward on forward (JVPs of JVPs) (∂ 2 f (w)[v, ei ])Pi=1
Figure 8.1: Computation graph corresponding to reverse mode autodiff for eval-
uating the gradient of f = fK ◦ . . . f1 . While f is a simple chain, ∇f is a DAG.
Figure 8.2: Computation graph for computing the HVP ∇2 f (x)[v] by using reverse
mode on top of reverse mode. As the computation graph of ∇f induces fan-in
operations sk−1 , rk 7→ ∂fk (sk−1 )[rk ], the reverse mode applied on ∇f induces
branching of the computations at each such node.
Figure 8.3: Computation graph for computing the HVP ∇2 f (x)[v] by using forward
mode on top of reverse mode. The forward mode naturally follows the computations
done for the gradient, except that it passes through the derivatives of the intermediate
operations.
158 Second-order automatic differentiation
j=1
M
∇2GN (ℓ ◦ f ) = λi ∂f (w)∗ ui u⊤
i ∂f (w)
∗
X
j=1
M p p ⊤
= λi ∂f (w)∗ ui λi ∂f (w)∗ ui
X
j=1
M
= vi vi⊤ where vi := λi ∂f (w)∗ ui .
X p
j=1
With some slight abuse of notation, we then have that the Gauss-Newton
matrix associated with a pair (x, y) is
is then h i
∇2GN L(w) = EX,Y ∼ρ ∇2GN L(w; X, Y ) .
∇2F L(w) = ES∼qw [∇2 L(w; S)] = ES∼qw [−∇2w log qw (S)].
where ρX (x) :=
R
ρ(x, y)dy.
where we used · to indicate that the results holds for all y. Plugging the
result back in the Fisher information matrix concludes the proof.
8.4. Inverse-Hessian vector product 165
u 7→ ∇2 L(w)−1 u,
H[v] = u
by only accessing the linear map v 7→ H[v] for any v. Among such
algorithms, we have the conjugate gradient (CG) method, that applies
for H positive-definite, i.e., such that ⟨v, H[v]⟩ > 0 for all v ̸= 0, or the
generalized minimal residual (GMRES) method, that applies for
any invertible H. A longer list of solvers can be found in public software
such as SciPy (Virtanen et al., 2020). The IHVP of a striclty convex
166 Second-order automatic differentiation
8.4.3 Complexity
j=1
i=1
n
+ ∂fi (w)∗ ∂i,j g(f (w))∗ [u]∂fj (w).
X
2
i,j=1
independently.
i=1
n
+ ∂f (w)∗ ∂ 2 gi (f (w))∗ [u]∂f (w).
X
i=1
s0 := x
sk := fk (sk−1 , wk ) ∀k ∈ {1, . . . , K}
f (x, w) := sK ,
Rk−1 := Wk⊤ Jk Wk
Jk := Rk ⊙ (a′ (Wk sk−1 )a′ (Wk sk−1 )⊤ )
Gk := Jk ⊗ sk−1 s⊤
k−1
diagonal of A,
Eω∼p [ω ⊙ Aω] = Diag(A),
where ⊙ denotes the Hadamard product (element-wise multiplication).
This suggests that we can use the Monte-Carlo method to estimate the
diagonal of A,
S
1X
Diag(A) ≈ ωi ⊙ Aωi ,
S i=1
with equality as S → ∞, since the estimator is unbiased. Since, as
reviewed in Section 8.1 and Section 8.2, we know how to multiply
efficiently with the Hessian and the Gauss-Newton matrices, we can
apply the technique with these matrices. The variance is determined
by the number S of samples drawn and therefore by the number of
matvecs performed. More elaborated approaches have been proposed to
further reduce the variance (Meyer et al., 2021; Epperly et al., 2023).
Suppose the objective function is of the form L(w; x, y) := ℓ(f (w; x); y)
where ℓ is the negative log-likelihood ℓ(θ; y) := − log pθ (y) of an ex-
ponential family distribution, and θ := f (w; x), for some network f .
We saw from the equivalence between the Fisher and Gauss-Newton
matrices in Proposition 8.6 (which follows from the Bartlett identity)
that
∇2GN L(w; x, ·) = EY ∼pθ [∂f (w; x)∗ ∇ℓ(θ; Y ) ⊗ ∇ℓ(θ; Y )∂f (w; x)]
= EY ∼pθ [∇L(w; x, Y ) ⊗ ∇L(w; x, Y )],
where · indicates that the result holds for any value of the second
argument. This suggests a Monte-Carlo scheme
S
1X
∇2GN L(w; x, ·) ≈ [∇L(w; x, yij ) ⊗ ∇L(w; x, yij )]
S j=1
γj = E γi ⊙ γi +
X X X X
E γi ⊙ γi ⊙ γj
i j i i̸=j
" #
=E
X
γi ⊙ γi
i
where we used that E[γi ⊙ γj ] = E[γi ] ⊙ E[γj ] = 0 since γi and γj are
independent variables for i ̸= j and have zero mean, from Bartlett’s
first identity Eq. (11.4). We can then use the Monte-Carlo method to
obtain
S S
1 1X 1X
diag(∇2GN L(w; x, ·)) ≈ ∇ L(w; x, yij ) ⊙ ∇ L(w; x, yij ) ,
S S j=1
S j=1
with equality when all labels in the support of pθ have been sampled.
This estimator can be more convenient to implement, since it only needs
access to the gradient of the averaged losses. However, it may suffer
from higher variance. A special case of this estimator is used by Liu
et al. (2023), where they draw only one y for each x.
8.9. Summary 179
8.9 Summary
180
9.2. Conditional independence 181
j=1
K j−1
=
Y \
P Aj Ai .
j=1 i=1
j=1
s∈S s1 ,...,sK ∈S
As we shall see, the graph of a graphical model encodes the dependen-
cies between the variables (S1 , . . . , SK ) and therefore how their joint
distribution factorizes. Given access to a joint probability distribution,
there are several inference problems one typically needs to solve.
9.3.2 Likelihood
A trivial task is to compute the likelihood of some observations s =
(s1 , . . . , sK ),
P(S1 = s1 , . . . , Sk = sk ) = p(s1 , . . . , sk ).
It is also common to compute the log-likelihood,
log P(S1 = s1 , . . . , Sk = sk ) = log p(s1 , . . . , sk ).
P(Sk = sk ) = p(s1 , . . . , sK )
X
= p(s1 , . . . , sK )
X X
Defining similarly
we obtain
P(Sk = sk , Sl = sl ) = p(s1 , . . . , sK ).
X
µ := ES∼p [ϕ(S)] =
X
p(s)ϕ(s) ∈ M
s∈S
184 Inference in graphical models as differentiation
s∈S
s∈S
Convex hull
s∈S
[µ]k,i = ES [ϕ(S)k,i ]
= ESk [ϕk (Sk )i ]
= ESk [I(Sk = vi )]
= P(Sk = sk )I(sk = vi )
X
sk ∈Sk
= P(Sk = vi ).
[µ]k,l,i,j = ES [ϕ(S)k,l,i,j ]
= ESk ,Sl [ϕk,l (Sk , Sl )i,j ]
= ESk ,Sl [I(Sk = vi , Sl = vj )]
= P(Sk = sk , Sl = sl )I(sk = vi , sl = vj )
X X
sk ∈Sk sl ∈Sl
= P(Sk = sk , Sl = vj ).
S0 := s0
S1 ∼ p1 (· | S0 )
S2 ∼ p2 (· | S1 )
..
.
SK ∼ pK (· | SK−1 ).
...
...
...
Figure 9.1: Left: Markov chain. Right: Computation graph of the forward-
backward and the Viterbi algorithms: a lattice.
start ... end
P(S1 = s1 , . . . , SK = sK ) = p(s1 , . . . , sK )
K
= P(Sk = sk | Sk−1 = sk−1 )
Y
k=1
K
= pk (sk | sk−1 ),
Y
k=1
Sk ∼ Categorical(πk−1,k,Sk−1 )
where
πk−1,k,i := softargmax(θk−1,k,i ) ∈ △M
= (πk−1,k,i,j )M
j=1
θk−1,k,i := (θk−1,k,i,j )M
j=1 ∈ R
M
We therefore have
P(Sk = j | Sk−1 = i) = pk (j | i)
= πk−1,k,i,j
= [softargmax(θk−1,k,i )]j
exp(θk−1,k,i,j )
=P
j ′ exp(θk−1,k,i,j ′ )
and
j′
p1 = · · · = pK = p.
More generally, a nth -order Markov chain may depend, not only on the
last variable, but on the last n variables,
Markov chains and more generally higher-order Markov chains are a spe-
cial case of Bayesian network. Similarly to computation graphs reviewed
in Section 7.3, variable dependencies can be expressed using a directed
acylic graph (DAG) G = (V, E), where the vertices V = {1, . . . , K}
represent variables and edges E represent variable dependencies. The
set {i1 , . . . , ink } = parents(k) ⊆ V, where nk := |parents(k)|, indicates
the variables Si1 , . . . , Sink that Sk depends on. This defines a partially
ordered set (poset). For notational simplicity, we again assume with-
out loss of generality that S0 is deterministic. A computation graph
is specified by functions f1 , . . . , fK in topological order. In analogy, a
Bayesian network is specified by conditional probability distribu-
tions pk of Sk given Sparents(k) . We can then define the generative
190 Inference in graphical models as differentiation
process
S0 := s0
S1 ∼ p1 (· | S0 )
S2 ∼ p2 (· | Sparents(2) )
..
.
SK ∼ pK (· | Sparents(K) ).
P(S = s) := P(S1 = s1 , . . . , SK = sK )
K
= P(Sk = sk |Sparents(k) = sparents(k) )
Y
k=1
K
:= pk (sk |sparents(k) )
Y
k=1
Z := ψC (sC ),
X Y
s∈S C∈C
k=1
matically normalized),
1 Y
pθ (s) := ψC (sC ; θC )
Z(θ) C∈C
1 Y
= exp(⟨θC , ϕC (sC )⟩)
Z(θ) C∈C
1
!
= exp ⟨θC , ϕC (sC )⟩
X
Z(θ) C∈C
1
= exp (⟨θ, ϕ(s)⟩)
Z(θ)
= exp (⟨θ, ϕ(s)⟩ − A(θ))
9.6. Markov random fields 193
where
s∈S C∈C
s∈S
A(θ) := log Z(θ)
P(Y = y) = pθ (y)
= exp θi yi +
X X
θi,j yi yj − A(θ)
i∈V (i,j)∈E
!
= exp ⟨θC , ϕC (y)⟩ − A(θ) ,
X
C∈C
(Greig et al., 1989). There are two ways the above equation can
be extended. First, we can use higher-order interactions, such as
yi yj yk for (i, j, k) ∈ V 3 . Second, we may want to use categorical
variables, which leads to the Potts model.
9.6.4 Sampling
Contrary to Bayesian networks, MRFs require an explicit normalization
constant Z. As a result, sampling from a distribution represented by a
general MRF is usually more involved than for Bayesian networks. A
commonly-used technique is Gibbs sampling.
where
K
Z := ψk (sk−1 , sk−1 )
XY
s∈S k=1
and where we used ψk as a shorthand for ψk−1,k , since k − 1 and k are
consecutive. As explained in Example 9.2, this also includes Markov
9.7. Inference on chains 195
chains by setting
in which case Z = 1.
sk−1 ∈Sk−1
sk+1 ∈Sk+1
sK ∈SK s1 ∈S1
and the marginal probabilities by
1
P(Sk = sk ) = αk (sk )βk (sk )
Z
1
P(Sk−1 = sk−1 , Sk = sk ) = αk−1 (sk−1 )ψk (sk−1 , sk )βk (sk ).
Z
We can also compute the conditional probabilities by
P(Sk−1 = sk−1 , Sk = sk )
P(Sk = sk | Sk−1 = sk−1 ) =
P(Sk−1 = sk−1 )
αk−1 (sk−1 )ψk (sk−1 , sk )βk (sk )
=
αk−1 (sk−1 )βk−1 (sk−1 )
ψk (sk−1 , sk )βk (sk )
= .
βk−1 (sk−1 )
In practice, the two recursions are often implemented in the log-domain
for numerical stability,
log αk (sk ) = log exp(log ψk (sk−1 , sk ) + log αk−1 (sk−1 ))
X
sk−1 ∈Sk−1
log βk (sk ) = log exp(log ψk+1 (sk , sk+1 ) + log βk+1 (sk+1 )).
X
sk+1 ∈Sk+1
6: δ ⋆ := max δK (sK )
sK ∈SK
7: s⋆K := arg max δK (sK )
sK ∈SK
8: for k := K − 1, . . . , 1 do ▷ Backtracking
9: s⋆k := qk+1 (s⋆k+1 )
Outputs: max p(s1 , . . . , sK ) ∝ δ ⋆
s1 ∈S1 ,...,sK ∈SK
arg max p(s1 , . . . , sK ) = (s⋆1 , . . . , s⋆K )
s1 ∈S1 ,...,sK ∈SK
s∈S
• Commutativity of ⊕: a ⊕ b = b ⊕ a,
• Associativity of ⊕: a ⊕ (b ⊕ c) = (a ⊕ b) ⊕ c,
with ε := 1 by default.
and
!
argmaxε f (v) := exp(f (v ′ )/ε)/ exp(f (v)/ε)
X
∈ P(V).
v∈V v∈V v ′ ∈V
and
Q := argmaxε aK,j ∈ △M ,
j∈[M ]
=
X
rk+1,j qk+1,j,i
j∈[M ]
=
X
µk+1,i,j .
j∈[M ]
6: A := maxε aK,i ∈ R
i∈[M ]
7: Q := argmaxε aK,i ∈ △M
i∈[M ]
8: Initialize rK,j = Qj ∀j ∈ [K]
9: for k := K − 1, . . . , 1 do ▷ Backward pass
10: for i ∈ [M ] do
11: for j ∈ [M ] do
12: µk+1,i,j = rk+1,j · qk+1,j,i
13: rk,i ← µk+1,i,j
maxε θ1,1,i1 + = A, ∇A(θ) = µ
PK
Outputs: k=2 θk,ik−1 ,ik
i1 ,...,iK ∈[M ]K
9.10 Summary
205
206 Differentiating through optimization
F (w, λ) = 0
for all λ ∈ Λ.
h(λ) := g(w⋆ (λ), λ), where w⋆ (λ) = arg max f (w, λ). (10.1)
w∈W
Function
Figure 10.1: The graph of h(λ) = maxw∈W f (w, λ) is the upper-envelope of the
graphs of the functions λ 7→ f (w, λ) for all w ∈ W.
and
w⋆ (λ) := arg max f (w, λ).
w∈W
If f is concave in w, convex in λ, and the maximum w⋆ (λ) is
unique, then the function h is differentiable with gradient
and
w⋆ (λ) := arg max f (w, λ).
w∈W
If f is continuously differentiable in λ for all w ∈ W, ∇1 f is
continuous and the maximum w⋆ (λ) is unique, then the function
h is differentiable with gradient
• x⋆ (λ0 ) = x0 ,
meaning that the Jacobian ∂w⋆ (λ), assuming that it exists, satisfies
• w⋆ (λ0 ) = w0 ,
can differentiate through w⋆ (λ) using the IFT, assuming that the
conditions of the theorem apply. Note that ∂1 F (w, λ) requires
the expression of the Jacobian ∂PC (y). Fortunately, PC (y) and its
Jacobian are easy to compute for many sets C (Blondel et al., 2021).
Assume the assumptions of the IFT hold. The JVP t := ∂w⋆ (λ)v
in the input direction v ∈ Λ is obtained by solving the linear system
At = Bv.
A∗ r = u.
use GMRES (Saad and Schultz, 1986) or BiCGSTAB (Vorst and Vorst,
1992).
s0 := x ∈ X
s1 := f1 (s0 , w1 ) ∈ S1
..
.
sK := fK (sK−1 , wK ) ∈ SK
f (w) := sK . (10.2)
s1 − f1 (x, w1 )
s2 − f2 (s1 , w2 )
c(s, w) :=
.. .
.
sK − fK (sK−1 , wK )
This defines an implicit function s⋆ (w) = (s⋆1 (w), . . . , s⋆K (w)), the
solution of this nonlinear system, which is given by the variables
s1 , . . . , sK defined in (10.2). The output of the feedforward network is
then f (w) = s⋆K (w).
In machine learning, the final layer s⋆K (w) is typically fed into a
loss ℓ, to define
L(w) := ℓ(s⋆K (w); y).
Note that an alternative is to write L(w) as
so that
∂f (w)∗ uK = ∇L(w).
Let us define u ∈ S1 × · · · × SK−1 × SK as u := (0, . . . , 0, ∇1 ℓ(f (w); y))
(gradient of the loss ℓ case) or u := (0, . . . , 0, uK ) (VJP of f in the
direction uK case). Using the adjoint state method, we know that the
gradient of this objective is obtained as
I 0 0
... ...
.. ..
I . .
−A
1
..
∂1 c(s (w), w) = 0
⋆
..
−A2 I . . ,
.
. .. .. ..
. . . . 0
0 ... 0 −AK I
I −A∗1 0 0
...
.. ..
0 I . .
−A∗2
.
.. ..
..
∂1 c(s⋆ (w), w)∗ =
. I . 0 .
.
. .. ..
. . . −A∗K
0 ... ... 0 I
∂2 f1 (x, w1 )∗ r1
∗
∂2 f2 (s1 (w), w2 )∗ r2
∂2 c(s(w), w) r =
.. ,
.
∂2 fK (s1 (w), wK )∗ rK
220 Differentiating through optimization
with w = f −1 (ω).
f ◦ f −1 (ω) = ω.
∂f (f −1 (ω))∂f −1 (ω) = I,
10.6 Summary
224
11.2. Differentiating through expectations 225
Discrete case
When Y is a discrete set (that is, pθ (y) is a probability mass function),
we can rewrite Eq. (11.1) as
E(θ) = pθ (y)g(y).
X
y∈Y
11.2. Differentiating through expectations 227
We then obtain
∇E(θ) = g(y)∇θ pθ (y). (11.2)
X
y∈Y
exp(θy )
pθ (Y = y) := P = [softmax(θ)]y ∈ (0, 1),
i∈[M ] exp(θi )
exp(⟨ϕ(y), θ⟩)
pθ (Y = y) := P .
Y∈Y exp(⟨ϕ(y), θ⟩)
228 Differentiating through integration
∇θ pθ (y) = ∇θ log pθ (y1 )+∇θ log pθ (y2 |y1 )+· · ·+∇θ log pθ (yL |y1 , . . . , yL−1 ).
y∈Y
11.3. Score function estimators, REINFORCE 231
so that
∇θ log pθ (y) = ey /γ − ∇A(θ).
We therefore see that ∇θ log pθ (y) crucially depends on ∇A(θ),
the gradient of the log-partion. This gradient is available for some
structured sets Y but not in general.
As another example, we apply SFE in Section 13.4 to derive the
gradient of perturbed functions.
Baseline
SFE is known to suffer from high variance (Mohamed et al., 2020).
This means that this estimator may require us to draw many samples
from the distribution pθ to work well in practice. One of the simplest
variance reduction technique consists in shifting the function g with a
constant β, called a baseline, to obtain
for any valid distribution pθ . The baseline β is often set to the running
average of past values of the function g, though it is neither optimal
nor does it guarantee to lower the variance (Mohamed et al., 2020).
Control variates
Another general technique are control variates. Let us denote the
expectation of a function h : RM → R under the distribution pθ as
Suppose that H(θ) and its gradient ∇H(θ) are known in closed form.
Then, for any γ ≥ 0, we clearly have
and therefore
Proof. Z
∂E(θ) = ∂θ pθ (y)g(y)dy
Y
Z
= ∂θ [pθ (y)g(y)]dy
Y
Z
= g(y) ⊗ ∇θ pθ (y)dy
Y
Z
= pθ (y)g(y) ⊗ ∇θ log pθ (y)dy
Y
= EY ∼pθ [g(Y ) ⊗ ∇θ log pθ (Y )].
Such a transformation exists, not only for the normal distribution, but
for location-scale family distributions. The key advantage is that we
can now easily compute ∇E(θ), since θ is no longer involved in the
distribution. We can generalize this idea, as summarized below.
The inverse transform method can be used for sampling from a proba-
bility distribution, given access to its associated quantile function.
Recall that the cumulative distribution function (CDF) associated with
a random variable Y is the function FY : R → [0, 1] defined by
In the general case of CDF functions that are not strictly increasing,
the quantile function is usually defined as
Note that, in the above example, the error function erf and its
inverse do not enjoy analytical expressions but autodiff packages usually
238 Differentiating through integration
Pushforward measures
Sk ∼ pk (· | sdeterm(k) , Srandom(k) )
⇐⇒ Sk ∼ pk (· | si1 , . . . , sipk , Sj1 , . . . , Sjqk )
Sk := fk (sdeterm(k) , Srandom(k) )
:= fk (si1 , . . . , sipk , Sj1 , . . . , Sjqk )
sk := fk (sdeterm(k) )
:= fk (si1 , . . . , sipk ).
Special cases
If all nodes are function nodes, we recover computation graphs, reviewed
in Section 4.1.3. If all nodes are distribution nodes, we recover Bayesian
networks, reviewed in Section 9.5.
11.5.2 Examples
We now present several examples that illustrate our formalism. We use
the legend below in the following illustrations.
244 Differentiating through integration
Deterministic Stochastic
Function Sampler
variable variable
S1 ∼ p1 (· | s0 )
S2 := f2 (S1 )
E(s0 ) := E[S2 ]
∇E(s0 ) = ES1 [f2 (S1 )∇s0 log p1 (S1 | s0 )]
S1 ∼ p1
S2 := f2 (S1 , s0 )
E(s0 ) := E[S2 ]
∇E(s0 ) = ES1 [∇s0 f2 (S1 , s0 )]
s1 := f1 (s0 )
S2 ∼ p2 (· | s1 )
S3 := f3 (S2 )
E(s0 ) := E[S3 ]
∇E(s0 ) = ∂f (s0 )∗ ES2 [f3 (S2 )∇s1 log p2 (S2 | s1 )]
11.5. Stochastic programs 245
• Example 4:
s1 := f1 (s0 )
s2 := f2 (s0 )
S3 ∼ p3 (· | s1 )
S4 ∼ p4 (· | s2 , S3 )
S5 := f5 (S4 )
E(s0 ) := E[S5 ] = ES3 [ES4 [f5 (S4 )]]
∇E(s0 ) = (ES3 [∇s1 log p(S3 | s1 )ES4 [f5 (S4 )]] ,
ES3 [ES4 [f5 (S4 )∇s2 log p4 (S4 |s2 , S3 )]])
SK := f (s0 ).
Sk ∼ pk (·|sdeterm(k) , Srandom(k) ),
Transformation
(location-scale transform, inverse transform)
Change-of-variables theorem
sk := sk−1 + hk (sk−1 , wk ).
sk − sk−1 = hk (sk , wk ).
s(0) = x
s′ (t) = h(t, s(t), w) t ∈ [0, T ] (11.7)
Differential equations like Eq. (11.7) arise in many contexts beyond neu-
ral ODEs, ranging from modeling physical systems to pandemics (Braun
and Golubitsky, 1983). Moreover, the differential equation presented
in Eq. (11.7) is just an example of an ordinary differential equation,
while controlled differential equations or stochastic differential equations
can also be considered.
Existence of a solution
s(0) = s0
s′ (t) = h(t, s(t)) t ∈ [0, T ],
st = exp(tA)(s0 ),
where exp(A) is the matrix exponential. Hence, the output s(T ) can
be expressed as a simple function of the parameters (A in this case).
However, generally, we do not have access to such analytical solutions,
and, just as for solving optimization problems in Chapter 10, we need
to resort to some iterative algorithms.
250 Differentiating through integration
Integration methods
To numerically solve an ODE, we can use integration methods,
whose goal is to build a sequence sk that approximates the solution
s(t) at times tk . The simplest integration method is the explicit Euler
method, that approximates the solutions between times tk−1 and tk as
Z tk
s(tk−1 ) − s(tk ) = h(t, s(t), w)dt
tk−1
for a time-step
δk := tk − tk−1 .
The resulting integration scheme consists in computing starting from
s0 = x, for k ∈ {1, . . . , K},
where Z T
f (x, w) := s(T ) = x + h(t, s(t), w)dt.
0
To solve such problems, we need to access gradients of ℓ composed with
f through VJPs of the solution of the ODE. The VJPs can actually
be characterized as solutions of an ODE themselves thanks to the
continuous time adjoint method (Pontryagin, 1985), presented
below, and whose proof is postponed to Section 11.6.6.
s(0) = x
s′ (t) = h(t, s(t), w) t ∈ [0, T ],
for Z T
g= ∂3 h(t, s(t), w)∗ r(t)dt
0
252 Differentiating through integration
The above ODE can then be solved by any integration method. Note,
however, that it requires first computing s(T ) and ∇L(s(T )) by an
integration method. The overall computation of the gradient using
an explicit Euler method to solve forward and backward ODEs is
summarized in Algorithm 11.2.
11.6. Differential equations 253
Algorithm 11.2 naturally looks like the reverse mode of autodiff for
a residual neural networks with shared weights. A striking difference
is that the intermediate computations sk are not kept in memory and,
instead, new variables ŝk are computed along the backward ODE. One
may believe that by switching to continuous time, we solved the memory
issues encountered in reverse-mode autodiff. Unfortunately, this comes
at the cost of numerical stability. As we use a discretization scheme
to recompute the intermediate states backward in time through ŝk in
Algorithm 11.2, we accumulate some truncation errors.
correct and well-defined. However, they may not match the gradients of
the true ODE formulation.
To compare the discretize-then-optimize and optimize-then-discretize
approaches, Gholaminejad et al. (2019) compared their performance on
an ODE whose solution can be computed analytically by selecting h
to be linear in s. The authors observed that discretize-then-optimize
generally outperformed optimize-then-discretize. A middle ground can
actually be found by using reversible differentiation schemes.
Leapfrog method
The (asynchronous) leapfrog method (Zhuang et al., 2021; Mutze,
2013) on the other hand is an example of symmetric reversible discretiza-
tion method. For a constant discretization step δ, given tk−1 , sk−1 , ck−1
and a function h, it computes
δ
t̄k−1 := tk−1 +
2
δ
s̄k−1 := sk−1 + ck−1
2
c̄k−1 := h(t̄k−1 , s̄k−1 )
δ
tk := t̄k−1 +
2
δ
sk := s̄k−1 + c̄k−1
2
ck := 2c̄k−1 − ck−1
M(tk−1 , sk−1 , ck−1 ; h, δ) := (tk , sk , ck ).
One can verify that we indeed have M(tk , sk , ck ; h, −δ) = (tk−1 , sk , ck ).
By using a reversible symmetric discretization scheme in the optimize-
then-discretize approach, we ensure that, at the end of the backward
discretization pass, we recover exactly the original input. Therefore, by
repeating forward and backward discretization schemes we always get
the same gradient, which was not the case for an Euler explicit scheme.
By using a reversible discretization scheme in the discretize-then-
optimize method, we address the memory issues of reverse mode autodiff.
As explained in Section 7.6, we can recompute intermediate values during
the backward pass rather than storing them.
11.6. Differential equations 257
then decompose as
∂w f (x, w)∗ [u]
= ∂w s(T, x, w)∗ u
Z T
+ (∂w s(t, x, w)∗ ∂s∗ h(t, s(t, x, w), w)∗ − ∂wt
2
s(t, x, w)∗ )r(t)dt
0
Z T
+ ∂w h(t, s(t, x, w), w)∗ r(t)dt,
0
∂x f (x, w)∗ [u]
= ∂x s(T, x, w)∗ u
Z T
+ (∂x s(t, x, w)∗ ∂s∗ h(t, s(t, x, w), w)∗ − ∂xt
2
s(t, x, w)∗ )r(t)dt
0
Here the second derivative terms ∂wt 2 s(t, x, w)∗ r, ∂ 2 s(t, x, w)∗ r cor-
xt
respond to second derivatives of ⟨s(t, x, w), r⟩. Since the Hessian is
symmetric (Schwartz’s theorem presented in Proposition 2.10), we can
swap the derivatives in t and w or x. Then, to express the gradient
uniquely in terms of first derivatives of s, we use an integration by part
to have for example
Z T Z T
2
∂wt s(t, x, w)∗ r(t)dt = 2
∂tw s(t, x, w)∗ r(t)dt
0 0
= (∂w s(T, x, w)∗ r(T ) − ∂w s(0, x, w)∗ r(0))
Z T
− ∂w s(t, x, w)∗ r(t)∗ ∂t r(t)dt.
0
Since s(0) = x, we have ∂w s(0, x, w)∗ r(0) = 0. The VJP w.r.t. w can
then be written as
∂w f (x, w)∗ [u]
= ∂w s(T, x, w)∗ [u − r(T )]
Z T
+ ∂w s(t, x, w)∗ [∂s h(t, s(t, x, w), w)∗ r(t) + ∂t r(t)]dt
0
Z T
+ ∂w h(t, s(t, x, w), w)∗ r(t)dt.
0
By choosing r(t) to satisfy the adjoint ODE
∂t r(t) = −∂s h(t, s(t, x, w), w)∗ r(t), r(T ) = u,
11.7. Summary 259
11.7 Summary
Smoothing programs
12
Smoothing by optimization
Some examples are given in Table 12.1. As can be seen from the
above, the infimal convolution, like the classical convolution, is com-
262
12.1. Primal approach 263
f1 f2 f1 □f2
f 0 inf u∈RM f (u)
ιC ∥ · ∥2 dC := inf u∈C ∥ · −u∥2
f 1 2
2 ∥ · ∥2 Mf := inf u∈RM ∥ · −u∥22 + f (u)
ιC ιD ιC+D
f ι{v} f (· − v)
f ⟨·, v⟩ ⟨·, v⟩ − f ∗ (v)
mutative,
(f □g)(µ) = (g□f )(µ).
Mf (µ) := (f □R)(µ)
1
= inf f (u) + ∥µ − u∥22
u∈R M 2
1
= inf f (µ + z) + ∥z∥22 .
z∈RM 2
1. Gradient:
∇Mf (µ) = µ − proxf (µ).
If f is convex, then
Proof.
= min f (u).
u∈RM
12.2.1 Definition
Consider affine functions of the form
u 7→ ⟨u, v⟩ − b.
The tightest lower bound is then the function such that b is defined by
Figure 12.1: For a fixed slope v, the function u 7→ uv − f ∗ (v) is the tighest affine
lower bound of f .
M
∗
f (v) = exp(vj − 1).
X
j=1
12.2.3 Properties
The conjugate enjoys several useful properties, that we now summarize.
268 Smoothing by optimization
M
∗
f (v) = fj∗ (vj ).
X
j=1
f ∗ (v) = c · g ∗ (v/c).
f ∗ (v) = g ∗ (v − α) − β.
f ∗ (v) = g ∗ (M −T v).
270 Smoothing by optimization
We now state a well-known result that will underpin this whole chapter:
smoothness and strong convexity are dual to each other (Hiriart-Urruty
and Lemaréchal, 1993; Kakade et al., 2009; Beck, 2017; Zhou, 2018).
For a review of the notions of smoothness and strong convexity, see
Section 14.4.
12.4 Examples
Its conjugate is
relu∗ (v) = ι[0,1] (v).
To notice why, we observe that since the objective is linear
1 if u ≥ 0
max uv = max uv = . (12.2)
v∈[0,1] v∈{0,1} 0 otherwise
Smoothed min
Unlike the logistic function, it can reach the exact values 0 or 1. However,
the function has two kinks, where the function is non-differentiable.
It turns out that the three sigmoids we presented above (step, logistic,
sparsesigmoid) are all equal to the derivative of their corresponding
non-linearity:
Activations Sigmoids
2 1.0
ReLU Heaviside
SoftPlus Logistic
1 SparsePlus 0.5 SparseSigmoid
0 0.0
2 1 0 1 2 2 1 0 1 2
Figure 12.2: Some ReLU functions and sigmoids. Differentiating the left functions
give the right functions.
12.5 Summary
13.1 Convolution
279
280 Smoothing by integration
Averaging perspective
where
1 1 µ−u 2
pµ,σ (u) := κσ (µ − u) = κσ (z) = √ e− 2 ( σ )
2πσ
is the PDF of the Gaussian distribution with mean µ and variance
σ 2 . Therefore, we can see f ∗ κσ as the expectation of f (u) over a
Gaussian centered around µ. This property is true for all translation-
invariant kernels, that correspond to a location-scale family distribution
(e.g., the Laplace distribution). The convolution therefore performs an
averaging with all points, with points nearby µ given more weight by
the distribution. The parameter σ controls the importance we want to
give to farther points. We call this viewpoint averaging, as we replace
f (u) by E[f (U )].
Perturbation perspective
we have
Z ∞
1 z 2
(f ∗ κσ )(µ) := f (µ − z)e− 2 ( σ ) dz
−∞
= EZ∼p0,σ [f (µ − Z)]
= EZ∼p0,σ [f (µ + Z)],
where, in the third line, we used that p0,σ is sign invariant, i.e., p0,σ (z) =
p0,σ (−z). This viewpoint shows that smoothing by convolution with
a Gaussian kernel can also be seen as injecting Gaussian noise or
perturbations to the function’s input.
Limit case
Many times, we work with functions whose convolution does not have
an analytical form. In these cases, we can use a discrete convolution on
a grid of values. For two functions f and g defined over Z, the discrete
convolution is defined by
∞
(f ∗ g)[i] := f [j]g[i − j].
X
j=−∞
j=−∞
282 Smoothing by integration
10
= 0.25
8 = 0.5
= 1.0
6
0
3 2 1 0 1 2 3
Figure 13.1: Smoothing of the signal f [t] := t2 + 0.3 sin(6πt) with a sampled and
renormalized Gaussian kernel.
M
(f ∗ g)[i] = f [i − j]g[j] = (g ∗ f )[i].
X
j=−M
13.1.4 Differentiation
Remarkably, provided that the two functions are integrable with inte-
grable derivatives, the derivative of the convolution satisfies
(f ∗ g)′ = (f ′ ∗ g) = (f ∗ g ′ ),
which simply stems from switching derivative and integral in the defini-
tion of the convolution. Moreover, we have the following proposition.
Fε := Cε {f }, Gε := Cε {g}, Hε := Cε {hε }.
We then obtain
Hε (µ) = (Fε ∗ Gε )(µ).
In the exponential domain, the convolution is therefore the counterpart
of the infimal convolution, if we replace the min-plus algebra with
the sum-product algebra. Back to log domain, we obtain
h(t) := (f ∗ g)(t).
13.2. Fourier and Laplace transforms 285
Table 13.1: Analogy between Fourier and Legendre transforms. See Proposition 12.3
for more conjugate calculus rules.
where
1 1 ∥ · ∥22
!
Gε := C = exp −
∥ · ∥22
2 2 ε
1 1 1
Qε {f } := C f (·) − ∥ · ∥2 = exp
2
∥ · ∥2 − f (·)
2
2 2ε ε
1
Qε {F } := ∥ · ∥2 − ε log(F (·)).
−1 2
2
This insight was tweeted by Gabriel Peyré in April 2020.
Proof.
1 1
Z
fε (v) := ε log exp ⟨u, v⟩ − f (u)) du
ε ε
1 1 1 1
Z
= ε log exp − ∥u − v∥22 + ∥u∥22 + ∥v∥2 − f (u)) du
2
2ε 2ε 2ε ε
1 1 1 1
Z
= ε log exp − ∥u − v∥22 + ∥u∥22 − f (u)) du + ∥v∥22
2ε 2ε ε 2
1
Z
= ε log Gε (v − u)Qε {f }(u)du + ∥v∥22
2
1
= ε log(Qε {f } ∗ Gε )(v) + ∥v∥2
2
1 1
= ∥v∥ − ε log
2
(v)
2 Qε {f } ∗ Gε
1
= Q−1ε (v)
Qε {f } ∗ Gε
What did we gain from this viewpoint? The convex conjugate can
often be difficult to compute in closed form. If we replace RM with a dis-
crete set S (i.e., a grid), we can then approximate the smoothed convex
conjugate in O(n log n), where n = |S|, using a discrete convolution,
u∈S
= Kq,