可微分编程deepmind
可微分编程deepmind
The Elements of
Differentiable Programming
Mathieu Blondel
Google DeepMind
[email protected]
Vincent Roulet
Google DeepMind
[email protected]
1 Introduction 6
1.1 What is differentiable programming? . . . . . . . . . . . . 6
1.2 Book goals and scope . . . . . . . . . . . . . . . . . . . . 9
1.3 Intended audience . . . . . . . . . . . . . . . . . . . . . . 10
1.4 How to read this book? . . . . . . . . . . . . . . . . . . . 10
1.5 Related work . . . . . . . . . . . . . . . . . . . . . . . . . 10
I Fundamentals 12
2 Differentiation 13
2.1 Univariate functions . . . . . . . . . . . . . . . . . . . . . 13
2.1.1 Derivatives . . . . . . . . . . . . . . . . . . . . . . 13
2.1.2 Calculus rules . . . . . . . . . . . . . . . . . . . . 17
2.1.3 Leibniz’s notation . . . . . . . . . . . . . . . . . . 19
2.2 Multivariate functions . . . . . . . . . . . . . . . . . . . . 20
2.2.1 Directional derivatives . . . . . . . . . . . . . . . . 20
2.2.2 Gradients . . . . . . . . . . . . . . . . . . . . . . 21
2.2.3 Jacobians . . . . . . . . . . . . . . . . . . . . . . 25
2.3 Linear differentiation maps . . . . . . . . . . . . . . . . . 30
2.3.1 The need for linear maps . . . . . . . . . . . . . . 31
2.3.2 Euclidean spaces . . . . . . . . . . . . . . . . . . . 32
2.3.3 Linear maps and their adjoints . . . . . . . . . . . 33
2.3.4 Jacobian-vector products . . . . . . . . . . . . . . 33
2.3.5 Vector-Jacobian products . . . . . . . . . . . . . . 35
2.3.6 Chain rule . . . . . . . . . . . . . . . . . . . . . . 36
2.3.7 Functions of multiple inputs (fan-in) . . . . . . . . 36
2.3.8 Functions with multiple outputs (fan-out) . . . . . 38
2.3.9 Extensions to non-Euclidean linear spaces . . . . . 39
2.4 Second-order differentiation . . . . . . . . . . . . . . . . . 40
2.4.1 Second derivatives . . . . . . . . . . . . . . . . . . 40
2.4.2 Second directional derivatives . . . . . . . . . . . . 41
2.4.3 Hessians . . . . . . . . . . . . . . . . . . . . . . . 42
2.4.4 Hessian-vector products . . . . . . . . . . . . . . . 43
2.4.5 Second-order Jacobians . . . . . . . . . . . . . . . 44
2.5 Higher-order differentiation . . . . . . . . . . . . . . . . . 45
2.5.1 Higher-order derivatives . . . . . . . . . . . . . . . 45
2.5.2 Higher-order directional derivatives . . . . . . . . . 45
2.5.3 Higher-order Jacobians . . . . . . . . . . . . . . . 46
2.5.4 Taylor expansions . . . . . . . . . . . . . . . . . . 46
2.6 Differential geometry . . . . . . . . . . . . . . . . . . . . 47
2.6.1 Differentiability on manifolds . . . . . . . . . . . . 48
2.6.2 Tangent spaces and pushforward operators . . . . . 48
2.6.3 Cotangent spaces and pullback operators . . . . . 50
2.7 Generalized derivatives . . . . . . . . . . . . . . . . . . . 53
2.7.1 Rademacher’s theorem . . . . . . . . . . . . . . . 53
2.7.2 Clarke derivatives . . . . . . . . . . . . . . . . . . 54
2.8 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 56
3 Probabilistic learning 59
3.1 Probability distributions . . . . . . . . . . . . . . . . . . . 59
3.1.1 Discrete probability distributions . . . . . . . . . . 59
3.1.2 Continuous probability distributions . . . . . . . . 60
3.2 Maximum likelihood estimation . . . . . . . . . . . . . . . 61
3.2.1 Negative log-likelihood . . . . . . . . . . . . . . . 61
3.2.2 Consistency w.r.t. the Kullback-Leibler divergence . 61
3.3 Probabilistic supervised learning . . . . . . . . . . . . . . 62
3.3.1 Conditional probability distributions . . . . . . . . 62
3.3.2 Inference . . . . . . . . . . . . . . . . . . . . . . . 62
3.3.3 Binary classification . . . . . . . . . . . . . . . . . 63
3.3.4 Multiclass classification . . . . . . . . . . . . . . . 65
3.3.5 Regression . . . . . . . . . . . . . . . . . . . . . . 67
3.3.6 Multivariate regression . . . . . . . . . . . . . . . 68
3.3.7 Integer regression . . . . . . . . . . . . . . . . . . 69
3.3.8 Loss functions . . . . . . . . . . . . . . . . . . . . 70
3.4 Exponential family distributions . . . . . . . . . . . . . . . 71
3.4.1 Definition . . . . . . . . . . . . . . . . . . . . . . 71
3.4.2 The log-partition function . . . . . . . . . . . . . . 72
3.4.3 Maximum entropy principle . . . . . . . . . . . . . 74
3.4.4 Maximum likelihood estimation . . . . . . . . . . . 75
3.4.5 Probabilistic learning with exponential families . . . 76
3.5 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 77
II Differentiable programs 79
4 Parameterized programs 80
4.1 Representing computer programs . . . . . . . . . . . . . . 80
4.1.1 Computation chains . . . . . . . . . . . . . . . . . 80
4.1.2 Directed acylic graphs . . . . . . . . . . . . . . . . 81
4.1.3 Computer programs as DAGs . . . . . . . . . . . . 83
4.1.4 Arithmetic circuits . . . . . . . . . . . . . . . . . . 85
4.2 Feedforward networks . . . . . . . . . . . . . . . . . . . . 86
4.3 Multilayer perceptrons . . . . . . . . . . . . . . . . . . . . 87
4.3.1 Combining affine layers and activations . . . . . . . 87
4.3.2 Link with generalized linear models . . . . . . . . . 87
4.4 Activation functions . . . . . . . . . . . . . . . . . . . . . 88
4.4.1 ReLU and softplus . . . . . . . . . . . . . . . . . . 88
4.4.2 Max pooling and log-sum-exp . . . . . . . . . . . . 89
4.4.3 Sigmoids: binary step and logistic functions . . . . 90
4.4.4 Probability mappings: argmax and softargmax . . . 91
4.5 Residual neural networks . . . . . . . . . . . . . . . . . . 93
4.6 Recurrent neural networks . . . . . . . . . . . . . . . . . . 94
4.6.1 Vector to sequence . . . . . . . . . . . . . . . . . 95
4.6.2 Sequence to vector . . . . . . . . . . . . . . . . . 96
4.6.3 Sequence to sequence (aligned) . . . . . . . . . . . 96
4.6.4 Sequence to sequence (unaligned) . . . . . . . . . 97
4.7 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 98
5 Control flows 99
5.1 Comparison operators . . . . . . . . . . . . . . . . . . . . 99
5.2 Soft inequality operators . . . . . . . . . . . . . . . . . . 101
5.2.1 Heuristic definition . . . . . . . . . . . . . . . . . 101
5.2.2 Stochastic process perspective . . . . . . . . . . . 102
5.3 Soft equality operators . . . . . . . . . . . . . . . . . . . 104
5.3.1 Heuristic definition . . . . . . . . . . . . . . . . . 104
5.3.2 Stochastic process perspective . . . . . . . . . . . 106
5.3.3 Gaussian process perspective . . . . . . . . . . . . 109
5.4 Logical operators . . . . . . . . . . . . . . . . . . . . . . 110
5.5 Continuous extensions of logical operators . . . . . . . . . 111
5.5.1 Probabilistic continuous extension . . . . . . . . . 111
5.5.2 Triangular norms and co-norms . . . . . . . . . . . 113
5.6 If-else statements . . . . . . . . . . . . . . . . . . . . . . 114
5.6.1 Differentiating through branch variables . . . . . . 115
5.6.2 Differentiating through predicate variables . . . . . 116
5.6.3 Continuous relaxations . . . . . . . . . . . . . . . 117
5.7 Else-if statements . . . . . . . . . . . . . . . . . . . . . . 120
5.7.1 Encoding K branches . . . . . . . . . . . . . . . . 120
5.7.2 Conditionals . . . . . . . . . . . . . . . . . . . . . 121
5.7.3 Differentiating through branch variables . . . . . . 122
5.7.4 Differentiating through predicate variables . . . . . 123
5.7.5 Continuous relaxations . . . . . . . . . . . . . . . 124
5.8 For loops . . . . . . . . . . . . . . . . . . . . . . . . . . . 125
5.9 Scan functions . . . . . . . . . . . . . . . . . . . . . . . . 127
5.10 While loops . . . . . . . . . . . . . . . . . . . . . . . . . 128
5.10.1 While loops as cyclic graphs . . . . . . . . . . . . 128
5.10.2 Unrolled while loops . . . . . . . . . . . . . . . . . 129
5.10.3 Markov chain perspective . . . . . . . . . . . . . . 132
5.11 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 135
6 Data structures 136
6.1 Lists . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 136
6.1.1 Basic operations . . . . . . . . . . . . . . . . . . . 137
6.1.2 Operations on variable-length lists . . . . . . . . . 138
6.1.3 Continuous relaxations using soft indexing . . . . . 140
6.2 Dictionaries . . . . . . . . . . . . . . . . . . . . . . . . . 143
6.2.1 Basic operations . . . . . . . . . . . . . . . . . . . 143
6.2.2 Continuous relaxation using kernel regression . . . 145
6.2.3 Discrete probability distribution perspective . . . . 146
6.2.4 Link with attention in Transformers . . . . . . . . 147
6.3 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 148
18 Duality 413
18.1 Dual norms . . . . . . . . . . . . . . . . . . . . . . . . . 413
18.2 Fenchel duality . . . . . . . . . . . . . . . . . . . . . . . . 414
18.3 Bregman divergences . . . . . . . . . . . . . . . . . . . . 417
18.4 Fenchel-Young loss functions . . . . . . . . . . . . . . . . 420
18.5 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 421
References 422
The Elements of
Differentiable Programming
Mathieu Blondel1 and Vincent Roulet1
1 Google DeepMind
ABSTRACT
Artificial intelligence has recently experienced remarkable
advances, fueled by large models, vast datasets, acceler-
ated hardware, and, last but not least, the transformative
power of differentiable programming. This new programming
paradigm enables end-to-end differentiation of complex com-
puter programs (including those with control flows and data
structures), making gradient-based optimization of program
parameters possible.
As an emerging paradigm, differentiable programming builds
upon several areas of computer science and applied mathe-
matics, including automatic differentiation, graphical mod-
els, optimization and statistics. This book presents a com-
prehensive review of the fundamental concepts useful for
differentiable programming. We adopt two main perspec-
tives, that of optimization and that of probability, with clear
analogies between the two.
Differentiable programming is not merely the differentiation
of programs, but also the thoughtful design of programs
intended for differentiation. By making programs differen-
tiable, we inherently introduce probability distributions over
their execution, providing a means to quantify the uncer-
tainty associated with program outputs.
Acknowledgements
2
Source code
3
Notation
Notation Description
X ⊆R D
Input space (e.g., features)
Y ⊆ RM Output space (e.g., classes)
Sk ⊆ RDk Output space on layer or state k
W ⊆ RP Weight space
Λ ⊆ RQ Hyperparameter space
Θ ⊆ RR Distribution parameter space, logit space
N Number of training samples
T Number of optimization iterations
x∈X Input vector
y∈Y Target vector
sk ∈ Sk State vector k
w∈W Network (model) weights
λ∈Λ Hyperparameters
θ∈Θ Distribution parameters, logits
π ∈ [0, 1] Probability value
π ∈ △M Probability vector
4
5
Notation Description
f Network function
f (·; x) Network function with x fixed
L Objective function
ℓ Loss function
κ Kernel function
ϕ Output embedding, sufficient statistic
step Heaviside step function
logisticσ Logistic function with temperature σ
logistic Shorthand for logistic1
pθ Model distribution with parameters θ
ρ Data distribution over X × Y
ρX Data distribution over X
µ, σ 2 Mean and variance
Z Random noise variable
1
Introduction
6
1.1. What is differentiable programming? 7
This book does not need to be read linearly chapter by chapter. When
needed, we indicate at the beginning of a chapter what chapters are
recommended to be read as a prerequisite.
Fundamentals
2
Differentiation
2.1.1 Derivatives
To study functions, such as defining their derivatives, we need to capture
their infinitesimal variations around points as defined by the notion of
limit.
13
14 Differentiation
if
|g(v)|
lim = 0.
v→w|f (v)|
That is, the function f dominates g in the limit v → w. For example,
f is continuous at w if and only if
f (w + δ) = f (w) + o(1) as δ → 0.
Continuous
Discontinuous at 0 non-differentiable at 1 and -1 1.0 Differentiable everywhere
1.0 1.0
ei := (0, . . . , 0, |{z}
1 , 0, . . . , 0).
i
2.2.2 Gradients
We now introduce the gradient vector, which gathers the partial deriva-
tives. We first recall the definitions of linear map and linear form.
partial derivatives
∂1 f (w) ∂f (w)[e1 ]
∇f (w) := .
.. ..
=
. .
∂P f (w) ∂f (w)[eP ]
P
∂f (w)[v] = vi ∂f (w)[ei ] = ⟨v, ∇f (w)⟩.
X
i=1
⟨a, v⟩)/∥v∥2 = 0 for any v and in particular for ∥v∥ → 0. Moreover, its
gradient is naturally given by ∇f (w) = a.
Generally, to show that a function is differentiable and find its
gradient, one approach is to approximate f (w + v) around v = 0. If we
can find a vector g such that
f (w + v) = f (w) + ⟨g, v⟩ + o(∥v∥2 ),
then f is differentiable at w since ⟨g, ·⟩ is linear. Moreover, g is then
the gradient of f at w.
Remark 2.3 (Gateaux and Fréchet differentiability). Multiple defini-
tions of differentiability exist. The one presented in Definition 2.6 is
about Fréchet differentiable functions. Alternatively, if f : RP →
R has well-defined directional derivatives along any directions then
the function is Gateaux differentiable. Note that the existence of
directional derivatives in any directions is not a sufficient condition
for the function to be differentiable. In other words, any Fréchet
2.2. Multivariate functions 23
i=1
product, we get
Linearity of gradients
The notion of differentiability for multi-input functions naturally inherits
from the linearity of derivatives for single-input functions. For any
u1 , . . . , uM ∈ R and any multi-input functions f1 , . . . , fM differentiable
at w, the function u1 f1 + . . . + uM fM is differentiable at w and its
gradient is
2.2.3 Jacobians
Let us now consider a multi-output function f : RP → RM defined by
f (w) := (f1 (w), . . . , fM (w)), where fj : RP → R. A typical example
in machine learning is a neural network. The notion of directional
derivative can be extended to such function by defining it as the vector
composed of the coordinate-wise directional derivatives:
f (w+δv)−f (w)
1 1
δ
f (w + δv) − f (w) ..
∂f (w)[v] := lim = lim ∈ RM ,
δ→0 δ δ→0
.
fM (w+δv)−fM (w)
δ
where the limits (provided that they exist) are applied coordinate-wise.
The directional derivative of f in the direction v ∈ RP is therefore the
vector that gathers the directional derivative of each fj , i.e., ∂f (w)[v] =
(∂fj (w)[v])M
j=1 . In particular, we can define the partial derivatives of
26 Differentiation
f at w as the vectors
∂i f1 (w)
∂i f (w) := ∂f (w)[ei ] =
.. ∈ RM .
.
∂i fM (w)
As for the usual definition of the derivative, the directional derivative
can provide a linear approximation of a function around a current input
as illustrated in Fig. 2.4 for a parametric curve f : R → R2 .
Just as in the single-output case, differentiability is defined not only
as the existence of directional derivatives in any direction but also by
the linearity in the chosen direction.
∂1 f1 (w) . . . ∂P f1 (w)
∂f (w) :=
.. .. .. ∈ RM ×P .
. . .
∂1 fM (w) . . . ∂P fM (w)
The Jacobian can be represented by stacking columns of partial
derivatives or rows of gradients,
∇f1 (w)⊤
..
∂f (w) = ∂1 f (w), . . . , ∂P f (w) =
. .
∇fM (w)⊤
2.2. Multivariate functions 27
P
∂f (w)[v] = vi ∂i f (w) = ∂f (w)v ∈ RM .
X
i=1
f1′ (w)
.
.. ∈ R
∂f (w) = f ′ (w) := M ×1
.
′ (w)
fM
∂f (w) = f ′ (w) ∈ R.
28 Differentiation
Example 2.4 illustrates the form of the Jacobian matrix for the
element-wise application of a differentiable function such as the softplus
activation. This example already shows that the Jacobian takes a simple
diagonal matrix form. As a consequence, the directional derivative
associated with this function is simply given by an element-wise product
rather than a full matrix-vector product as suggested in Definition 2.9.
We will revisit this point in Section 2.3.
σ(w1 )
.
.. ∈ R
f (w) := where σ(w) := log(1 + ew ).
P
σ(wP )
Since σ is differentiable, each coordinate of this function is differen-
tiable and the overall function is differentiable. The j th coordinate of
f is independent of the ith coordinate of w for i = ̸ j, so ∂i fj (w) = 0
for i ̸= j. For i = j, the result boils down to the derivative of σ
at wj . That is, ∂j fj (w) = σ ′ (wj ), where σ ′ (w) = ew /(1 + ew ). The
Jacobian of f is therefore a diagonal matrix
σ (w1 ) 0 ... 0
′
.. .. ..
. .
0 .
∂f (w) = diag(σ ′ (w1 ), . . . , σ ′ (wP )) := . .
. .. ..
. . . 0
0 ... 0 σ ′ (w P)
M
∇⟨u, f ⟩(w) = uj ∇fj (w) = ∂f (w)⊤ u ∈ RP ,
X
j=1
Chain rule
Equipped with a generic definition of differentiability and the associated
objects, gradients and Jacobians, we can now generalize the chain rule,
previously introduced for single-input single-output functions.
y = (y1 , . . . , yN )⊤ ∈ RN .
The function f can be decomposed into a linear mapping
f1 (w) = Xw and a squared error f2 (p) = ∥p − y∥22 , so that
f = f2 ◦ f1 . We can then apply the chain rule in Proposition 2.3 to
get
∇f (w) = ∂f1 (w)⊤ ∇f2 (f1 (w))
provided that f1 , f2 are differentiable at w and f1 (w), respectively.
The function f1 is linear so differentiable with Jacobian ∂f1 (w) =
X. On the other hand the partial derivatives of f2 are given by
∂j f2 (p) = 2(pj − yj ) for j ∈ {1, . . . , N }. Therefore, f2 is differen-
tiable at any p and its gradient is ∇f2 (p) = 2(p − y). By combining
the computations of the Jacobian of f1 and the gradient of f2 , we
then get the gradient of f as
Linear spaces, a.k.a. vector spaces, are spaces equipped (and closed
under) an addition rule compatible with multiplication by a scalar
(we limit ourselves to the field of reals). Namely, in a linear space E,
there exists the operations + and ·, such that for any u, v ∈ E, and
a ∈ R, we have u + v ∈ E and a · u ∈ E. Euclidean spaces are linear
spaces equipped with a basis e1 , . . . , eP ∈ E. Any element v ∈ E can be
decomposed as v = Pi=1 vi ei for some unique scalars v1 , . . . , vP ∈ R. A
P
RP1 ×P2 by
PX
1 ,P2
where tr(Z) := i=1 Zii is the trace operator defined for square matrices
PP
∂f : E → (E → F).
∂f (·)∗ : E → (F → E).
As this is true for any v ∈ E, ∂f (w)∗ [u] is the gradient of ⟨u, f ⟩F per
Proposition 2.4.
The proof follows the one of Proposition 2.2. When the last function
is scalar-valued, which is often the case in machine learning, we obtain
the following simplified result.
i=1
∂f (x, W )[v, V ] = W v + V x ∈ F.
∂1 f (x, W )∗ [u] = W ⊤ u ∈ E1 ,
∂2 f (x, W )∗ [u] = ux⊤ ∈ E2 .
38 Differentiation
i=1
T
∂h(w)[v] = ∂i f (g(w))[∂gi (w)[v]].
X
i=1
Here, we cannot find a finite number of functions that can express all
possible functions on R. Therefore, this space is not a mere Euclidean
space. Nevertheless, we can consider functions on this Hilbert space
(called functionals to distinguish them from the elements of the space).
The associated directional derivatives and gradients, can be defined
and are called respectively, functional derivative and functional
gradient, see, e.g., Frigyik et al. (2008) and references therein.
Figure 2.5: Points at which the second derivative is small are points along which
the function is well approximated by its tangent line. On the other hand, point with
large second derivative tend to be badly approximated by the tangent line.
2.4.3 Hessians
For a multi-input function, twice differentiability is simply defined as
the differentiability of any directional derivative ∂f (w)[v] w.r.t. w.
∇ f (w) :=
2
.. .. ..
. . . ,
∂P 1 f (w) . . . ∂P P f (w)
provided that all second partial derivatives are well-defined.
The second directional derivative at w is bilinear in any direc-
tions v = Pi=1 vi ei and v ′ = Pi=1 vi′ ei . Therefore,
P P
P
∂ 2 f (w)[v, v ′ ] = vi vj′ ∂ 2 f (w)[ei , ej ] = ⟨v, ∇2 f (w)v ′ ⟩.
X
i,j=1
∂ 2 f (w)[v, v ′ ] = ⟨v ′ , ∇2 f (w)[v]⟩.
The function f is twice differentiable if and only if all its coordinates are
twice differentiable. The second directional derivative is then a bilinear
map. We can then compute second directional derivatives as
P
∂ 2 f (w)[v, v ′ ] = vi vj′ ∂ 2 f (w)[ei , ej ] = (⟨v, ∇2 fj (w)v ′ ⟩)M
X
j=1 .
i,j=1
∂ n f (w)[v1 , . . . , vn ]
= ∂(∂ n−1 f (w)[v1 , . . . , vn−1 ])[vn ]
∂f (w + δvn )[v1 , . . . , vn−1 ] − ∂f (w)[v1 , . . . , vn−1 ]
= lim
δ→0 δ
i1 ,...,in =1
i1 ,...,in =1
In the case of the sphere in Fig. 2.6, the tangent space is a plane, that
is, a Euclidean space. This property is generally true: tangent spaces
are Euclidean spaces such that we will be able to define directional
derivatives as linear operators. Now, if f is differentiable and goes from
a manifold M to a manifold N , then f ◦ α is a differentiable curve in N .
Therefore, (f ◦ α)′ (0) is the derivative of a curve passing through f (w)
at 0 and is tangent to N at f (w). Hence, the directional derivative
of f : M → N at w can be defined as a function from the tangent
space Tw M of M at w onto the tangent space Tf (w) N of N at f (w).
Overall, we built the directional derivative (JVP) by considering how a
composition of f with any curve α pushes forward the derivative of α
into the derivative of f ◦ α. The resulting JVP is called a pushforward
operator in differentiable geometry.
Note that elements of the cotangent space are linear mappings, not
vectors. This distinction is important to define the pullback operator
as an operator on functions as done in measure theory. From a linear
algebra viewpoint, the cotangent space is exactly the dual space of
Ty N , that is, the set of linear maps from Ty N to R, called linear forms.
As Ty N is a Euclidean space, its dual space Ty∗ N is also a Euclidean
space. The pullback operator is then defined as the operator that gives
2.6. Differential geometry 51
Function f M→N
Push-forward ∂f (w) Tw M → Tf (w) N
Pullback ∂f (w)⋆ Tf∗(w) N → Tw∗ M
Adjoint of pushforward ∂f (w)∗ Tf (w) N → Tw M
Table 2.1: For a differentiable function f defined from a manifold M onto a manifold
N , the JVP is generalized with the notion of pushforward ∂f (w). The counterpart
of the pushforward is the pullback operation ∂f (w)⋆ that acts on linear forms in the
tangent spaces. For Riemannian manifolds, the pullback operation can be identified
with the adjoint operator ∂f (w)∗ of the pushforward operator as any linear form is
represented by a vector.
we have
Tw M = Null(2⟨w, ·⟩)
= {v ∈ RP : ⟨w, v⟩ = 0}
i=1
Jacobian of f at w ∈ E is an element J of
conv lim ∂f (vn ) : (vn )+∞
n=1 s.t. vn ∈ E \ Ω, vn → w .
n→+∞ n→+∞
2.8 Summary
P(Y = y) := p(y),
E[ϕ(Y )] =
X
p(y)ϕ(y),
y∈Y
59
60 Probabilistic learning
y∈Y
the variance is
Z
V[ϕ(Y )] = E[(ϕ(Y ) − E[ϕ(Y )])2 ] = p(y)(ϕ(y) − E[ϕ(Y )])2 dy
Y
1 1
2 !
y−µ
pλ (y) := √ exp − .
σ 2π 2 σ
p and
p(Y )
λ∞ := arg min KL(p, pλ ) = EY ∼p log ,
λ∈Λ pλ (Y )
then λ
b N → λ∞ in expectation over the observations, as N → ∞. This
can be seen by using
N N
1 X p(yi ) 1 X
KL(p, pλ ) ≈ log = log p(yi ) − log pλ (yi )
N i=1 pλ (yi ) N i=1
and the law of large numbers.
3.3.2 Inference
The main advantage of this probabilistic approach is that our prediction
model is much richer than if we just learned a function from X to Y.
3.3. Probabilistic supervised learning 63
Figure 3.1: The Bernoulli distribution, whose PMF and CDF are here illustrated
with parameter π = 0.8. Its mean function is π = A′ (θ) = logistic(θ) = 1+exp(−θ)
1
,
where θ is for instance the output of a neural network. The negative log-likelihood
leads to the logistic loss, L(θ, y) = softplus(θ) − θy = log(1 + exp(θ)) − θy. The
loss curve is shown for y ∈ {0, 1}.
Remark 3.1 (Link with the logistic distribution). The logistic distri-
bution with mean and scale parameters µ and σ is a continuous
probability distribution with PDF
u−µ
pµ,σ (u) := p0,1
σ
where
exp(−z)
p0,1 (z) :=.
(1 + exp(−z))2
If a random variable U follow a logistic distribution with pa-
rameters µ and σ, we write U ∼ Logistic(µ, σ). The CDF of
3.3. Probabilistic supervised learning 65
U ∼ Logistic(µ, σ) is
Z u
u−µ
P(U ≤ u) = pµ,σ (u)du = logistic .
−∞ σ
Therefore, if
U ∼ Logistic(µ, σ)
and
u−µ
Y ∼ Bernoulli logistic ,
σ
then
P(Y = 1) = P(U ≤ u).
Here, U can be interpreted as a latent continuous variable and u
as a threshold.
Figure 3.2: The categorical distribution, whose PMF and CDF are here il-
lustrated with parameter π = (0.3, 0.6, 0.1). Its mean function is π = ∇A(θ) =
softargmax(θ), where θ ∈ RM is for instance the output of a neural network. Here,
for illustration purpose, we choose to set θ = (s, 1, 0) and vary only s. Since the
mean function ∇A(θ) belongs to R3 , we choose to display ⟨∇A(θ), ei ⟩ = ∇A(θ)i ,
for i ∈ {1, 2, 3}. The negative log-likelihood leads to the logistic loss, L(θ, y) =
logsumexp(θ) − ⟨θ, y⟩. The loss curve is shown for y ∈ {e1 , e2 , e3 }, again with
θ = (s, 1, 0) and varying s.
Therefore, as was also the case for the Bernoulli distribution, the mean
and the probability distribution (represented by the vector π) are the
same in this case.
exp(u)
softargmax(u) := P ∈ relint(△M ).
j exp(uj )
3.3.5 Regression
For real outcomes, where Y = R, we can use, among other choices, a
normal distribution with parameters
λ := (µ, σ),
where µ ∈ R is the mean parameter and σ ∈ R+ is the standard
deviation parameter. The PDF is
1 1 (y − µ)2
!
pµ,σ (y) := √ exp − .
σ 2π 2 σ2
The expectation is
EY ∼pµ,σ [Y ] = µ.
One advantage of the probabilistic perspective is that we are not limited
to predicting the mean. We can also compute the CDF
1 y−µ
P(Y ≤ y) = 1 + erf √ , (3.1)
2 σ 2
where we used the error function Z
2 z 2
erf(z) := √ e−t dt.
π 0
This function is available in most scientific computing libraries, such as
SciPy (Virtanen et al., 2020). From the CDF, we also easily obtain
1 b−µ a−µ
P(a < Y ≤ b) = erf √ − erf √ .
2 σ 2 σ 2
Parameterization
Typically, in regression, the mean is output by a model, while the
standard deviation σ is kept fixed (typically set to 1). Since µ is uncon-
strained, we can simply set
µ := f (x, w) ∈ R,
where f : X × W → R is for example a neural network. That is, the
output of f is the mean of the distribution,
EY ∼pµ,1 [Y ] = µ = f (x, w).
We can also use µ to predict P(Y ≤ y) or P(a < Y ≤ b), as shown
above.
68 Probabilistic learning
Figure 3.3: The Gaussian distribution, with mean parameter µ and variance
σ 2 = 1. Its mean function is µ = A′ (θ) = θ, where θ is for instance the output of a
neural network. The negative log-likelihood leads to the squared loss, L(θ, y) =
(y − θ)2 . The loss curve is shown for y ∈ {−2, 0, 2}.
Parameterization
Typically, in multivariate regression, the mean is output by a model,
while the covariance matrix is kept fixed (typically set to the identity
matrix). Since µ is again unconstrained, we can simply set
µ := f (x, w) ∈ RM .
More generally, we can parametrize the function f so as to output both
the mean µ and the covariance matrix Σ, i.e.,
(µ, Σ) := f (x, w) ∈ RM × RM ×M .
3.3. Probabilistic supervised learning 69
Figure 3.4: The Poisson distribution, with mean parameter λ. For the PMF
and the CDF, the lines between markers are shown for visual aid: the Poisson
distribution does not assign probability mass to non-integer values. Its mean function
is λ = A′ (θ) = exp(θ), where θ is for instance the output of a neural network.
The negative log-likelihood leads to the Poisson loss, L(θ, y) = − log pλ (y) =
−yθ + exp(θ) + log(y!), which is a convex function of θ. The loss curve is shown for
y ∈ {1, 4, 10}.
λy exp(−λ)
P(Y = y) = pλ (y) := .
y!
Its CDF is
y
P(Y ≤ y) = P(Y = y).
X
i=0
E[Y ] = V[Y ] = λ.
j=1
3.4.1 Definition
The exponential family is a class of probability distributions, whose
PMF or PDF can be written in the form
h(y) exp [⟨θ, ϕ(y)⟩]
pθ (y) =
exp(A(θ))
= h(y) exp [⟨θ, ϕ(y)⟩ − A(θ)] ,
where θ are the natural or canonical parameters of the distribution.
The function h is known as the base measure. The function ϕ is the
sufficient statistic: it holds all the information about y and is used
to embed y in a vector space. The function A is the log-partition
or log-normalizer (see below for a details). All the distributions we
reviewed in Section 3.3 belong to the exponential family. With some
abuse of notation, we use pλ for the distribution in original form and
pθ for the distribution in exponential family form. As we will see,
we can go from θ to λ and vice-versa. We illustrate how to rewrite a
distribution in exponential family form below.
72 Probabilistic learning
pλ (y) := π y (1 − π)1−y
= exp(log(π y (1 − π)1−y ))
= exp(y log(π) + (1 − y) log(1 − π))
= exp(log(π/(1 − π))y + log(1 − π))
= exp(θy − log(1 + exp(θ)))
= exp(θy − softplus(θ))
=: pθ (y).
y∈Y
y∈Y
and similarly for continuous variables. Here, we defined the affine map
B(θ) := (⟨θ, ϕ(y)⟩ + log h(y))y∈Y .
3.4. Exponential family distributions 73
Bernoulli Categorical
Y {0, 1} [M ]
λ π = logistic(θ) π = softmax(θ)
θ logit(π) log π + exp(A(θ))
ϕ(y) y ey
A(θ) softplus(θ) logsumexp(θ)
h(y) 1 1
µ
θ σ ( σµ2 , 2σ
−1
2)
y
ϕ(y) σ (y, y 2 )
θ2 µ2 −θ12 µ2
A(θ) 2 = 2σ 2 4θ2 − 12 log(−2θ2 ) = σ2
+ log σ
2
exp( −y2 )
h(y) √ 2σ √1
2πσ 2π
where conv(S) is the convex hull of S and P(Y) is the set of valid
probability distributions over Y.
Similarly, the Hessian ∇2 A(θ) coincides with the covariance matrix
of ϕ(Y ) according to pθ (Wainwright and Jordan, 2008, Chapter 3).
When the exponential family is minimal, which means that the
parameters θ uniquely identify the distribution, it is known that ∇A
is a one-to-one mapping from Θ to M. That is, µ(θ) = ∇A(θ) and
θ = (∇A)−1 (µ(θ)).
y∈Y
Its gradient is
which is independent of y.
Many times, Θ will be the entire RM but this is not always the case.
For instance, as we previously discussed, for a multivariate normal
distribution, where θ = (µ, Σ) = f (x, w), we need to ensure that Σ is
a positive semidefinite matrix.
Training
Given input-output pairs {(xi , yi )}N
i=1 , we then seek to find the param-
eters w of f (x, w) by minimizing the negative log-likelihood
N N
arg min − log pθi (yi ) = arg min A(θi ) − ⟨θi , ϕ(yi )⟩
X X
w∈W i=1 w∈W i=1
3.5. Summary 77
Inference
Once we found w by minimizing the objective function above, there are
several possible strategies to perform inference for a new input x.
µ = p = ∇A(θ) = softargmax(θ) ∈ △M .
3.5 Summary
Differentiable programs
4
Parameterized programs
80
4.1. Representing computer programs 81
... ...
s0 ∈ S0
s1 := f1 (s0 ) ∈ S1
..
.
sK := fK (sK−1 ) ∈ SK
f (s0 ) := sK . (4.1)
Figure 4.2: Example of a directed acyclic graph. Here the nodes are V =
{0, 1, 2, 3, 4}, the edges are E = {(0, 1), (0, 2), (0, 3), (1, 3), (2, 3), (1, 4), (3, 4)}. Parents
of the node 3 are pa(3) = {0, 1, 2}. Children of node 1 are ch(1) = {3, 4}. There is
a unique root, 0, and a unique leaf, 4; 0 → 3 → 4 is a path from 0 to 4. This is an
acyclic graph since there is no cycle (i.e., a path from a node to itself). We can order
nodes 0 and 3 as 0 ≤ 3 since there is no path from 3 to 0. Similarly, we can order 1
and 2 as 1 ≤ 2 since there is no path from 2 to 1. Two possible topological orders of
the nodes are (0, 1, 2, 3, 4) and (0, 2, 1, 3, 4).
√
Figure 4.3: Representation of f (x1 , x2 ) = x2 ex1 x1 + x2 ex1 as a DAG, with
functions and variables as nodes. Edges indicate function and variable dependencies.
The function f is decomposed as 8 elementary functions in topological order.
Executing a program
When a program has multiple inputs, we can always group them into
s0 ∈ S0 as s0 = (s0,1 , . . . , s0,N0 ) with S0 = (S0,1 × · · · × S0,N0 ), since
later functions can always filter out what elements of s0 they need.
Likewise, if an intermediate function fk has multiple outputs, we can
always group them as a single output sk = (sk,1 , . . . , sk,Nk ) with Sk =
4.1. Representing computer programs 85
Figure 4.4: Two possible representations of a program. Left: Functions and output
variables are represented by the same nodes. Right: functions and variables are
represented by a disjoint set of nodes.
(Sk,1 × · · · × Sk,Nk ), since later functions can filter out the elements of
sk that they need.
s0 := x
s1 := f1 (s0 , w1 )
s2 := f2 (s1 , w2 )
..
.
sK := fK (sK−1 , wK ),
s1 = a1 (W1 x + b1 ).
sK = aK (WK sK−1 + bK ).
sk = relu(Ak sk−1 + bk ),
softplus(u) := log(1 + eu ).
Max pooling
max(u) := max uj .
j∈[M ]
logsumexp(u) = logsumexp(u − c 1) + c,
where c := maxj∈[M ] uj .
More generally, we can introduce a temperature parameter γ > 0
then have
M
s̃ = log exp(ũi ).
X
j=1
Written differently, we have the identity
M
!
log = logsumexp(log(u)).
X
ui
i=1
j=1 i=1
at all other points, it has zero derivative at these points, which makes it
difficult to use as part of a neural network trained with backpropagation.
Logistic function
A better sigmoid is the logistic function, which is a mapping from R
to (0, 1) and is defined as
1
logistic(u) :=
1 + e−u
eu
=
1 + eu
1 1 u
= + tanh .
2 2 2
It maps (−∞, 0) to (0, 0.5),[0, +∞) to [0.5, 1) and it satisfies logistic(0) =
0.5. It can therefore be seen as mapping from real values to probability
values. The logistic can be seen as a differentiable approximation to the
discontinuous binary step function step(u). The logistic function can
be shown to be the derivative of softplus, i.e., for all u ∈ R
softplus′ (u) = logistic(u).
Two important properties of the logistic function are that for all u ∈ R
logistic(−u) = 1 − logistic(u)
and
logistic′ (u) = logistic(u) · logistic(−u)
= logistic(u) · (1 − logistic(u)).
Other sigmoids are possible; see Section 13.6.
Argmax
The argmax operator is defined by
!
argmax(u) := ϕ arg max uj ,
j∈[M ]
This mapping puts all the probability mass onto a single coordinate (in
case of ties, we pick a single coordinate arbitrarily). Unfortunately, this
mapping is a discontinuous function.
Softargmax
As a differentiable everywhere relaxation, we can use the softargmax
defined by
exp(u)
softargmax(u) := PM .
j=1 exp(uj )
This operator is commonly known in the literature as softmax but this
is a misnomer: this operator really defines a differentiable relaxation
of the argmax. The output of the softargmax belongs to the relative
interior of the probability simplex, meaning that it can never reach the
borders of the simplex. If we denote π = softargmax(u), this means
that πj ∈ (0, 1), that is, πj can never be exactly 0 or 1. The softargmax
is the gradient of log-sum-exp,
∇logsumexp(u) = softargmax(u).
The softargmax can be seen as a generalization of the logistic function,
as we have for all u ∈ R
[softargmax((u, 0))]1 = logistic(u).
where ū := M1 PM
j=1 uj ). Using this constraint together with
M
log πi = ui − log exp(uj ),
X
j=1
we then obtain
M M
log πi = −M log exp(uj )
X X
i=1 j=1
so that
M
1 X
ui = [softargmax−1 (π)]i = log πi − log πj .
M j=1
where L stands for length. Note that we use the same number of
parameters P for each setup for notational convenience, but this of
course does not need to be the case. Throughout this section, we use
the notation p1:L := (p1 , . . . , pL ) for a sequence of L vectors.
c := f e (x1:L , we )
p1:L′ := f d (c, wd )
4.7 Summary
• greater than:
1 if u1 ≥ u2
gt(u1 , u2 ) :=
0 otherwise
= step(u1 − u2 )
99
100 Control flows
• less than:
1 if u1 ≤ u2
lt(u1 , u2 ) :=
0 otherwise
= 1 − gt(u1 , u2 )
= step(u2 − u1 )
• equal:
1 if |u1 − u2 | = 0
eq(u1 , u2 ) :=
0 otherwise
= gt(u2 , u1 ) · gt(u1 , u2 )
= step(u2 − u1 ) · step(u1 − u2 )
• not equal:
1 if |u1 − u2 | > 0
neq(u1 , u2 ) :=
0 otherwise
= 1 − eq(u1 , u2 )
= 1 − step(u2 − u2 ) · step(u1 − u2 ),
4
u1 greater than u2 1.0 4
u1 equal to u2 1.0
Value
Value
2 0.5 2 0.5
u2
u2
0 0.0 0 0.0
0 2 4 0 2 4
u1 u1
Smoothed operators with logistic
4
u1 greater than u2 1.0 4
u1 equal to u2 1.0
Value
Value
2 0.5 2 0.5
u2
u2
0 0.0 0 0.0
0 2 4 0 2 4
u1 u1
Figure 5.1: The greater than and equal to operators are discontinuous functions,
leading to black or white pictures. They can be smoothed with appropriate approxi-
mations of the Heaviside step function.
• sigmoidσ (0) = 12 .
102 Control flows
gt(µ1 , µ2 ) = step(µ1 − µ2 )
≈ sigmoidσ (µ1 − µ2 )
=: gtσ (µ1 , µ2 )
lt(µ1 , µ2 ) = step(µ2 − µ1 )
≈ sigmoidσ (µ2 − µ1 )
=: ltσ (µ1 , µ2 )
= 1 − sigmoidσ (µ1 − µ2 )
= 1 − gtσ (µ1 − µ2 ).
We see that the soft inequality operators are based on the CDF of the
difference between U1 and U2 .
From a perturbation perspective, we can also define noise variables
Z1 ∼ p0,1 and Z2 ∼ p0,1 such that U1 = µ1 + σ1 Z1 and U2 = µ2 + σ2 Z2
(Section 12.4.1). We then have
Gaussian case
Logistic case
When U1 ∼ Gumbel(µ1 , σ) and U2 ∼ Gumbel(µ2 , σ), we have
gt(µ1 , µ2 ) = E [gt(U1 , U2 )] ,
where Ui ∼ δµi and where δµi is the delta distribution that assigns a
probability of 1 to µi .
where ∥ϕ(µ)∥ := ⟨ϕ(µ), ϕ(µ)⟩ = κ(µ, µ). This is the cosine simi-
p p
k(µ1 , µ2 ) := κ(µ1 − µ2 ),
that depend only on the difference between inputs. When the kernel has
a scale parameter σ > 0, we use the notation κσ . We can then define a
soft equality operator as
κσ (µ1 − µ2 )
eq(µ1 , µ2 ) ≈ eqσ (µ1 , µ2 ) := .
κσ (0)
Several isotropic kernels can be chosen such as the Gaussian kernel
!
t2
κσ (t) := exp − 2
2σ
E[eq(U1 , U2 )] = P(U1 = U2 ) = 0,
Therefore,
and in particular
f|X| (0) = 2fX (0).
Therefore
f|U1 −U2 | (0) = 2fU1 −U2 (0),
further justifying using the PDF of U1 − U2 evaluated at 0. When X
follows a normal distribution, |X| follows the so-called folded normal
distribution.
Gaussian case
When U1 ∼ Normal(µ1 , σ12 ) and U2 ∼ Normal(µ2 , σ22 ), we obtain from
Eq. (5.1)
1 (t − (µ1 − µ2 ))2
!
fU1 −U2 (t) = √ exp −
2π 2(σ12 + σ22 )
so that
(µ1 − µ2 )2
!
eqσ (µ1 , µ2 ) = exp ∈ [0, 1].
2(σ12 + σ22 )
We indeedqrecover κσ (µ1 − µ2 )/κσ (0), where κσ is the Gaussian kernel
with σ = σ12 + σ22 . For the CDF of the absolute difference, we obtain
ε − (µ1 − µ2 ) −ε − (µ1 − µ2 )
P(|U1 − U2 | ≤ ε) = Φ −Φ .
σ σ
Logistic case
When U1 ∼ Gumbel(µ1 , σ) and U2 ∼ Gumbel(µ2 , σ), recalling that
so that
µ1 − µ2
eqσ (µ1 , µ2 ) = sech2 ∈ [0, 1].
2σ
We indeed recover κσ (µ1 − µ2 )/κσ (0), where κσ is the logistic kernel
with σ = σ1 = σ2 .
for some kernel k. Equipped with such a mapping from real numbers
to random variables, we need a measure of similarity between random
variables. A natural choice is their correlation
Cov(Uµi , Uµj )
corr(Uµi , Uµj ) := q ∈ [0, 1].
Var(Uµi ) Var(Uµj )
We therefore obtain
k(µ1 , µ2 )
corr(Uµi , Uµj ) = p
k(µ1 , µ1 )k(µ2 , µ2 )
⟨ϕ(µ1 ), ϕ(µ2 )⟩
= ,
∥ϕ(µ1 )∥∥ϕ(µ2 )∥
110 Control flows
• Commutativity:
and(π, π ′ ) = and(π ′ , π)
or(π, π ′ ) = and(π ′ , π)
• Associativity:
• Neutral element:
and(π, 1) = π
or(π, 0) = π
• De Morgan’s laws:
and
1 if 1 ∈ {π1 , . . . , πK }
any(π) := .
0 otherwise
and(π, π ′ ) = π · π ′
or(π, π ′ ) = π + π ′ − π · π ′
not(π) = 1 − π.
1.0
And operator 1.0 1.0
Or operator 1.0
Value
Value
0.5 0.5 0.5 0.5
0
0
0.0 0.0 0.0 0.0
0.0 0.5 1.0 0.0 0.5 1.0
Figure 5.3: The Boolean and and or operators are functions from {0, 1} × {0, 1}
to {0, 1} (corners in the figure) but their continuous extensions and(π, π ′ ) := π · π ′
as well as or(π, π ′ ) := π + π ′ − π · π ′ define a function from [0, 1] × [0, 1] to [0, 1].
all and any, which are functions from [0, 1]K to [0, 1], as
K
all(π) =
Y
πi
i=1
K
any(π) = 1 − (1 − πi ).
Y
i=1
all(π) = P(Y1 = 1 ∩ · · · ∩ YK = 1)
K
= P(Yi = 1)
Y
i=1
any(π) = P(Y1 = 1 ∪ · · · ∪ YK = 1)
= 1 − P(¬(Y1 = 1 ∪ · · · ∪ YK = 1))
= 1 − P(Y1 ̸= 1 ∩ · · · ∩ YK ̸= 1)
K
=1− (1 − P(Yi = 1)).
Y
i=1
These are the chain rule of probability and the addition rule of proba-
bility for K independent variables.
More generally, in the fuzzy logic literature (Klir and Yuan, 1995;
Jayaram and Baczynski, 2008), the concepts of triangular norms and
co-norms have been introduced to provide continuous relaxations of the
and and or operators, respectively.
Value
Value
Value
Value
0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5
0
0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.5 1.0 0.0 0.5 1.0 0.0 0.5 1.0 0.0 0.5 1.0
Figure 5.4: Alternative relaxations of the Boolean and and or operators using
triangular norms (t-norms).
Table 5.1: Examples of triangular norms and conorms, which are continuous
relaxations of the and and or operators, respectively. More instances can be obtained
by smoothing out the min and max operators.
Probabilistic perspective
From a probabilistic perspective, we can view Eq. (5.4) as the expecta-
tion of gi (ui ), where i ∈ {0, 1} is a binary random variable distributed
according to a Bernoulli distribution with parameter π = sigmoid(p):
fs (p, u1 , u0 ) = Ei∼Bernoulli(sigmoid(p)) [gi (ui )] .
Taking the expectation over the two possibles branches makes the
function differentiable with respect to p, since sigmoid(p) is differentiable.
Of course, this comes at the cost of evaluating both branches, instead
of a single one. The probabilistic perspective suggests that we can also
compute the variance if needed as
Vi∼Bernoulli(sigmoid(p)) [gi (ui )]
h i
=Ei∼Bernoulli(sigmoid(p)) (fs (p, u1 , u0 ) − gi (ui ))2 .
118 Control flows
Hard Soft
comparison comparison
operator operator
ifelse ifelse
Figure 5.5: Computation graphs of programs using if-else statements with either
hard or soft comparison operators. By using a hard comparison operator (step
function, left panel) the predicate π is a discrete variable (represented by a dashed
line). Depending on the value (0 or 1) of the predicate π, only one branch (red or blue)
contributes to the output. Derivatives along a path of continuous variables (dense
lines) can be computed. However, discrete variables such as the predicate prevent the
propagation of meaningful derivatives. By using a soft comparison operator (sigmoid,
right panel), the predicate is a continuous variable and derivatives with respect to
the input p can be taken. In this case both branches (corresponding to g0 and g1 )
contribute to the output and therefore need to be evaluated.
where
πa := sigmoidσ (x − a)
πb := sigmoidσ (b − x),
where
The difference stems from the fact that the local approach smoothes
out a ≤ x and x ≤ b independently (treating 1X≥a and 1X≤b as
independent random variables), while the global approah smoothes
out a ≤ x ≤ b simultaenously. In practice, both approaches ap-
proximate the original function well as σ → 0 and coincide for σ
sufficiently small as illustrated in Fig. 5.6.
120 Control flows
0 0 0
2 0 2 2 0 2 2 0 2
Orignal function Locally smoothed Globally smoothed
Figure 5.6: Global versus local smoothing approaches on a gate function f (x) := 1
if x ∈ [−1, 1], and f (x) := 0 otherwise. In our notation, we can write f (x) =
ifelse(and(gt(x, −1), lt(x, 1)), 1, 0). A local approach smoothes out gt and lt separately.
A global approach uses the expectation of the whole program, see Remark 5.1. We
observe that, though the approaches differ for large σ, they quickly coincide for
smaller σ.
ei := (0, . . . , |{z}
1 , . . . , 0),
i
Combining booleans
To form, such a vector π ∈ {e1 , . . . , eK }, we can combine the previously-
defined comparison and logical operators to define π = (π1 , . . . , πK ).
However, we need to ensure that only one πi is non-zero. We give an
example in Example 5.1.
5.7.2 Conditionals
We can now express a conditional statement as the function
cond : {e1 , . . . , eK } × V K → V defined by
v if π = e1
1
.
cond(π, v1 , . . . , vK ) := .. (5.5)
if π = eK
v
K
K
=
X
πi vi .
i=1
6 Mean
4 Standard deviation
Hard
2
0
2
4
6
4 2 0 2 4
Figure 5.7: A conditional with three branches: the soft-thresholding operator (see
Example 5.1). It is a piecewise linear function (dotted black line). Using a softargmax,
we can induce a categorical probability distribution over the three branches. The
expected value (blue line) can be seen as a smoothed out version of the operator.
The induced distribution allows us to also compute the standard deviation.
As for the ifelse case, the Jacobian w.r.t. p is null almost everywhere,
∂p fa (p, u1 , . . . , uK ) = 0.
softargmin(p) := softargmax(−p) ∈ △K .
i=1
Probabilistic perspective
From a probabilistic perspective, we can view Eq. (5.6) as the ex-
pectation of gi (ui ), where i ∈ [K] is a categorical random variable
distributed according to a categorical distribution with parameter
π = softargmax(p):
fs (p, u1 , . . . , uK ) = Ei∼Categorical(softargmax(p)) [gi (ui )] .
Taking the expectation over the K possible branches makes the function
differentiable with respect to p, at the cost of evaluating all branches,
instead of a single one. Similarly as for the if-else case, we can compute
the variance if needed as
Vi∼Categorical(softargmax(p)) [gi (ui )]
h i
=Ei∼Categorical(softargmax(p)) (fs (p, u1 , . . . , uK ) − gi (ui ))2 .
This is illustrated in Fig. 5.7.
For loops are a control flow for sequentially calling a fixed number K
of functions, reusing the output from the previous iteration. In full
generality, a for loop can be written as follows.
for i := 1, . . . , N do
for j := 1, . . . , N − i − 1 do
v ′ := swap(v, j, j + 1)
π := step(vj − vj+1 )
v ← ifelse(π, v ′ , v)
... ...
Figure 5.8: A for loop forms a com- Figure 5.9: Computation graph of the
putation chain. A feed forward network scan function. Sequence-to-sequence
can be seen as a parameterized for loop, RNNs can be seen as a parameterized
where each function fk depends on some scan function.
parameters wk .
v1 := u1
v2 := u1 + u2
v3 := u1 + u2 + u3
..
.
vk := sk−1 + uk
f (sk−1 , uk ) := (vk , vk )
ifelse
Figure 5.10: A while loop can be represented as a cyclic graph. The while loop
stops if π = 1 and performs another iteration s ← g(s), π ← f (s) if π = 0.
Unlike for loops and scan, the number of iterations of while loops is
not known ahead of time, and may even be infinite. In this respect, a
while loop can be seen as a cyclic graph, as illustrated in Fig. 5.10.
π0 := f (s0 )
if π0 = 1 then
r := s0
else
s1 := g(s0 ), π1 := f (s1 )
if π1 = 1 then
r := s1
else
s2 := g(s1 ), π2 := f (s2 )
if π2 = 1 then
r := s2
else
r := s3 := g(s2 )
r = ifelse(π0 ,
s0 ,
ifelse(π1 ,
s1 ,
ifelse(π2 ,
s2 ,
s3 )))
i=0
T i−1
= (1 − πj ) πi si ,
X Y
i=0 j=0
where we defined
si := g(si−1 ) := g i (s0 ) := g ◦ · · · ◦ g (s0 ) ∈ S
| {z }
i times
πi := f (si ) ∈ {0, 1}.
See also (Petersen et al., 2021). If we further define the shorthand
notation
π̃0 := π0
i−1
π̃i :=
Y
πj πi i ∈ {1, . . . , T },
j=0
...
...
...
cond
Figure 5.11: Computation graph of an unrolled truncated while loop. As in Fig. 5.5,
we depict continuous variables in dense lines and discrete variables in dashed lines.
The output of a while loop with at most T iterations
PT can be written as a conditional
with T + 1 branches, cond(π̃, s0 , . . . , sT ) = t=0 π̃t st .
s0 s1 s2 s3
s0 0 1 0 0
0 0 1 0
P := (pi,j )Ti,j=0 := s1 .
s2 0 0 0 1
s3 0 0 0 1
r = E[SI ]
+∞
= P(I = i)E[Si |I = i]
X
i=1
+∞
= P(I = i)si−1
X
i=1
X i−2
+∞
= (1 − πj )πi−1 si−1
Y
i=1 j=0
X i−1
+∞
= (1 − πj )πi si .
Y
i=0 j=0
134 Control flows
Because the stopping time is not known ahead of time, the sum over i
goes from 0 to ∞. However, if we enforce in the stopping criterion that
the while loop runs no longer than T iterations, by setting
i=0 j=0
s0 s1 s2 s3
s0 π0 1 − π0 0 0
P = s1 0 π1 1 − π1 0 .
s 02 0 π2 1 − π2
s3 0 0 0 1
With the help of this framework, we can backpropagate even through the
while loop’s stopping criterion, provided that we smooth out the predi-
cate. For example, we saw that the stopping criterion in Example 5.5 is
f (si ) = step(τ − ε(si )) and therefore
Due to the step function, the derivative of the while loop with respect
to τ will always be 0, just like it was the case for if-else statements. If
we change the stopping criterion to f (si ) = sigmoid(τ − ε(si )), we then
have (recall that or is well defined on [0, 1] × [0, 1])
i=0 j=0
T i−1
= (1 − sigmoid(ε(si ) − τ ))sigmoid(ε(si ) − τ )si .
X Y
i=0 j=0
5.11 Summary
• Unlike for loops and scan, the number of iterations of while loops
is not known ahead of time and may even be infinite. However,
unrolled while loops define valid directed acyclic graphs. We
defined a principled way to differentiate through the stopping
criterion of a while loop, thanks to a Markov chain perspective.
6
Data structures
6.1 Lists
l := (l1 , . . . , lK ) ∈ LK (V)
LK (V) := V K = V
|
× ·{z
· · × V} .
K times
136
6.1. Lists 137
Setting values
We now present how to replace values from a list l ∈ LK (V). We define
the function list.set : LK (V) × [K] × V → LK (V) as
v if i = j
[list.set(l, i, v)]j := ,
lj if i ̸= j
for j ∈ [K]. In the functional programming spirit, the function returns
the whole new list, even though a single element has been modified.
Again, the function is continuous and differentiable in l ∈ LK (V) and
v ∈ V but not in i ∈ [K]. In the particular case V = R, given a list
l = (l1 , . . . , lK ), we can write
list.set(l, i, v) = (v − li )ei .
That is, we subtract the old value li and add the new value v at the
location i ∈ [K].
Implementation
A fixed-length list can be implemented as an array, which enables O(1)
random access to individual elements. The hardware counterpart of
an array is random access memory (RAM), in which memory can be
retrieved by address (location).
138 Data structures
Initializing lists
list.init(v) := (v),
Pushing values
In order to add new values either to the left or to the right, we define
list.pushLeft : LK (V) × V → LK+1 (V) as
list.pushLeft(l, v) := (v, l1 , . . . , lK ).
Popping values
In order to remove values either from the left or from the right, we
define list.popLeft : LK (V) → LK−1 (V) × V as
list.popLeft(l) := (l2 , . . . , lK ), l1
Inserting values
The pushLeft and pushRight functions can only insert values at the
beginning and at the end of a list, respectively. We now study the insert
function, whose goal is to be able to add a new value at an arbitrary
location, shifting all values to the right and increasing the list size by 1.
We define the function list.insert : LK (V) × [K + 1] × V → LK+1 (V) as
l j
if j < i
[list.insert(l, i, v)]j := v if j = i ,
if j > i
lj−1
Differentiability
The list.init, list.push and list.pop functions are readily continuous and
differentiable with respect to their arguments (a continuous relaxation
is not needed). As for the list.set function, the list.insert function is
continuous and differentiable in l and v, but not in i.
Implementation
Under the hood, a variable-length list can be implemented as a linked
list or as a dynamic array. A linked list gives O(K) random access while
a dynamic array allows O(1) random access, at the cost of memory
reallocations.
= cond(πi , l1 , . . . , lK )
= EI∼Categorical(πi ) [lI ],
j=−∞
1 2 3 4 5 1 2 3 4 5
Figure 6.1: The list.get(l, i) function is continuous and differentiable in l but not
in i. Its relaxation list.sofGet(l, πi ) is differentiable in both l and πi . When V = R,
list.softGet(l, πi ) can be seen as taking the inner product between the list l and the
probability distribution πi , instead of the delta distribution (canonical vector) ei .
Setting values
Inserting values
Multi-dimensional indexing
6.2 Dictionaries
DL (K, V) := LL (K × V) = (K × V)L
Getting values
The goal of the dict.get function is to retrieve the value associated with
a key, assuming that the dictionary contains this key. Formally, we
define the dict.get : DL (K, V) × K → V ∪ {∞} function as
v
i if ∃i ∈ [L] s.t. k = ki
dict.get(d, k) := .
∞ if k ̸∈ {k1 , . . . , kL }
eq(k, ki )vi
PL
dict.get(d, k) := Pi=1 .
i=1 eq(k, ki )
L
0.2
(k4, v4)
(k1, v1) Key-value pairs
Kernel Estimator
0.0
0.0 0.5 1.0
Keys
Figure 6.2: Given a set of key-value pairs (ki , vi ) ∈ K × V defining a dictionary d,
we can estimate a continuous mapping from K to V using Nadaraya–Watson kernel
regression (here, illustrated with K = V = R). When keys are normalized to have
unit norm, this recovers softargmax attention from Transformers.
Setting values
The goal of the dict.set function is to replace the value associated with
an existing key. Formally, we define the dict.set : DL (K, V) × K × V →
DL (K, V) function as
(k , v) if ki = k
i
(dict.set(d, k, v))i := .
(ki , vi ) if ki ̸= k
Implementation
f (k, v)
f (v|k) = .
f (k)
f (k, v)
Z Z
E[V |K = k] = f (v|k)vdv = vdv.
V V f (k)
This is the Bayes predictor, in the sense that E[V |K] is the minimizer
of E[(h(K) − V )2 ] over the space of measurable functions h : K →
V. Using a sample of L input-output pairs (ki , vi ), corresponding to
key-value pairs in our case, Nadaraya–Watson kernel regression
estimates the joint PDF and the marginal PDF using kernel density
estimation (KDE). Using a product of isotropic kernels κσ and ρσ for
key-value pairs, we can define
L
1X
fbσ (k, v) := κσ (k − ki )ρσ (v − vi ).
L i=1
fbσ (k, v)
Z
b |K = k] :=
E[V vdv
V fbσ (k)
i=1 κσ (k − ki )ρσ (v − vi )
Z 1 PL
= L
vdv
i=1 κσ (k − ki )
1 PL
V L
i=1 κσ (k − ki ) V ρσ (v − vi )vdv
PL R
=
i=1 κσ (k − ki )
PL
κσ (k − ki )vi
PL
= Pi=1 .
i=1 κσ (k − ki )
L
κσ (k − ki )vi
PL
dict.softGet(d, k) := Pi=1 .
i=1 κσ (k − ki )
L
κσ (k − ki )
πk,i := PL ∀i ∈ [L].
j=1 κσ (k − kj )
6.2. Dictionaries 147
...
...
Weight
Avg
Figure 6.3: Computation graph of the dict.softGet function. We can use a kernel
κσ to produce a discrete probability distribution πk = (πk,1 , . . . , πk,L ) ∈ △L , that
captures the affinity between the dictionary keys (k1 , . . . , kL ) and the input key k.
The dict.softGet function can then merely be seen as a convex combination (weighted
average) of values (v1 , . . . , vL ) using the probability values (πk,1 , . . . , πk,L ) as weights.
This distribution captures the affinity between the input key k and the
keys (k1 , . . . , kL ) of dictionary d. As illustrated in Fig. 6.3, we obtain
so that
κσ (k − ki )
πk,i = PL
j=1 κσ (k − kj )
exp(⟨k, ki ⟩/σ 2 )
= PL .
j=1 exp(⟨k, kj ⟩/σ )
2
6.3 Summary
Differentiating through
programs
7
Finite differences
From Definition 2.4 and Definition 2.13, the directional derivative and
more generally the JVP are defined as a limit,
f (w + δv) − f (w)
∂f (w)[v] := lim .
δ→0 δ
This suggests that we can approximate the directional derivative and
the JVP using
f (w + δv) − f (w)
∂f (w)[v] ≈ ,
δ
151
152 Finite differences
δ2 2 δ3
f (w+δv)−f (w) = δ∂f (w)[v]+ ∂ f (w)[v, v]+ ∂ 3 f (w)[v, v, v]+. . .
2 3!
so that
f (w + δv) − f (w) δ δ2
= ∂f (w)[v] + ∂ 2 f (w)[v, v] + ∂ 3 f (w)[v, v, v] + . . .
δ 2 3!
= ∂f (w)[v] + o(δ).
Approximation error
10 6
10 8
10 10
10 12
10 14 Forward difference
Central difference
10 16 Complex Step
10 13 10 11 10 9 10 7 10 5 10 3 10 1
i=0
δ
By grouping the terms in the sum for each order of derivative, we obtain
a set of p + 1 equations to be satisfied by the p + 1 coefficients a0 , . . . , ap ,
that is,
a0 + a1 + . . . + ap = 0
a1 + 2a2 + . . . + pap = 1
a1 + 2j a2 + . . . + pj ap = 0 ∀j ∈ {2, . . . , p}.
i=−p
δ
i=0
δk
As before, we can expand the terms in the sum. For the approximation
to capture only the k th derivative, we now require the coefficients ai to
satisfy
0j a0 + 1j a1 + 2j a2 + . . . + pj ap = 0 ∀j ∈ {0, . . . , k − 1}.
0k a0 + 1k a1 + 2k a2 + . . . + pk ap = k!
0j a0 + 1j a1 + 2j a2 + . . . + pj ap = 0 ∀j ∈ {k + 1, . . . , p}.
i=−p
δk
and
f (w + (iδ)v)
Im = ∂f (w)[v] + o(δ 2 ).
δ
This suggests that we can compute directional derivatives using the
approximation
f (w + (iδ)v)
∂f (w)[v] ≈ Im ,
δ
for 0 < δ ≪ 1. This is called the complex-step derivative approxi-
mation (Squire and Trapp, 1998; Martins et al., 2003).
Contrary to forward, backward and central differences, we see that
only a single function call is necessary. A function call on complex
numbers may take roughly twice the cost of a function call on real
numbers. However, thanks to the fact that a difference of functions is
no longer needed, the complex-step derivative approximation usually
enjoys smaller round-off error as illustrated in Fig. 7.1. That said, one
drawback of the method is that all elementary operations within the
program implementing the function f must be well-defined on complex
numbers, e.g., using overloading.
7.7. Complexity 157
7.7 Complexity
7.8 Summary
159
160 Automatic differentiation
s0 ∈ S0
s1 := f1 (s0 ) ∈ S1
..
.
sK := fK (sK−1 ) ∈ SK
f (s0 ) := sK . (8.1)
∂f (s0 ) = ∂fK (sK−1 )∂fK−1 (sK−2 ) . . . ∂f2 (s1 )∂f1 (s0 ), (8.2)
where ∂fk (sk−1 ) are the Jacobians of the intermediate functions com-
puted at s0 , . . . , sK , as defined in Eq. (8.1). The main drawback of
this approach is computational: computing the full ∂f (s0 ) requires
to materialize the intermediate Jacobians in memory and to perform
matrix-matrix multiplications. However, in practice, computing the full
Jacobian is rarely needed. Indeed, oftentimes, we only need to right-
multiply or left-multiply with ∂f (s0 ). This gives rise to forward-mode
and reverse-mode autodiff, respectively.
8.1.1 Forward-mode
... ...
... ...
t0 := v
t1 := ∂f1 (s0 )[t0 ]
..
.
tK := ∂fK (sK−1 )[tK−1 ]
∂f (s0 )[v] := tK .
8.1.2 Reverse-mode
In machine learning, most functions whose gradient we need to compute
take the form ℓ ◦ f , where ℓ is a scalar-valued loss function and f is a
network. As seen in Proposition 2.3, the gradient takes the form
Memory
usage
Algorithm steps
Figure 8.2: Memory usage of forward-mode autodiff for a computation chain. Here
t0 = v, sK = f (s0 ), tK = ∂f (s0 )[v].
This motivates the need for applying the adjoint ∂f (s0 )∗ to ∇ℓ(f (s0 )) ∈
SK and more generally to any output direction u ∈ SK . From Propo-
sition 2.7, we have
rK = u
rK−1 = ∂fK (sK−1 )∗ [rK ]
..
.
r0 = ∂f1 (s0 )∗ [r1 ]
∂f (s0 )∗ [u] = r0 .
Forward pass
... ...
... ...
Backward pass
Memory
usage
Algorithm steps
Forward pass Backward pass
Figure 8.4: Memory usage of reverse mode autodiff for a computation chain.
where
∂f (s0 )∗ [u] := backward(u; s0 , . . . , sK−1 ).
In functional programming terminology, the VJP ∂f (s0 )∗ is a closure,
as it contains the intermediate computations s0 , . . . sK . The same can
be done for the JVP ∂f (s0 ) if we want to apply to multiple directions
vi .
s0 = x
s1 = f1 (s0 ) = σ(A1 s0 + b1 )
s2 = f2 (s1 ) = A2 s1 + b2
f (x) = s2 ,
t0 = v
t1 = σ ′ (A1 s0 + b1 ) ⊙ (A1 t0 )
t2 = A2 t1
∂f (x)[v] = t2 ,
r2 = u
r1 = ∂f2 (s1 )∗ [r2 ] = A⊤
2 r2
r0 = ∂f1 (s0 )∗ [r1 ] = A⊤ ′
1 (σ (A1 s0 + b1 ) ⊙ r1 )
∂f (x)∗ [u] = r0 .
Forward-mode Reverse-mode
Time O(M D2 + KD3 ) O(M 2 D + KM D2 )
Space O(max{M, D}) O(KD + M )
Table 8.1: Time and space complexities of forward-mode and reverse-mode autodiff
for computing the full Jacobian of a chain of functions f = fK ◦ · · · ◦ f1 , where
fk : RD → RD if k = 1, . . . , K − 1 and fK : RD → RM . We assume ∂fk is a dense
linear operator. Forward mode requires D JVPs. Reverse mode requires M VJPs.
Using Definition 2.9, we find that we can extract each row of the Jaco-
bian matrix, which corresponds to the transposed gradients ∇fi (s0 ) ∈
RD , for i ∈ [M ], by multiplying with the standard basis vector ei ∈ RM :
... ...
∇L(w; x, y) = (g1 , . . . , gK ) ∈ W1 × · · · × WK
where
∂f (x, w)∗ [u] = (r0 , (g1 , . . . , gK )),
with u = ∇ℓ(f (x, w); y) ∈ SK . The output r0 ∈ S0 , where S0 = X ,
corresponds to the gradient w.r.t. x ∈ X and is typically not needed,
except in generative modeling settings. The full procedure is summarized
in Algorithm 8.3.
8.2. Feedforward networks 171
Forward pass
... ...
Backward pass
j=1
8.3.1 Forward-mode
The forward mode corresponds to computing a JVP in an input direction
v ∈ S0 . The algorithm consists in computing intermediate JVPs along
the forward pass. We initialize t0 := v ∈ S0 . Using Proposition 2.8, the
derivatives on iteration k ∈ [K] are propagated as
j=1
8.3.2 Reverse-mode
The theorem gives an upper bound on the size of the best circuit for
computing ∇f from the size of the best circuit for computing f .
8.4 Implementation
(Section 4.1.4), A = {+, ×}. More generally, A may contain all the
necessary functions for expressing programs. We emphasize, however,
that A is not necessarily restricted to low-level functions such as log
and exp, but may also contain higher-level functions. For instance,
even though the log-sum-exp can be expressed as the composition
of elementary operations (log, sum, exp), it is usually included as a
primitive on its own, both because it is a very commonly-used building
block, but also for numerical stability reasons.
An autodiff system must implement for each f ∈ A its JVP for support-
ing the forward mode, and its VJP for supporting the reverse mode.
We give a couple of examples. We start with the JVP and VJP of linear
functions.
Example 8.2 (JVP and VJP of linear functions). Consider the matrix-
vector product f (W ) = W x ∈ RM , where x ∈ RD is fixed and
W ∈ RM ×D . As already mentioned in Section 2.3.1, the JVP of f
at W ∈ RM ×D along an input direction V ∈ RM ×D is simply
∂f (W )[V ] = f (V ) = V x ∈ RM .
∂f (W )∗ [u] = ux⊤ ∈ RM ×D .
∂f (W )[V ] = f (V ) = V X ∈ RM ×N .
∂f (W )∗ [U ] = U X ⊤ ∈ RM ×D .
Example 8.3 (JVP and VJP of separable function). Consider the func-
tion f (w) := (g1 (w1 ), . . . , gP (wP )), where each gi : R → R has a
derivative gi′ . The Jacobian matrix is then a diagonal matrix
In this case, the JVP and VJP are actually the same
8.5 Checkpointing
Function step
2
1
0
Time step
M(K) = log2 K.
Analytical formula
It turns out that we can also find an optimal scheme analytically. This
scheme was found by Griewank (1992), following the analysis of optimal
inversions of sequential programs by divide-and-conquer algorithms
done by Grimm et al. (1996); see also Griewank (2003, Section 6) for
a simple proof. The main idea consists in considering the number of
times an evaluation step fk is repeated. As we split the chain at l, all
steps from 1 to l will be repeated at least once. In other words, treating
the second half of the chain incurs one memory cost, while treating the
first half of the chain incurs one repetition cost. Griewank (1992) shows
that for fixed K, S, we can find the minimal number of repetitions
analytically and build the corresponding scheme with simple formulas
for the optimal splits.
Compared to the dynamic programming approach, it means that we
do not need to compute the pointers l∗ (k, s), and we can use a simple
formula to set l∗ (k, s). We still need to traverse the corresponding binary
tree given K, S and the l∗ (k, s) to obtain the schedules. Note that such
optimal scheme does not take into account varying computational costs
for the functions fk .
The optimal scheme presented above requires knowing the total number
of nodes in the computation graph ahead of time. However, when
differentiating through for example a while loop (Section 5.10), this is
not the case. To circumvent this issue, online checkpointing schemes
have been developed and proven to be nearly optimal (Stumm and
Walther, 2010; Wang et al., 2009). These schemes start by defining a set
of S checkpoints with the first S computations, then these checkpoints
are rewritten dynamically as the computations keep going. Once the
computations terminate, the optimal approach presented above for a
fixed length is applied on the set of checkpoints recorded.
184 Automatic differentiation
8.8 Summary
• The forward mode: i) uses JVPs, ii) builds the Jacobian one
column at a time, iii) is efficient for tall Jacobians (M ≥ P ), iv)
need not store intermediate computations.
• The reverse mode: i) uses VJPs, builds the Jacobian one row at
a time, iii) is efficient for wide Jacobians (P ≥ M ), iv) needs to
store intermediate computations, in order to be computationally
optimal.
187
188 Second-order automatic differentiation
9.1.2 Complexity
To get a sense of the computational and memory complexity of the four
approaches, we consider a chain of functions f := fK ◦ · · · ◦ f1 as done
9.1. Hessian-vector products 189
Method Computation
Reverse on reverse (VJP of gradient) ∂(∇f )(w)∗ [v]
Forward on reverse (JVP of gradient) ∂(∇f )(w)[v]
Reverse on forward (gradient of JVP) ∇(∂f (·)[v])(w)
Forward on forward (JVPs of JVPs) (∂ 2 f (w)[v, ei ])Pi=1
Figure 9.1: Computation graph corresponding to reverse mode autodiff for eval-
uating the gradient of f = fK ◦ . . . f1 . While f is a simple chain, ∇f is a DAG.
Figure 9.2: Computation graph for computing the HVP ∇2 f (x)[v] by using reverse
mode on top of reverse mode. As the computation graph of ∇f induces fan-in
operations sk−1 , rk 7→ ∂fk (sk−1 )[rk ], the reverse mode applied on ∇f induces
branching of the computations at each such node.
Figure 9.3: Computation graph for computing the HVP ∇2 f (x)[v] by using forward
mode on top of reverse mode. The forward mode naturally follows the computations
done for the gradient, except that it passes through the derivatives of the intermediate
operations.
9.1. Hessian-vector products 191
j=1
M
∇2GN (ℓ ◦ f ) = λi ∂f (w)∗ ui u⊤
i ∂f (w)
∗
X
j=1
M p p ⊤
= λi ∂f (w)∗ ui λi ∂f (w)∗ ui
X
j=1
M
= vi vi⊤ where vi := λi ∂f (w)∗ ui .
X p
j=1
With some slight abuse of notation, we then have that the Gauss-Newton
matrix associated with a pair (x, y) is
is then h i
∇2GN L(w) = EX,Y ∼ρ ∇2GN L(w; X, Y ) .
∇2F L(w) = ES∼qw [∇2 L(w; S)] = ES∼qw [−∇2w log qw (S)].
where ρX (x) :=
R
ρ(x, y)dy.
where we used · to indicate that the results holds for all y. Plugging the
result back in the Fisher information matrix concludes the proof.
198 Second-order automatic differentiation
u 7→ ∇2 L(w)−1 u,
H[v] = u
by only accessing the linear map v 7→ H[v] for any v. Among such
algorithms, we have the conjugate gradient (CG) method, that applies
for H positive-definite, i.e., such that ⟨v, H[v]⟩ > 0 for all v ̸= 0, or the
generalized minimal residual (GMRES) method, that applies for
any invertible H. A longer list of solvers can be found in public software
such as SciPy (Virtanen et al., 2020). The IHVP of a strictly convex
9.4. Inverse-Hessian vector product 199
9.4.3 Complexity
j=1
i=1
n
+ ∂fi (w)∗ ∂i,j g(f (w))∗ [u]∂fj (w).
X
2
i,j=1
independently.
i=1
n
+ ∂f (w)∗ ∂ 2 gi (f (w))∗ [u]∂f (w).
X
i=1
s0 := x
sk := fk (sk−1 , wk ) ∀k ∈ {1, . . . , K}
f (x, w) := sK ,
Rk−1 := Wk⊤ Jk Wk
Jk := Rk ⊙ (a′ (Wk sk−1 )a′ (Wk sk−1 )⊤ )
Gk := Jk ⊗ sk−1 s⊤
k−1
diagonal of A,
Eω∼p [ω ⊙ Aω] = Diag(A),
where ⊙ denotes the Hadamard product (element-wise multiplication).
This suggests that we can use the Monte-Carlo method to estimate the
diagonal of A,
S
1X
Diag(A) ≈ ωi ⊙ Aωi ,
S i=1
with equality as S → ∞, since the estimator is unbiased. Since, as
reviewed in Section 9.1 and Section 9.2, we know how to multiply
efficiently with the Hessian and the Gauss-Newton matrices, we can
apply the technique with these matrices. The variance is determined
by the number S of samples drawn and therefore by the number of
matvecs performed. More elaborated approaches have been proposed to
further reduce the variance (Meyer et al., 2021; Epperly et al., 2023).
Suppose the objective function is of the form L(w; x, y) := ℓ(f (w; x); y)
where ℓ is the negative log-likelihood ℓ(θ; y) := − log pθ (y) of an ex-
ponential family distribution, and θ := f (w; x), for some network f .
We saw from the equivalence between the Fisher and Gauss-Newton
matrices in Proposition 9.6 (which follows from the Bartlett identity)
that
∇2GN L(w; x, ·) = EY ∼pθ [∂f (w; x)∗ ∇ℓ(θ; Y ) ⊗ ∇ℓ(θ; Y )∂f (w; x)]
= EY ∼pθ [∇L(w; x, Y ) ⊗ ∇L(w; x, Y )],
where · indicates that the result holds for any value of the second
argument. This suggests a Monte-Carlo scheme
S
1X
∇2GN L(w; x, ·) ≈ [∇L(w; x, yij ) ⊗ ∇L(w; x, yij )]
S j=1
γj = E γi ⊙ γi +
X X X X
E γi ⊙ γi ⊙ γj
i j i i̸=j
" #
=E
X
γi ⊙ γi
i
where we used that E[γi ⊙ γj ] = E[γi ] ⊙ E[γj ] = 0 since γi and γj are
independent variables for i ̸= j and have zero mean, from Bartlett’s
first identity Eq. (12.2). We can then use the Monte-Carlo method to
obtain
S S
1 1X 1X
diag(∇2GN L(w; x, ·)) ≈ ∇ L(w; x, yij ) ⊙ ∇ L(w; x, yij ) ,
S S j=1
S j=1
with equality when all labels in the support of pθ have been sampled.
This estimator can be more convenient to implement, since it only needs
access to the gradient of the averaged losses. However, it may suffer
from higher variance. A special case of this estimator is used by Liu
et al. (2023), where they draw only one y for each x.
212 Second-order automatic differentiation
9.9 Summary
213
214 Inference in graphical models as differentiation
j=1
K j−1
=
Y \
P Aj Ai .
j=1 i=1
j=1
s∈S s1 ,...,sK ∈S
As we shall see, the graph of a graphical model encodes the dependen-
cies between the variables (S1 , . . . , SK ) and therefore how their joint
distribution factorizes. Given access to a joint probability distribution,
there are several inference problems one typically needs to solve.
10.3.2 Likelihood
A simple task is to compute the likelihood of some observations
s = (s1 , . . . , sK ),
P(S1 = s1 , . . . , Sk = sk ) = p(s1 , . . . , sk ).
It is also common to compute the log-likelihood,
log P(S1 = s1 , . . . , Sk = sk ) = log p(s1 , . . . , sk ).
This is the mode of the joint probability distribution. This is also known
as maximum a-posteriori (MAP) inference in the literature (Wainwright
and Jordan, 2008).
= p(s1 , . . . , sK )
X X
Defining similarly
Ck,l (sk , sl ) := S1 × · · · × {sk } × · · · × {sl } × · · · × SK ,
we obtain
P(Sk = sk , Sl = sl ) = p(s1 , . . . , sK ).
X
s∈S
s∈S
Convex hull
s∈S
[µ]k,i = ES [ϕ(S)k,i ]
= ESk [ϕk (Sk )i ]
= ESk [I(Sk = vi )]
= P(Sk = sk )I(sk = vi )
X
sk ∈Sk
= P(Sk = vi ).
[µ]k,l,i,j = ES [ϕ(S)k,l,i,j ]
= ESk ,Sl [ϕk,l (Sk , Sl )i,j ]
= ESk ,Sl [I(Sk = vi , Sl = vj )]
= P(Sk = sk , Sl = sl )I(sk = vi , sl = vj )
X X
sk ∈Sk sl ∈Sl
= P(Sk = vi , Sl = vj ).
S0 := s0
S1 ∼ p1 (· | S0 )
S2 ∼ p2 (· | S1 )
..
.
SK ∼ pK (· | SK−1 ).
...
...
...
Figure 10.1: Left: Markov chain. Right: Computation graph of the forward-
backward and the Viterbi algorithms: a lattice.
start ... end
P(S1 = s1 , . . . , SK = sK ) = p(s1 , . . . , sK )
K
= P(Sk = sk | Sk−1 = sk−1 )
Y
k=1
K
= pk (sk | sk−1 ),
Y
k=1
Sk ∼ Categorical(πk−1,k,Sk−1 )
where
πk−1,k,i := softargmax(θk−1,k,i ) ∈ △M
= (πk−1,k,i,j )M
j=1
θk−1,k,i := (θk−1,k,i,j )M
j=1 ∈ R
M
We therefore have
P(Sk = j | Sk−1 = i) = pk (j | i)
= πk−1,k,i,j
= [softargmax(θk−1,k,i )]j
exp(θk−1,k,i,j )
=P
j ′ exp(θk−1,k,i,j ′ )
and
j′
p1 = · · · = pK = p.
More generally, a nth -order Markov chain may depend, not only on the
last variable, but on the last n variables,
Markov chains and more generally higher-order Markov chains are a spe-
cial case of Bayesian network. Similarly to computation graphs reviewed
in Section 8.3, variable dependencies can be expressed using a directed
acyclic graph (DAG) G = (V, E), where the vertices V = {1, . . . , K}
represent variables and edges E represent variable dependencies. The set
{i1 , . . . , ink } = pa(k) ⊆ V, where nk := |pa(k)|, indicates the variables
Si1 , . . . , Sink that Sk depends on. This defines a partially ordered
set (poset). For notational simplicity, we again assume without loss of
generality that S0 is deterministic. A computation graph is specified
by functions f1 , . . . , fK in topological order. In analogy, a Bayesian
network is specified by conditional probability distributions pk of Sk
10.5. Bayesian networks 223
S0 := s0
S1 ∼ p1 (· | S0 )
S2 ∼ p2 (· | Spa(2) )
..
.
SK ∼ pK (· | Spa(K) ).
P(S = s) := P(S1 = s1 , . . . , SK = sK )
K
= P(Sk = sk |Spa(k) = spa(k) )
Y
k=1
K
:= pk (sk |spa(k) )
Y
k=1
Z := ψC (sC ),
X Y
s∈S C∈C
k=1
matically normalized),
1 Y
pθ (s) := ψC (sC ; θC )
Z(θ) C∈C
1 Y
= exp(⟨θC , ϕC (sC )⟩)
Z(θ) C∈C
1
!
= exp ⟨θC , ϕC (sC )⟩
X
Z(θ) C∈C
1
= exp (⟨θ, ϕ(s)⟩)
Z(θ)
= exp (⟨θ, ϕ(s)⟩ − A(θ))
226 Inference in graphical models as differentiation
where
s∈S C∈C
s∈S
A(θ) := log Z(θ)
P(Y = y) = pθ (y)
= exp θi yi +
X X
θi,j yi yj − A(θ)
i∈V (i,j)∈E
!
= exp ⟨θC , ϕC (y)⟩ − A(θ) ,
X
C∈C
graph cut algorithms (Greig et al., 1989). There are two ways the
above equation can be extended. First, we can use higher-order
interactions, such as yi yj yk for (i, j, k) ∈ V 3 . Second, we may want
to use categorical variables, which leads to the Potts model.
10.6.4 Sampling
Contrary to Bayesian networks, MRFs require an explicit normalization
constant Z. As a result, sampling from a distribution represented by a
general MRF is usually more involved than for Bayesian networks. A
commonly-used technique is Gibbs sampling.
where
K
Z := ψk (sk−1 , sk−1 )
XY
s∈S k=1
and where we used ψk as a shorthand for ψk−1,k , since k − 1 and k are
consecutive. As explained in Example 10.2, this also includes Markov
228 Inference in graphical models as differentiation
chains by setting
in which case Z = 1.
sk−1 ∈Sk−1
sk+1 ∈Sk+1
sK ∈SK s1 ∈S1
and the marginal probabilities by
1
P(Sk = sk ) = αk (sk )βk (sk )
Z
1
P(Sk−1 = sk−1 , Sk = sk ) = αk−1 (sk−1 )ψk (sk−1 , sk )βk (sk ).
Z
We can also compute the conditional probabilities by
P(Sk−1 = sk−1 , Sk = sk )
P(Sk = sk | Sk−1 = sk−1 ) =
P(Sk−1 = sk−1 )
αk−1 (sk−1 )ψk (sk−1 , sk )βk (sk )
=
αk−1 (sk−1 )βk−1 (sk−1 )
ψk (sk−1 , sk )βk (sk )
= .
βk−1 (sk−1 )
In practice, the two recursions are often implemented in the log-domain
for numerical stability,
log αk (sk ) = log exp(log ψk (sk−1 , sk ) + log αk−1 (sk−1 ))
X
sk−1 ∈Sk−1
log βk (sk ) = log exp(log ψk+1 (sk , sk+1 ) + log βk+1 (sk+1 )).
X
sk+1 ∈Sk+1
6: δ ⋆ := max δK (sK )
sK ∈SK
7: s⋆K := arg max δK (sK )
sK ∈SK
8: for k := K − 1, . . . , 1 do ▷ Backtracking
9: s⋆k := qk+1 (s⋆k+1 )
Outputs: max p(s1 , . . . , sK ) ∝ δ ⋆
s1 ∈S1 ,...,sK ∈SK
arg max p(s1 , . . . , sK ) = (s⋆1 , . . . , s⋆K )
s1 ∈S1 ,...,sK ∈SK
s∈S
• Commutativity of ⊕: a ⊕ b = b ⊕ a,
• Associativity of ⊕: a ⊕ (b ⊕ c) = (a ⊕ b) ⊕ c,
with ε := 1 by default.
and
!
argmaxε f (v) := exp(f (v ′ )/ε)/ exp(f (v)/ε)
X
∈ P(V).
v∈V v∈V v ′ ∈V
and
Q := argmaxε aK,j ∈ △M ,
j∈[M ]
=
X
rk+1,j qk+1,j,i
j∈[M ]
=
X
µk+1,i,j .
j∈[M ]
6: A := maxε aK,i ∈ R
i∈[M ]
7: Q := argmaxε aK,i ∈ △M
i∈[M ]
8: Initialize rK,j = Qj ∀j ∈ [K]
9: for k := K − 1, . . . , 1 do ▷ Backward pass
10: for i ∈ [M ] do
11: for j ∈ [M ] do
12: µk+1,i,j = rk+1,j · qk+1,j,i
13: rk,i ← µk+1,i,j
maxε θ1,1,i1 + = A, ∇A(θ) = µ
PK
Outputs: k=2 θk,ik−1 ,ik
i1 ,...,iK ∈[M ]K
10.10 Summary
239
240 Differentiating through optimization
F (w, λ) = 0
for all λ ∈ Λ.
h(λ) := g(w⋆ (λ), λ), where w⋆ (λ) = arg max f (w, λ). (11.1)
w∈W
5 w ( 1)
w ( 2)
w ( )
10
0 1 2
Function
Figure 11.2: The graph of h(λ) = maxw∈W f (w, λ) is the upper-envelope of the
graphs of the functions λ 7→ f (w, λ) for all w ∈ W.
and
w⋆ (λ) := arg max f (w, λ).
w∈W
If f is concave in w, convex in λ, and the maximum w⋆ (λ) is
unique, then the function h is differentiable with gradient
and
w⋆ (λ) := arg max f (w, λ).
w∈W
If f is continuously differentiable in λ for all w ∈ W, ∇1 f is
continuous and the maximum w⋆ (λ) is unique, then the function
h is differentiable with gradient
• x⋆ (λ0 ) = x0 ,
⋆
• ∂x⋆ (λ) = − ∂∂21 FF (x (λ),λ)
(x⋆ (λ),λ) .
not true: failure of the IFT assumptions does not necessarily mean that
the implicit function is not differentiable, as we now illustrate.
meaning that the Jacobian ∂w⋆ (λ), assuming that it exists, satisfies
• w⋆ (λ0 ) = w0 ,
Assume the assumptions of the IFT hold. The JVP t := ∂w⋆ (λ)v
in the input direction v ∈ Λ is obtained by solving the linear system
At = Bv.
A∗ r = u.
and by the inverse function theorem, we have ∂f −1 (λ, 0) = (∂f (λ, w∗ (λ)))−1 .
So using block matrix inversions formula
!−1 !
A B ∼ ∼
= ,
C D −(D − CA−1 B)−1 CA−1 ∼
we get the claimed expression. Though we expressed the proof in terms of
Jacobians and matrices, the result naturally holds for the corresponding
linear operators, JVPs, VJPs, and their inverses.
Solving this linear system w.r.t. r at s = s⋆ (w) gives the adjoint variable
r ⋆ (w). We then get
s0 := x ∈ X
s1 := f1 (s0 , w1 ) ∈ S1
..
.
sK := fK (sK−1 , wK ) ∈ SK
f (w) := sK . (11.2)
s1 − f1 (x, w1 )
s2 − f2 (s1 , w2 )
c(s, w) :=
.. .
.
sK − fK (sK−1 , wK )
11.4. Adjoint state method 253
This defines an implicit function s⋆ (w) = (s⋆1 (w), . . . , s⋆K (w)), the
solution of this nonlinear system, which is given by the variables
s1 , . . . , sK defined in (11.2). The output of the feedforward network is
then f (w) = s⋆K (w).
In machine learning, the final layer s⋆K (w) is typically fed into a
loss ℓ, to define
L(w) := ℓ(s⋆K (w); y).
Note that an alternative is to write L(w) as
L(w) = min ℓ(s; y) s.t. c(s, w) = 0.
s∈S
0 ... 0 −AK I
254 Differentiating through optimization
I −A∗1 0 0
...
.. ..
0 I . .
−A∗2
.
.. ..
..
∗
∂1 c(s (w), w) =
⋆
. I . 0 .
.
. .. ..
. . . −A∗K
0 ... ... 0 I
∂2 f1 (x, w1 )∗ r1
∗
∂2 f2 (s1 (w), w2 )∗ r2
∂2 c(s(w), w) r =
.. ,
.
∂2 fK (s1 (w), wK )∗ rK
minimize over w but we do not require this, as gradients are not neces-
sarily used for optimization. Our exposition also supports computing
the VJP of any vector-valued function f , while existing works derive
the gradient of a scalar-valued loss function.
with w = f −1 (ω).
The inverse function theorem can be used to prove the implicit function
theorem; see proof of Theorem 11.4. Conversely, recall that, in order to
use the implicit function theorem, we need to choose a root objective
F : W × Λ → W. If we set W = Λ = RQ and F (w, ω) = f (w) −
ω, with f : RQ → RQ , then we have that the root w⋆ (ω) satisfying
F (w⋆ (ω), ω) = 0 is exactly w⋆ (ω) = f −1 (ω). Moreover, ∂1 F (w, ω) =
∂f (w) and ∂2 F (w, ω) = −I. By applying the implicit function theorem
with this F , we indeed recover the inverse function theorem.
f ◦ f −1 (ω) = ω.
∂f (f −1 (ω))∂f −1 (ω) = I,
11.6 Summary
258
12.2. Differentiating through expectations 259
1. Sample y1 , . . . , yn from p.
Continuous case
When Y is a continuous set (that is, pθ (y) is a probability density
function), we can rewrite E(θ) as
Z
E(θ) = pθ (y)g(y)dy.
Y
Discrete case
When Y is a discrete set (that is, pθ (y) is a probability mass function),
we can rewrite E(θ) as
E(θ) = pθ (y)g(y).
X
y∈Y
We then obtain
∇E(θ) = g(y)∇θ pθ (y).
X
y∈Y
exp(⟨ϕ(y), θ⟩)
pθ (y) := P ,
y ′ ∈Y exp(⟨ϕ(y ), θ⟩)
′
where θ = f (x, w) ∈ RM .
12.2. Differentiating through expectations 263
The key idea of the score function estimator (SFE), also known as
REINFORCE, is to rewrite ∇E(θ) as an expectation. The estimator is
based on the logarithmic derivative identity
∇θ pθ (y)
∇θ log pθ (y) = ⇐⇒ ∇θ pθ (y) = pθ (y)∇θ log pθ (y).
pθ (y)
Using this identity, we obtain the following gradient estimator.
Proof. Z
∇E(θ) = ∇θ pθ (y)g(y)dy
Y
Z
= pθ (y)g(y)∇θ log pθ (y)dy
Y
= EY ∼pθ [g(Y )∇θ log pθ (Y )].
y∈Y
so that
∇θ log pθ (y) = ey /γ − ∇A(θ).
We therefore see that ∇θ log pθ (y) crucially depends on ∇A(θ), the
gradient of the log-partition. This gradient is available for some
structured sets Y, see e.g. (Mensch and Blondel, 2018), but not in
general.
Suppose both the distribution and the function now depend on θ. When
g is scalar-valued and differentiable w.r.t. θ, we want to differentiate
Baseline
SFE is known to suffer from high variance (Mohamed et al., 2020).
This means that this estimator may require us to draw many samples
from the distribution pθ to work well in practice. One of the simplest
variance reduction technique consists in shifting the function g with a
constant β, called a baseline, to obtain
∇E(θ) = EY ∼pθ [(g(Y ) − β)∇θ log pθ (Y )].
The reason this is still a valid estimator of ∇E(θ) stems from
∇θ pθ (Y )
EY ∼pθ [∇θ log pθ (Y )] = EY ∼pθ
pθ (Y )
= ∇θ EY ∼pθ [1]
= ∇θ 1
= 0,
268 Differentiating through integration
for any valid distribution pθ . The baseline β is often set to the running
average of past values of the function g, though it is neither optimal
nor does it guarantee to lower the variance (Mohamed et al., 2020).
Control variates
Another general technique are control variates. Let us denote the
expectation of a function h : RM → R under the distribution pθ as
H(θ) := EY ∼pθ [h(Y )].
Suppose that H(θ) and its gradient ∇H(θ) are known in closed form.
Then, for any γ ≥ 0, we clearly have
E(θ) = EY ∼pθ [g(Y )]
= EY ∼pθ [g(Y ) − γ(h(Y ) − H(θ))]
= EY ∼pθ [g(Y ) − γh(Y )] + γH(θ)
and therefore
∇E(θ) = ∇θ EY ∼pθ [g(Y ) − γh(Y )] + γ∇H(θ).
Applying SFE, we then obtain
∇E(θ) = EY ∼pθ [(g(Y ) − γh(Y ))∇θ log pθ (Y )] + γ∇H(θ).
Examples of h include a bound on f or a second-order Taylor expansion
of f , assuming that these approximations are easier to integrate than f
(Mohamed et al., 2020).
The expression of the VJP follows by using the SFE on the scalar valued
integrand ⟨g(Y ), u⟩. The JVP is obtained as the adjoint operator of the
VJP and the Jacobian follows.
then, we obtain
Using the previous subsection with g(y, θ) = g(y)∇θ log pθ (θ), we easily
obtain an estimator of the Hessian.
Proposition 12.3 (SFE for the Hessian). Let us define the scalar-
valued function E(θ) := EY ∼pθ [g(Y )]. Then,
The key advantage is that we can now easily compute the derivatives
by mere application of the chain rule, since the parameters µ and σ are
moved from the distribution to the function:
∂
E(µ, σ) = EZ∼Normal(0,1) [g ′ (µ + σZ)]
∂µ
∂
E(µ, σ) = σ · EZ∼Normal(0,1) [g ′ (µ + σZ)].
∂σ
The change of variable
U := µ + σZ (12.4)
Proof. We have
FZ (z) = P(Z ≤ z)
U −µ
=P ≤z
σ
= P(U ≤ µ + σz)
= FU (µ + σz)
Note that, in the above example, the error function erf and its
inverse do not enjoy analytical expressions but autodiff packages usually
provide numerical routines to compute them and differentiate through
them. Nonetheless, one caveat of the inverse transform is that it indeed
requires access to (approximations of) the quantile function and its
derivatives, which may be difficult for complicated distributions.
U ∼ q ⇐⇒ Z ∼ p, U := T (Z).
Pushforward measures
More generally, we can define the notion of pushforward, in the language
of measures. Denote M(Z) the set of measures on a set Z. A measure
α ∈ M(Z), that has a density dα(z) := pZ (z)dz, can be integrated
against a funtcion f as
Z Z
f (z)dα(z) = f (z)pZ (z)dz.
Z Z
functions f ∈ C(U)
Z Z
f (u)dβ(u) = f (T (z))dα(z).
U Z
Sk ∼ pk (· | sdeterm(k) , Srandom(k) )
⇐⇒ Sk ∼ pk (· | si1 , . . . , sipk , Sj1 , . . . , Sjqk )
280 Differentiating through integration
Sk := fk (sdeterm(k) , Srandom(k) )
:= fk (si1 , . . . , sipk , Sj1 , . . . , Sjqk )
sk := fk (sdeterm(k) )
:= fk (si1 , . . . , sipk ).
Special cases
If all nodes are function nodes, we recover computation graphs, reviewed
in Section 4.1.3. If all nodes are distribution nodes, we recover Bayesian
networks, reviewed in Section 10.5.
12.5. Stochastic programs 281
12.5.2 Examples
We now present several examples that illustrate our formalism. We use
the legend below in the following illustrations.
Deterministic Stochastic
Function Sampler
variable variable
S1 ∼ p1 (· | s0 )
S2 := f2 (S1 )
E(s0 ) := E[S2 ]
∇E(s0 ) = ES1 [f2 (S1 )∇s0 log p1 (S1 | s0 )]
S1 ∼ p1
S2 := f2 (S1 , s0 )
E(s0 ) := E[S2 ]
∇E(s0 ) = ES1 [∇s0 f2 (S1 , s0 )]
s1 := f1 (s0 )
S2 ∼ p2 (· | s1 )
S3 := f3 (S2 )
E(s0 ) := E[S3 ]
∇E(s0 ) = ∂f (s0 )∗ ES2 [f3 (S2 )∇s1 log p2 (S2 | s1 )]
• Example 4:
12.5. Stochastic programs 283
s1 := f1 (s0 )
s2 := f2 (s0 )
S3 ∼ p3 (· | s1 )
S4 ∼ p4 (· | s2 , S3 )
S5 := f5 (S4 )
E(s0 ) := E[S5 ] = ES3 [ES4 [f5 (S4 )]]
∇E(s0 ) = ES3 [∂f1 (s0 )∗ ∇s1 log p(S3 | s1 )ES4 [f5 (S4 )]]
+ ES3 [ES4 [∂f2 (s0 )∗ ∇s2 log p4 (S4 |s2 , S3 )f5 (S4 )]]
SK := f (s0 ).
Sk ∼ pk (·|sdeterm(k) , Srandom(k) ),
Our formalism uses two types of nodes: distribution nodes with asso-
ciated conditional distribution pk and function nodes with associated
function fk . It is often possible to convert between node types.
Converting a distribution node into a function node is exactly the
reparametrization trick studied in Section 12.4. We can use transforma-
tions such as the location-scale transform or the inverse transform.
Converting a function node into a distribution node can be done
using the change-of-variables theorem, studied in Section 12.4.5, on a
pushforward distribution.
Because the pathwise estimator has lower variance than SFE, this is
the method of choice when the fk functions are available. The conversion
from distribution node to function node and vice-versa is illustrated in
Fig. 12.1.
Transformation
(location-scale transform, inverse transform)
Change-of-variables theorem
Existence of a solution
First and foremost, the question is whether s(t) is well-defined. For-
tunately, the answer is positive under mild conditions, as shown by
Picard-Lindelöf’s theorem recalled below (Butcher, 2016, Theorem 16).
s(0) = s0
s′ (t) = h(t, s(t)) t ∈ [0, T ],
st = exp(tA)(s0 ),
where exp(A) is the matrix exponential. Hence, the output s(T ) can
be expressed as a simple function of the parameters (A in this case).
However, generally, we do not have access to such analytical solutions,
and, just as for solving optimization problems in Chapter 11, we need
to resort to some iterative algorithms.
Integration methods
To numerically solve an ODE, we can use integration methods,
whose goal is to build a sequence sk that approximates the solution
s(t) at times tk . The simplest integration method is the explicit Euler
method, that approximates the solutions between times tk−1 and tk as
Z tk
s(tk−1 ) − s(tk ) = h(t, s(t), w)dt
tk−1
for a time-step
δk := tk − tk−1 .
288 Differentiating through integration
where Z T
f (x, w) := s(T ) = x + h(t, s(t), w)dt.
0
12.6. Differential equations 289
s(0) = x
s′ (t) = h(t, s(t), w) t ∈ [0, T ],
for Z T
g= ∂3 h(t, s(t), w)∗ r(t)dt
0
and for r solving the adjoint (backward) ODE
0.5
0.0
0.5 1.0 1.5
Figure 12.2: Finding the optimal parameters of an ODE to fit some observed data.
The dots represent the trajectories of a dynamical system observed at regular times
(time is represented here by a gradient color, the lighter the color, the larger the
time). Each line represents the solution of an ODE given by some hyperparameters
w. The objective is to find the hyperparameters of the ODE such that its solution
fits the data points. Green and orange lines fail to do so while the blue line fits
the data. To compute such parameters w, we need to backpropagate through the
solution of the ODE.
The above ODE can then be solved by any integration method. Note,
however, that it requires first computing s(T ) and ∇L(s(T )) by an
integration method. The overall computation of the gradient using
an explicit Euler method to solve forward and backward ODEs is
summarized in Algorithm 12.2.
Algorithm 12.2 naturally looks like the reverse mode of autodiff for
a residual neural networks with shared weights. A striking difference
is that the intermediate computations sk are not kept in memory and,
instead, new variables ŝk are computed along the backward ODE. One
may believe that by switching to continuous time, we solved the memory
issues encountered in reverse-mode autodiff. Unfortunately, this comes
at the cost of numerical stability. As we use a discretization scheme
to recompute the intermediate states backward in time through ŝk in
Algorithm 12.2, we accumulate some truncation errors.
To understand the issue here, consider applying Algorithm 12.2
repeatedly on the same parameters but using ŝ0 instead of s0 = x each
time. In the continuous realm, σ(T ) = s(0). But after discretization,
ŝ0 ≈ σ(T ) does not match s0 . Therefore, by applying Algorithm 12.2
292 Differentiating through integration
with s0 = ŝ0 , we would not get the same output even if in continu-
ous time we naturally should have. This phenomenon is illustrated in
Fig. 12.3. It intuitively shows why Algorithm 12.2 induces some noise
in the estimation of the gradient.
Total backward
discretization
error ...
Local forward
discretization.
Backward approx. error
of the ODE
Figure 12.3: Forward and backward discretizations when using the continuous
adjoint method.
where δ > 0 is some fixed discretization step, tk is the time step (typically
tk = tk−1 +δ), sk is the approximation of s(tk ), and ck is some additional
context variables used by the discretization method to build the iterates.
An explicit Euler method does not have a context, but just as an
optimization method may update some internal states, a discretization
method can update some context variable. The discretization scheme
in Eq. (12.8) is a forward discretization scheme as we took a positive
discretization step. By taking a negative discretization step, we obtain
the corresponding backward discretization scheme, for k ∈ (K, . . . , 1),
Leapfrog method
The (asynchronous) leapfrog method (Zhuang et al., 2021; Mutze,
2013) on the other hand is an example of symmetric reversible discretiza-
tion method. For a constant discretization step δ, given tk−1 , sk−1 , ck−1
and a function h, it computes
δ
t̄k−1 := tk−1 +
2
δ
s̄k−1 := sk−1 + ck−1
2
c̄k−1 := h(t̄k−1 , s̄k−1 )
δ
tk := t̄k−1 +
2
δ
sk := s̄k−1 + c̄k−1
2
ck := 2c̄k−1 − ck−1
M(tk−1 , sk−1 , ck−1 ; h, δ) := (tk , sk , ck ).
One can verify that we indeed have M(tk , sk , ck ; h, −δ) = (tk−1 , sk , ck ).
By using a reversible symmetric discretization scheme in the optimize-
then-discretize approach, we ensure that, at the end of the backward
discretization pass, we recover exactly the original input. Therefore, by
repeating forward and backward discretization schemes we always get
the same gradient, which was not the case for an Euler explicit scheme.
296 Differentiating through integration
then decompose as
∂w f (x, w)∗ [u]
= ∂w s(T, x, w)∗ u
Z T
+ (∂w s(t, x, w)∗ ∂s∗ h(t, s(t, x, w), w)∗ − ∂wt
2
s(t, x, w)∗ )r(t)dt
0
Z T
+ ∂w h(t, s(t, x, w), w)∗ r(t)dt,
0
∂x f (x, w)∗ [u]
= ∂x s(T, x, w)∗ u
Z T
+ (∂x s(t, x, w)∗ ∂s∗ h(t, s(t, x, w), w)∗ − ∂xt
2
s(t, x, w)∗ )r(t)dt
0
Here the second derivative terms ∂wt 2 s(t, x, w)∗ r, ∂ 2 s(t, x, w)∗ r cor-
xt
respond to second derivatives of ⟨s(t, x, w), r⟩. Since the Hessian is
symmetric (Schwartz’s theorem presented in Proposition 2.10), we can
swap the derivatives in t and w or x. Then, to express the gradient
uniquely in terms of first derivatives of s, we use an integration by part
to have for example
Z T Z T
2
∂wt s(t, x, w)∗ r(t)dt = 2
∂tw s(t, x, w)∗ r(t)dt
0 0
= (∂w s(T, x, w)∗ r(T ) − ∂w s(0, x, w)∗ r(0))
Z T
− ∂w s(t, x, w)∗ r(t)∗ ∂t r(t)dt.
0
Since s(0) = x, we have ∂w s(0, x, w)∗ r(0) = 0. The VJP w.r.t. w can
then be written as
∂w f (x, w)∗ [u]
= ∂w s(T, x, w)∗ [u − r(T )]
Z T
+ ∂w s(t, x, w)∗ [∂s h(t, s(t, x, w), w)∗ r(t) + ∂t r(t)]dt
0
Z T
+ ∂w h(t, s(t, x, w), w)∗ r(t)dt.
0
By choosing r(t) to satisfy the adjoint ODE
∂t r(t) = −∂s h(t, s(t, x, w), w)∗ r(t), r(T ) = u,
298 Differentiating through integration
12.7 Summary
• We can also first discretize the problem in such a way that the
gradient can simply be computed by reverse mode auto-diff, ap-
plied on the discretization steps. This is the discretize-then-
optimize approach. The optimize-then-discretize approach has
no memory cost, but discrepancies between the forward and back-
ward discretization passes often lead to numerical errors. The
discretize-then-optimize introduces no such discrepancies but may
come at a large memory cost.
Smoothing programs
13
Smoothing by optimization
301
302 Smoothing by optimization
Existence
The infimal convolution (f □g)(µ) exists if the infimum inf u∈RM f (u) +
g(µ − u) is finite (Bauschke and Combettes, 2017, Proposition 12.6).
A sufficient condition to achieve this is that u 7→ f (u) + g(µ − u) is
convex for all µ ∈ RM . However, this is not a necessary condition. For
13.1. Primal approach 303
Properties
A crucial property of the Moreau envelope envf is that for any convex
function f , it is always a smooth function, even when f itself is not
smooth. By smooth, we formally mean that the resulting function envf
is differentiable everywhere with Lipschitz-continuous gradients. We say
L-smooth, if the gradients are L-Lipshcitz continuous. Such a property
can determine the efficiency of optimization algorithms as reviewed in
Section 15.4. We recap below useful properties of the Moreau envelope.
Proof.
1. This is best seen using the dual approach detailed in Section 13.3.
4. This follows from the fact that the infimum of a jointly convex
function is convex.
5. We have
1
inf envf (µ) = inf ∥µ − u∥22 + f (u)
inf
µ∈RM 2
µ∈RM u∈RM
1
= inf inf ∥µ − u∥22 + f (u)
u∈R µ∈R 2
M M
= inf f (u).
u∈RM
Examples
To illustrate smoothing from the Moreau envelope perspective, we
show how to smooth the 1-norm. In this case, we obtain an analytical
expression for the Moreau envelope.
306 Smoothing by optimization
3.0
Huber loss
2.5
Absolute loss
2.0
1.5
1.0
0.5
0.0
3 2 1 0 1 2 3
Figure 13.1: The Huber loss is the Moreau envelope of the absolute loss.
Figure 13.2: The Moreau envelope is not limited to convex functions. For instance,
the ramp function is continuous but nonconvex, and the step function is not only
nonconvex but also discontinuous. In this figure, we approximately computed the
infimum over u ∈ R in Definition 13.2 by restricting the search on a finite grid, in a
closed interval.
i=1
T
= di ∇envfi (µ).
X
i=1
In the particular case f (u) = (f1 (u1 ), . . . , fT (uT )), we obtain
T
∂envf (µ)∗ [d] = di envfi (µi ).
X
i=1
308 Smoothing by optimization
i=1
Smoothing vector-valued functions by Moreau envelope (or more gen-
erally, by infimal convolution) remains an open area of research. We
will see in Chapter 14 that smoothing by convolution more naturally
supports vector-valued functions.
13.2. Legendre–Fenchel transforms, convex conjugates 309
13.2.1 Definition
Consider the class of affine functions of the form
u 7→ ⟨u, v⟩ − b.
f (u)
v
f * (v)
0 1
u
Figure 13.3: For a fixed slope v, the function u 7→ uv − f ∗ (v) is the tighest affine
lower bound of f with slope v.
0.0
1.0
0.5
f * (v)
f (u)
0.5
1.0 v
f * (v)
0.0
0.0 0.5 1.0 2 1 0 1
u v
Figure 13.4: Left: instead of representing a convex function f by its graph (u, f (u))
for u ∈ dom(f ), we can represent it by the set of tangents with slope v and intercept
−f ∗ (v) for v ∈ dom(f ∗ ). Right: by varying the slope v of all possible tangents, we
obtain a function of the slope v rather than of the original input u. The colors of the
tangents on the left are chosen to match the colors of the vertical lines on the right.
j=1
See for instance Boyd and Vandenberghe (2004) or Beck (2017) for
many more examples.
312 Smoothing by optimization
i=1
13.2.3 Properties
The conjugate enjoys several useful properties, that we now summarize.
Proof.
1. This follows from the fact that v 7→ supu∈C g(u, v) is convex if g
is convex in v. Note that this is true even if g is nonconvex in u.
Here, g(u, v) = ⟨u, v⟩ − f (u), which is affine in v and therefore
convex in v.
2. This follows immediately from Definition 13.3.
3. This follows from Danskin’s theorem, reviewed in Section 11.2.
Another way to see this is by observing that
f ∗ (v) = ⟨g, v⟩ − f (g)
f ∗ (v ′ ) ≥ ⟨g, v ′ ⟩ − f (g),
where g := arg max ⟨u, v⟩ − f (u). Subtracting the two, we obtain
u∈dom(f )
M
f ∗ (v) = fj∗ (vj ).
X
j=1
f ∗ (v) = c · g ∗ (v/c).
f ∗ (v) = g ∗ (v − α) − β.
f ∗ (v) = g ∗ (M −T v).
i=1
fΩ (u) = ⟨u, vΩ
⋆
⟩ − fΩ∗ (vΩ
⋆
) ≥ ⟨u, v ⋆ ⟩ − fΩ∗ (v ⋆ ) = f (u) − Ω(v ⋆ )
and similarly
f (u) − Ω(vΩ
⋆
) = ⟨u, v ⋆ ⟩ − f ∗ (v ⋆ ) − Ω(vΩ
⋆
) ≥ ⟨u, vΩ
⋆
⟩ − fΩ∗ (vΩ
⋆
) = fΩ (u).
Proof. We have
envf = fΩ = fΩ∗ .
Given the equivalence between the primal and dual approaches, using
one approach or the other is mainly a matter of mathematical or
algorithmic convenience, depending on the case.
In this book, we focus on applications of smoothing techniques to dif-
ferentiable programming. For applications to non-smooth optimization,
see Nesterov (2005) and Beck and Teboulle (2012).
Primal approach
ε 1 1
R(u) = ∥u/ε∥22 = ∥u∥22 = Ω(u).
2 2ε ε
We therefore get
1 1 1
fεΩ = f □ Ω = (εf □Ω) = envεf .
ε ε ε
Shannon’s entropy
A definition of information content satisfying the criteria above is
1
I(E) := log = − log p(E).
p(E)
Indeed, − log 1 = 0, − log 0 = ∞ and − log is a decreasing function over
(0, 1]. Using this information content definition leads to Shannon’s
entropy (Shannon, 1948)
H(Y ) = E[I(Y )] = − p(y) log p(y).
X
y∈Y
i=1
Gini’s entropy
As an alternative, we can define information content as
1
I(E) = (1 − p(E)).
2
322 Smoothing by optimization
0.7
0.6
0.5
0.4
0.3
0.2
Tsallis 1 (Shannon)
0.1 Tsallis = 1.5
0.0
Tsallis = 2 (Gini)
0.0 0.2 0.4 0.6 0.8 1.0
Figure 13.5: Tsallis entropies of the distribution π = (π, 1 − π) ∈ △2 , for π ∈ [0, 1].
An entropy is a non-negative concave function that attains its maximum at the
uniform distribution, here (0.5, 0.5). A negative entropy, a.k.a. negentropy, can be used
as a dual regularization function Ω to smooth out a function f when dom(f ∗ ) ⊆ △M .
Tsallis entropies
Given α ≥ 1, a more general information content definition is
1
I(E) = (1 − p(E)α−1 ).
α(α − 1)
13.3. Dual approach 323
so that
M
p
∥v∥pp =
X
vi .
i=1
Tsallis entropies for α → 1 (Shannon entropy), α = 1.5 and α = 2 (Gini
entropy) are illustrated in Fig. 13.5 and Fig. 13.6.
1. H(π) = 0 if π ∈ {e1 , . . . , eM },
2. H is strictly concave,
The softplus
The sparseplus
If we use the regularizer Ω(π) = π(π − 1), which comes from using
Gini’s negentropy with 21 ⟨π, π − 1⟩ with π = (π, 1 − π), we obtain
0,
u ≤ −1
reluΩ (u) = sparseplus(u) = 1
(u + 1)2 , −1 < u < 1 .
4
u≥1
u,
See Fig. 13.8 (left figure) for a comparison of softplus and sparseplus.
326 Smoothing by optimization
min(u) = − max(−u).
and
∇maxΩ (u) = ∇δΩ (u − τ ⋆ 1),
where τ ⋆ is the solution w.r.t. τ of the above min, which satisfies
the root equation
⟨∇δΩ (u − τ ⋆ 1), 1⟩ = 1.
= min τ + δΩ (u − τ 1),
τ ∈R
where we used that we can swap the min and the max, since (u, v) 7→
⟨u, v⟩ − Ω(v) is convex-concave and v ∈ △M is an affine constraint. The
gradient ∇δΩ (u) follows from Danskin’s theorem. The root equation
follows from computing the derivative of τ 7→ τ +δΩ (u−τ 1) and setting
it to zero.
= logsumexp(u)
M
= log
X
euj .
j=1
A unique property of the softmax, which is not the case of all maxΩ
operators, is that it supports associativity.
= ⟨u, π ⋆ ⟩ − Ω(π ⋆ )
where
π ⋆ = sparseargmax(u) := arg min ∥u − π∥22 .
π∈△M
Proof. This follows from the fact that Ω(π) is up to a constant equal
to 12 ∥π∥22 and completing the square.
and τ ⋆ satisfies
M
[ui − τ ]+ = 1.
X
i=1
Using Proposition 13.8 proves the proposition’s first part. Setting the
derivative w.r.t. τ to zero gives the second part.
13.5. Smoothed max operators 331
u2
u2
u2
2.5 2.5 2.5
2.5 0.0 2.5 2.5 0.0 2.5 2.5 0.0 2.5
u1 u1 u1
0 2 4 2 4 0 2 4
Value Value Value
Figure 13.7: Max, softmax and sparsemax functions. The max function has non-
smooth contour lines (set of points {u ∈ R3 : f (u) = c} for some constant c
represented by dashed gray lines). So the gradient along these contour lines switch
suddenly at the corners of the contour lines switch. This shows that the max function
is not differentiable everywhere, namely, non-differentiable on the set of points
{u ∈ R3 : ui = uj for any i ̸= j}. The contour lines of the softmax and sparsemax
functions on the other hand are smooth illustrating that these functions are smooth
counterpart of the max function.
It can be shown (Duchi et al., 2008; Condat, 2016) that the exact
solution τ ⋆ is obtained by
j⋆
1 X
τ ⋆ = ⋆ u[i] − 1 , (13.5)
j i=1
Unlike the logistic function, it can reach the exact values 0 or 1. However,
the function has two kinks, where the function is non-differentiable.
Activations Sigmoids
2 1.0
ReLU Heaviside
SoftPlus Logistic
1 SparsePlus 0.5 SparseSigmoid
0 0.0
2 1 0 1 2 2 1 0 1 2
Figure 13.8: Smoothed ReLU functions and relaxed step functions (sigmoids).
Differentiating the left functions gives the right functions.
The softargmax
exp(u)
argmaxΩ (u) = softargmax(u) = PM ,
j=1 exp(uj )
Proof. We know that maxΩ (u) = logsumexp(u) and that ∇ maxΩ (u) =
argmaxΩ (u). Differentiating logsumexp(u) gives softargmax(u).
13.7. Relaxed argmax operators 335
The sparseargmax
sparseargmax(u) = [u − τ ⋆ ]+ ,
= argmax(u1, u2, 0)
1 2 3
u2
u2
2.5 2.5 2.5
2.5 0.0 2.5 2.5 0.0 2.5 2.5 0.0 2.5
u1 u1 u1
0.0 0.5 1.0 0.0 0.5 1.0 0.0 0.5 1.0
Value Value Value
= softargmax(u1u2, 0)
1 2 3
u2
u2
u2
u2
13.8 Summary
14.1 Convolution
339
340 Smoothing by integration
Averaging perspective
where
1 1 µ−u 2
pµ,σ (u) := κσ (µ − u) = κσ (z) = √ e− 2 ( σ )
2πσ
is the PDF of the Gaussian distribution with mean µ and variance
σ 2 . Therefore, we can see f ∗ κσ as the expectation of f (u) over a
Gaussian centered around µ. This property is true for all translation-
invariant kernels, that correspond to a location-scale family distribution
(e.g., the Laplace distribution). The convolution therefore performs an
averaging with all points, with points nearby µ given more weight by
the distribution. The parameter σ controls the importance we want to
give to farther points. We call this viewpoint averaging, as we replace
f (u) by E[f (U )].
Perturbation perspective
we have
Z ∞
1 z 2
(f ∗ κσ )(µ) := f (µ − z)e− 2 ( σ ) dz
−∞
= EZ∼p0,σ [f (µ − Z)]
= EZ∼p0,σ [f (µ + Z)],
where, in the third line, we used that p0,σ is sign invariant, i.e., p0,σ (z) =
p0,σ (−z). This viewpoint shows that smoothing by convolution with
a Gaussian kernel can also be seen as injecting Gaussian noise or
perturbations to the function’s input.
Limit case
Many times, we work with functions whose convolution does not have
an analytical form. In these cases, we can use a discrete convolution on
a grid of values. For two functions f and g defined over Z, the discrete
convolution is defined by
∞
(f ∗ g)[i] := f [j]g[i − j].
X
j=−∞
j=−∞
342 Smoothing by integration
10
= 0.25
8 = 0.5
= 1.0
6
0
3 2 1 0 1 2 3
Figure 14.1: Smoothing of the signal f [t] := t2 + 0.3 sin(6πt) with a sampled and
renormalized Gaussian kernel.
M
(f ∗ g)[i] = f [i − j]g[j] = (g ∗ f )[i].
X
j=−M
14.1.4 Differentiation
Remarkably, provided that the two functions are integrable with inte-
grable derivatives, the derivative of the convolution satisfies
(f ∗ g)′ = (f ′ ∗ g) = (f ∗ g ′ ),
which simply stems from switching derivative and integral in the defini-
tion of the convolution. Moreover, we have the following proposition.
Fε := Cε {f }
Gε := Cε {g}
Hε := Cε {hε }.
14.1. Convolution 345
Table 14.1: Analogy between Fourier and Legendre transforms. See Proposition 13.3
for more conjugate calculus rules.
(f □g)∗ = f ∗ + g ∗ .
0
3 2 1 0 1 2 3
Figure 14.2: Applying the smoothed conjugate twice gives a smoothed biconjugate
(convex envelope) of the function.
Proof.
1 1
Z
fε (v) := ε log exp ⟨u, v⟩ − f (u)) du
ε ε
1 1 1 1
Z
= ε log exp − ∥u − v∥22 + ∥u∥22 + ∥v∥2 − f (u)) du
2
2ε 2ε 2ε ε
1 1 1 1
Z
= ε log exp − ∥u − v∥22 + ∥u∥22 − f (u)) du + ∥v∥22
2ε 2ε ε 2
1
Z
= ε log Gε (v − u)Qε {f }(u)du + ∥v∥22
2
1
= ε log(Qε {f } ∗ Gε )(v) + ∥v∥2
2
1 1
= ∥v∥ − ε log
2
(v)
2 Qε {f } ∗ Gε
1
= Q−1ε (v)
Qε {f } ∗ Gε
What did we gain from this viewpoint? The convex conjugate can
often be difficult to compute in closed form. If we replace RM with a dis-
crete set S (i.e., a grid), we can then approximate the smoothed convex
350 Smoothing by integration
u∈S
= Kq,
14.3 Examples
3.0 1.0
= 0.5
2.5 = 1.0 0.8
2.0 = 2.0 0.6
1.5
0.4
1.0
0.5 0.2
0.0 0.0
3 2 1 0 1 2 3 3 2 1 0 1 2 3
Figure 14.3: Smoothing of the ReLU and Heaviside functions by convolution with
a Gaussian kernel, for three values of the width σ.
(r ∗ κσ )′ = (r′ ∗ κσ ) = h ∗ κσ = Φσ .
we obtain,
(r ∗ κσ )′′ = (h ∗ κσ )′ = (h′ ∗ κσ ) = δ ∗ κσ = κσ ,
U := µ + σZ.
We then have
U ∼ pµ,σ ,
where pµ,σ is the location-family distribution generated by the noise
distribution p. It is the pushforward distribution of Z through the
transformation (see Section 12.4.4). In this notation, the initial noise
distribution p is then simply p = p0,1 . The perturbed function can then
be expressed from these two perspectives as
= f ∗ κσ (µ),
∇p(z)
Z
EZ∼p [h(Z)∇ log p(Z)] = h(z) p(z)dz
RM p(z)
Z
= h(z)∇p(z)dz.
RM
and therefore
103
SFE
102
SFE with forward difference
SFE with central difference
Gradient error 101
100
10 1
This should not be too surprising, as we already know from the convo-
lution perspective that fσ (µ) = (f ∗ κσ )(µ) → f (µ) when σ → 0. This
recovers the randomized forward-mode estimator already presented in
Section 8.7.
p(z) = exp(−ν(z)),
where
ν(z) = z + γ + exp(−(z + γ)),
and where γ ≈ 0.577 is Euler’s constant. We extend it to a multivariate
distribution by taking M independent centered Gumbel distributions
Z = (Z1 , . . . , Zm ) with associated location-scale family
µ + σZ ∼ pµ,σ ,
where
i(u) := arg max ui
i∈[M ]
ϕ(i) := ei
with ϕ(i) is the one-hot encoding of i ∈ [M ]. It turns out that the
function y(u) perturbed using Gumbel noise enjoys a closed form
expectation, which is nothing else than the softargmax.
362 Smoothing by integration
Y := i(µ + σ · Z) ∈ [M ],
qµ,σ := Categorical(softargmax(µ/σ)).
Moreover, we have
By Remark 14.3, we have that e−µi /σ−Zi ∼ Exp(exp(µi /σ−γ)). One eas-
ily verifies as an exercise, that, for U1 , . . . , UM independent exponential
variables with parameters u1 , . . . , um , we have P(arg mini∈[M ] {Ui } =
k) = uk / M i=1 ui . Hence, we get
P
exp(µk /σ)
P(Y = k) = PM ,
i=1 exp(µi /σ)
that is,
Y ∼ Categorical(softargmax(µ/σ)).
The last claim follows from the distribution of Y and the definition of
ϕ.
f (u) := max ui ,
i∈[M ]
V := f (µ + σ · Z).
is distributed according to
qµ,σ := pσLSE(µ/σ),σ .
Moreover, we have
M
! !
P max {µi + σZi } ≤ t = exp − (exp(µi /σ − γ)) exp(−t/σ)
X
i∈[M ]
i=1
= exp(− exp(−(t − σLSE(µ/σ))/σ − γ)).
For further reading on the Gumbel trick, see Tim Vieira’s great
blog.
14.5.6 Perturb-and-MAP
Previously, we discussed the Gumbel trick in the classification setting,
where Y = [M ]. In the structured prediction setting, outputs are
typically embedded in RM but the output space is very large. That is,
Y ⊆ RM but |Y| ≫ M . Structured outputs are then decoded using a
maximum a-posteriori (MAP) oracle
f (u) := max⟨y, u⟩
y∈Y
The first estimator usually has lower variance. Note that we cannot use
the reparametrization trick this time, since y is discontinuous, contrary
to f .
14.5.7 Gumbel-softmax
where
y(u) := arg max ⟨y, u⟩.
y∈{e1 ,...,eM }
T := softargmaxτ (µ + Z) ∈ △M ,
where π := softargmax(µ).
14.6 Summary
Optimizing differentiable
programs
15
Optimization basics
L⋆ := inf L(w),
w∈W
assuming that the infimum exists (i.e., L(w) is lower bounded). We will
denote a solution, if it exists, by
w⋆ ∈ arg min L(w) := w ∈ W : L(w) = min
′
L(w′ ) .
w∈W w ∈W
L(wt ) − L⋆ ≤ ε. (15.1)
370
15.2. Oracles 371
15.2 Oracles
wt+1 := wt − γ∇L(wt ),
L∥w − v∥22 /2, where the last inequality follows from the smoothness
assumption and standard integration.
1.2
KL(p, q)
1.0 0.5||p q||21
0.8
0.6
0.4
0.2
0.0
0.0 0.2 0.4 0.6 0.8 1.0
ρ∞ := lim ρt .
t→+∞
1.0
Convergence rate R(t) (log scale)
10 2
R(t) = 1/ t (sublinear)
10 8
R(t) = 1/t (sublinear) 0.6
R(t) = 1/t 2 (sublinear)
10 11
R(t) = e t (linear) 0.4
10 14 R(t) = e t2 (superlinear)
0.2
10 17
10 20 0.0
0 20 40 60 80 100 0 20 40 60 80 100
Iteration t Iteration t
Figure 15.4: Left: convergence rates. Right: progress ratios. An algorithm with
sublinear convergence rates eventually eventually stops making progress. An algorithm
with linear convergence rate eventually reaches a state of constant progress. An
algorithm with superlinear convergence rate makes faster progress after each iteration.
erally depend on the dimension of the problem, making them unfit for
high-dimensional problems.
This explains the immense success of first-order algorithms for train-
ing neural networks. Fortunately, using reverse-mode autodiff, as studied
in Chapter 8, it can be shown that computing a gradient has roughly
the same complexity as evaluating the function itself Section 8.3.3
15.6 Summary
384
16.1. Gradient descent 385
w2
w1
0.5 0.5
1 0 1 1 0 1
w1 w1
βγ 2
L(wt+1 ) ≤ L(wt ) − γ∥∇L(wt )∥22 + ∥∇L(wt )∥22 .
2
Therefore, for β-smooth functions, by selecting γ ≤ β1 , we get that
γ
L(wt+1 ) − L(wt ) ≤ − ∥∇L(wt )∥22 ,
2
which illustrates the main mechanism behind gradient descent: each
iteration decreases the objective by a constant times the norm of the
386 First-order optimization
Non-convex case
Without further assumptions, i.e., in the non-convex case, the above
result (i.e., convergence to a stationary point, measured by the gradient
norm) is the best we may get in theory. Denoting Ts (ε) the number
of iterations needed for a gradient descent to output a point that is
ε-stationary, i.e., ∥∇L(ŵ)∥2 ≤ ε, we have Ts (ε) ≤ O(ε−2 ).
Convex case
By adding a convexity assumption on the objective, we can use the lower
bound provided by the convexity assumption to ensure convergence to
a minimum. Namely, for a β-smooth and convex function f , and with
stepsize γ ≤ 1/β, we have that (Nesterov, 2018)
1
L(wT ) − L⋆ ≤ ∥w0 − w⋆ ∥22 .
γT
That is, we get a sublinear convergence rate, and the associated compu-
tational complexity to find a minimum is T (ε) = O(1/ε).
That is, we obtain a linear convergence rate and the associated computa-
tional complexity is T (ε) = O(ln ε−1 ). The above convergence rates may
be further refined (Nesterov, 2018); we focused above on the simplest
result for clarity.
Strong convexity can also be replaced by a weaker assumption,
gradient-dominating property (Polyak, 1963), i.e., ∥∇L(v)∥22 ≥ c(L(v)−
L⋆ ) for some constant c and any v ∈ W. A convex, gradient-dominating
function can also be minimized at a linear rate.
v t+1 := νv t − γ∇L(wt )
wt+1 := wt + v t+1 .
w2
w1
0.5
1 0 1
w1
as a empirical distribution ρ = ρn
n
1X
L(w) = ES∼ρn [L(w; S)] = L(w; (Xi , Yi )).
n i=1
In this case, we see that the full gradient ∇L(w), as needed by gradient
descent, is the average of the individual gradients. That is, the cost of
computing ∇L(w) is proportional to the number of training points n. For
n very large, that is a very large amount of samples, this computational
cost can be prohibitive. Stochastic gradients circumvent this issue.
constant term depending on the variance of the oracle and the stepsize.
One can diminish the variance by considering mini-batches: if the
variance of a single stochastic gradient is σ12 , considering a mini-batch
of m gradients reduces the variance of the corresponding oracle to
σm = σ12 /m. To decrease the additional term, one may also decrease
the stepsizes over the iterations. For example, by choosing a decreasing
stepsize like γ t = t−1/2 , the convergence rate is then of the order
O((∥w0 − w⋆ ∥22 + σ 2 ln t)/ (t)). The stepsize can also be selected as
p
v t+1 := νv t + g(wt ; S t )
wt+1 := wt − γv t+1 .
v t+1 := νv t + g(wt + νv t ; S t )
wt+1 := wt − γv t+1 .
392 First-order optimization
mt+1 := ν1 mt + (1 − ν1 )g t
v t+1 := ν2 v t + (1 − ν2 )(g t )2
m̂t+1 := mt+1 /(1 − ν1t )
v̂ t+1 := v t+1 /(1 − ν2t )
p
wt+1 := wt − γ m̂t+1 / v̂ t+1 + ε ,
⟨∇L(w⋆ ), w − w⋆ ⟩ ≥ 0 ∀w ∈ C.
• If C = RP , we obviously have
projC (w) = w.
• If Ω(w) = 0, we have
proxγΩ (w) = w.
which is used in the group lasso (Yuan and Lin, 2006) and can be
used to encourage group sparsity.
For a review of more proximal operators, see for instance (Bach et al.,
2012; Parikh, Boyd, et al., 2014).
16.5 Summary
wt+1 := wt − γ t B t ∇L(wt ),
399
400 Second-order optimization
wt+1 = wt − dt ,
17.1.5 Linesearch
In practice, we may not have access to an initial point close enough
from the minimizer. In that case, even for strictly convex functions for
402 Second-order optimization
In that case, an estimate of the Hessian can be constructed just like for
the gradient using that
h i
ES∼ρ ∇2 L(w; S) = ∇2 L(w).
404 Second-order optimization
Denote then
ηt
wt+1 := arg min q t (J t w + δ t ) + ∥w − wt ∥22 .
w∈W 2
17.2.3 Linesearch
Similarly to Newton’s method, the iterates of a Gauss-Newton method
may diverge when used alone. However, the direction −(∇2GN (ℓ◦f )(wt )+
η t I)−1 ∇L(wt ) defines a descent direction for any η t > 0 and can be
combined with a stepsize γ t (typically chosen using a linesearch) to
obtain iterates of the form
wt+1 = wt − γ t (∇2GN (ℓ ◦ f )(wt ) + η t I)−1 ∇L(wt ).
Negative log-likelihood
We consider objectives of the form
17.4.1 BFGS
A celebrated example of quasi-Newton method is the BFGS method
(Broyden, 1970; Fletcher, 1970; Goldfarb, 1970; Shanno, 1970), whose
acronym follows from its author names. The rationale of the BFGS
update stems once again from a variational viewpoint. We wish to
build a simple quadratic model of the objective ht (w) = L(wt ) +
⟨∇L(wt ), w − wt ⟩ + 12 ⟨w − wt , Qt (w − wt )⟩ for some Qt built along
the iterations rather than taken as ∇2 L(wt ). One desirable property of
such quadratic model would be that its gradients at consecutive iterates
match the gradients of the original function, i.e., ∇ht (wt ) = ∇L(wt )
and ∇ht (wt−1 ) = ∇L(wt−1 ). A simpler condition, called the secant
condition consists in considering the differences of these vectors, that
is, ensuring that
∇ht (wt ) − ∇ht (wt−1 ) = ∇L(wt ) − ∇L(wt−1 )
⇐⇒ Qt (wt − wt−1 ) = ∇L(wt ) − ∇L(wt−1 )
⇐⇒ wt − wt−1 = B t (∇L(wt ) − ∇L(wt−1 )),
for B t = (Qt )−1 . Building B t , a surrogate of the inverse of the Hessian
satisfying the secant equation, can then be done as
B t+1 := I −ρt st (y t )⊤ B t I −ρt st (y t )⊤ + ρt st (st )⊤
where
st := wt+1 − wt
y t := ∇L(wt+1 ) − ∇L(wt )
1
ρt := t t .
⟨s , y ⟩
17.5. Approximate Hessian diagonal inverse preconditionners 411
17.6 Summary
wt+1 := wt − γ t B t ∇L(wt )
We introduce in this section dual norms, since they are useful in this
book.
413
414 Duality
Proof.
In the case when f (w) = Aw, where A is a linear map, and when
both ℓ and R are convex, we can state a much stronger result.
Table 18.1: Examples of loss conjugates. For regression losses (squared, absolute),
where yi ∈ RM , we define ti = ϕ(yi ) = yi . For classification losses (logistic, per-
ceptron, hinge), where yi ∈ [M ], we define ti = ϕ(yi ) = eyi . To simplify some
expressions, we defined the change of variable µi := yi − αi .
Ai w := W xi ,
where u, v ∈ dom(f ).
D D D
uj X
Bf (u, v) = uj log uj +
X X
− vj ,
j=1
vj j=1 j=1
Properties
Bregman divergences enjoy several useful properties.
f (u)
Df (u, v)
f (v) + h∇f (v), u − vi
v u
3. From the fact that u 7→ Bf (u, v) is the sum of f (u) and a linear
function of u.
18.5 Summary
422
References 423