0% found this document useful (0 votes)
10 views451 pages

可微分编程deepmind

Uploaded by

qi wang
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
10 views451 pages

可微分编程deepmind

Uploaded by

qi wang
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 451

arXiv:2403.14606v2 [cs.

LG] 24 Jul 2024

The Elements of
Differentiable Programming

Mathieu Blondel
Google DeepMind
[email protected]
Vincent Roulet
Google DeepMind
[email protected]

Draft version 2 (last update: July 24, 2024)


Contents

1 Introduction 6
1.1 What is differentiable programming? . . . . . . . . . . . . 6
1.2 Book goals and scope . . . . . . . . . . . . . . . . . . . . 9
1.3 Intended audience . . . . . . . . . . . . . . . . . . . . . . 10
1.4 How to read this book? . . . . . . . . . . . . . . . . . . . 10
1.5 Related work . . . . . . . . . . . . . . . . . . . . . . . . . 10

I Fundamentals 12

2 Differentiation 13
2.1 Univariate functions . . . . . . . . . . . . . . . . . . . . . 13
2.1.1 Derivatives . . . . . . . . . . . . . . . . . . . . . . 13
2.1.2 Calculus rules . . . . . . . . . . . . . . . . . . . . 17
2.1.3 Leibniz’s notation . . . . . . . . . . . . . . . . . . 19
2.2 Multivariate functions . . . . . . . . . . . . . . . . . . . . 20
2.2.1 Directional derivatives . . . . . . . . . . . . . . . . 20
2.2.2 Gradients . . . . . . . . . . . . . . . . . . . . . . 21
2.2.3 Jacobians . . . . . . . . . . . . . . . . . . . . . . 25
2.3 Linear differentiation maps . . . . . . . . . . . . . . . . . 30
2.3.1 The need for linear maps . . . . . . . . . . . . . . 31
2.3.2 Euclidean spaces . . . . . . . . . . . . . . . . . . . 32
2.3.3 Linear maps and their adjoints . . . . . . . . . . . 33
2.3.4 Jacobian-vector products . . . . . . . . . . . . . . 33
2.3.5 Vector-Jacobian products . . . . . . . . . . . . . . 35
2.3.6 Chain rule . . . . . . . . . . . . . . . . . . . . . . 36
2.3.7 Functions of multiple inputs (fan-in) . . . . . . . . 36
2.3.8 Functions with multiple outputs (fan-out) . . . . . 38
2.3.9 Extensions to non-Euclidean linear spaces . . . . . 39
2.4 Second-order differentiation . . . . . . . . . . . . . . . . . 40
2.4.1 Second derivatives . . . . . . . . . . . . . . . . . . 40
2.4.2 Second directional derivatives . . . . . . . . . . . . 41
2.4.3 Hessians . . . . . . . . . . . . . . . . . . . . . . . 42
2.4.4 Hessian-vector products . . . . . . . . . . . . . . . 43
2.4.5 Second-order Jacobians . . . . . . . . . . . . . . . 44
2.5 Higher-order differentiation . . . . . . . . . . . . . . . . . 45
2.5.1 Higher-order derivatives . . . . . . . . . . . . . . . 45
2.5.2 Higher-order directional derivatives . . . . . . . . . 45
2.5.3 Higher-order Jacobians . . . . . . . . . . . . . . . 46
2.5.4 Taylor expansions . . . . . . . . . . . . . . . . . . 46
2.6 Differential geometry . . . . . . . . . . . . . . . . . . . . 47
2.6.1 Differentiability on manifolds . . . . . . . . . . . . 48
2.6.2 Tangent spaces and pushforward operators . . . . . 48
2.6.3 Cotangent spaces and pullback operators . . . . . 50
2.7 Generalized derivatives . . . . . . . . . . . . . . . . . . . 53
2.7.1 Rademacher’s theorem . . . . . . . . . . . . . . . 53
2.7.2 Clarke derivatives . . . . . . . . . . . . . . . . . . 54
2.8 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 56

3 Probabilistic learning 59
3.1 Probability distributions . . . . . . . . . . . . . . . . . . . 59
3.1.1 Discrete probability distributions . . . . . . . . . . 59
3.1.2 Continuous probability distributions . . . . . . . . 60
3.2 Maximum likelihood estimation . . . . . . . . . . . . . . . 61
3.2.1 Negative log-likelihood . . . . . . . . . . . . . . . 61
3.2.2 Consistency w.r.t. the Kullback-Leibler divergence . 61
3.3 Probabilistic supervised learning . . . . . . . . . . . . . . 62
3.3.1 Conditional probability distributions . . . . . . . . 62
3.3.2 Inference . . . . . . . . . . . . . . . . . . . . . . . 62
3.3.3 Binary classification . . . . . . . . . . . . . . . . . 63
3.3.4 Multiclass classification . . . . . . . . . . . . . . . 65
3.3.5 Regression . . . . . . . . . . . . . . . . . . . . . . 67
3.3.6 Multivariate regression . . . . . . . . . . . . . . . 68
3.3.7 Integer regression . . . . . . . . . . . . . . . . . . 69
3.3.8 Loss functions . . . . . . . . . . . . . . . . . . . . 70
3.4 Exponential family distributions . . . . . . . . . . . . . . . 71
3.4.1 Definition . . . . . . . . . . . . . . . . . . . . . . 71
3.4.2 The log-partition function . . . . . . . . . . . . . . 72
3.4.3 Maximum entropy principle . . . . . . . . . . . . . 74
3.4.4 Maximum likelihood estimation . . . . . . . . . . . 75
3.4.5 Probabilistic learning with exponential families . . . 76
3.5 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 77

II Differentiable programs 79

4 Parameterized programs 80
4.1 Representing computer programs . . . . . . . . . . . . . . 80
4.1.1 Computation chains . . . . . . . . . . . . . . . . . 80
4.1.2 Directed acylic graphs . . . . . . . . . . . . . . . . 81
4.1.3 Computer programs as DAGs . . . . . . . . . . . . 83
4.1.4 Arithmetic circuits . . . . . . . . . . . . . . . . . . 85
4.2 Feedforward networks . . . . . . . . . . . . . . . . . . . . 86
4.3 Multilayer perceptrons . . . . . . . . . . . . . . . . . . . . 87
4.3.1 Combining affine layers and activations . . . . . . . 87
4.3.2 Link with generalized linear models . . . . . . . . . 87
4.4 Activation functions . . . . . . . . . . . . . . . . . . . . . 88
4.4.1 ReLU and softplus . . . . . . . . . . . . . . . . . . 88
4.4.2 Max pooling and log-sum-exp . . . . . . . . . . . . 89
4.4.3 Sigmoids: binary step and logistic functions . . . . 90
4.4.4 Probability mappings: argmax and softargmax . . . 91
4.5 Residual neural networks . . . . . . . . . . . . . . . . . . 93
4.6 Recurrent neural networks . . . . . . . . . . . . . . . . . . 94
4.6.1 Vector to sequence . . . . . . . . . . . . . . . . . 95
4.6.2 Sequence to vector . . . . . . . . . . . . . . . . . 96
4.6.3 Sequence to sequence (aligned) . . . . . . . . . . . 96
4.6.4 Sequence to sequence (unaligned) . . . . . . . . . 97
4.7 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 98

5 Control flows 99
5.1 Comparison operators . . . . . . . . . . . . . . . . . . . . 99
5.2 Soft inequality operators . . . . . . . . . . . . . . . . . . 101
5.2.1 Heuristic definition . . . . . . . . . . . . . . . . . 101
5.2.2 Stochastic process perspective . . . . . . . . . . . 102
5.3 Soft equality operators . . . . . . . . . . . . . . . . . . . 104
5.3.1 Heuristic definition . . . . . . . . . . . . . . . . . 104
5.3.2 Stochastic process perspective . . . . . . . . . . . 106
5.3.3 Gaussian process perspective . . . . . . . . . . . . 109
5.4 Logical operators . . . . . . . . . . . . . . . . . . . . . . 110
5.5 Continuous extensions of logical operators . . . . . . . . . 111
5.5.1 Probabilistic continuous extension . . . . . . . . . 111
5.5.2 Triangular norms and co-norms . . . . . . . . . . . 113
5.6 If-else statements . . . . . . . . . . . . . . . . . . . . . . 114
5.6.1 Differentiating through branch variables . . . . . . 115
5.6.2 Differentiating through predicate variables . . . . . 116
5.6.3 Continuous relaxations . . . . . . . . . . . . . . . 117
5.7 Else-if statements . . . . . . . . . . . . . . . . . . . . . . 120
5.7.1 Encoding K branches . . . . . . . . . . . . . . . . 120
5.7.2 Conditionals . . . . . . . . . . . . . . . . . . . . . 121
5.7.3 Differentiating through branch variables . . . . . . 122
5.7.4 Differentiating through predicate variables . . . . . 123
5.7.5 Continuous relaxations . . . . . . . . . . . . . . . 124
5.8 For loops . . . . . . . . . . . . . . . . . . . . . . . . . . . 125
5.9 Scan functions . . . . . . . . . . . . . . . . . . . . . . . . 127
5.10 While loops . . . . . . . . . . . . . . . . . . . . . . . . . 128
5.10.1 While loops as cyclic graphs . . . . . . . . . . . . 128
5.10.2 Unrolled while loops . . . . . . . . . . . . . . . . . 129
5.10.3 Markov chain perspective . . . . . . . . . . . . . . 132
5.11 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 135
6 Data structures 136
6.1 Lists . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 136
6.1.1 Basic operations . . . . . . . . . . . . . . . . . . . 137
6.1.2 Operations on variable-length lists . . . . . . . . . 138
6.1.3 Continuous relaxations using soft indexing . . . . . 140
6.2 Dictionaries . . . . . . . . . . . . . . . . . . . . . . . . . 143
6.2.1 Basic operations . . . . . . . . . . . . . . . . . . . 143
6.2.2 Continuous relaxation using kernel regression . . . 145
6.2.3 Discrete probability distribution perspective . . . . 146
6.2.4 Link with attention in Transformers . . . . . . . . 147
6.3 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 148

III Differentiating through programs 150

7 Finite differences 151


7.1 Forward differences . . . . . . . . . . . . . . . . . . . . . 151
7.2 Backward differences . . . . . . . . . . . . . . . . . . . . 152
7.3 Central differences . . . . . . . . . . . . . . . . . . . . . . 153
7.4 Higher-accuracy finite differences . . . . . . . . . . . . . . 154
7.5 Higher-order finite differences . . . . . . . . . . . . . . . . 155
7.6 Complex-step derivatives . . . . . . . . . . . . . . . . . . 156
7.7 Complexity . . . . . . . . . . . . . . . . . . . . . . . . . . 157
7.8 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 157

8 Automatic differentiation 159


8.1 Computation chains . . . . . . . . . . . . . . . . . . . . . 159
8.1.1 Forward-mode . . . . . . . . . . . . . . . . . . . . 160
8.1.2 Reverse-mode . . . . . . . . . . . . . . . . . . . . 162
8.1.3 Complexity of entire Jacobians . . . . . . . . . . . 167
8.2 Feedforward networks . . . . . . . . . . . . . . . . . . . . 169
8.2.1 Computing the adjoint . . . . . . . . . . . . . . . 169
8.2.2 Computing the gradient . . . . . . . . . . . . . . . 170
8.3 Computation graphs . . . . . . . . . . . . . . . . . . . . . 172
8.3.1 Forward-mode . . . . . . . . . . . . . . . . . . . . 172
8.3.2 Reverse-mode . . . . . . . . . . . . . . . . . . . . 173
8.3.3 Complexity, the Baur-Strassen theorem . . . . . . . 173
8.4 Implementation . . . . . . . . . . . . . . . . . . . . . . . 174
8.4.1 Primitive functions . . . . . . . . . . . . . . . . . 174
8.4.2 Closure under function composition . . . . . . . . 175
8.4.3 Examples of JVPs and VJPs . . . . . . . . . . . . 176
8.4.4 Automatic linear transposition . . . . . . . . . . . 177
8.5 Checkpointing . . . . . . . . . . . . . . . . . . . . . . . . 178
8.5.1 Recursive halving . . . . . . . . . . . . . . . . . . 179
8.5.2 Dynamic programming . . . . . . . . . . . . . . . 181
8.5.3 Online checkpointing . . . . . . . . . . . . . . . . 183
8.6 Reversible layers . . . . . . . . . . . . . . . . . . . . . . . 184
8.6.1 General case . . . . . . . . . . . . . . . . . . . . . 184
8.6.2 Case of orthonormal JVPs . . . . . . . . . . . . . 184
8.7 Randomized forward-mode estimator . . . . . . . . . . . . 185
8.8 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 185

9 Second-order automatic differentiation 187


9.1 Hessian-vector products . . . . . . . . . . . . . . . . . . . 187
9.1.1 Four possible methods . . . . . . . . . . . . . . . 187
9.1.2 Complexity . . . . . . . . . . . . . . . . . . . . . . 188
9.2 Gauss-Newton matrix . . . . . . . . . . . . . . . . . . . . 192
9.2.1 An approximation of the Hessian . . . . . . . . . . 192
9.2.2 Gauss-Newton chain rule . . . . . . . . . . . . . . 193
9.2.3 Gauss-Newton vector product . . . . . . . . . . . . 193
9.2.4 Gauss-Newton matrix factorization . . . . . . . . . 194
9.2.5 Stochastic setting . . . . . . . . . . . . . . . . . . 195
9.3 Fisher information matrix . . . . . . . . . . . . . . . . . . 195
9.3.1 Definition using the score function . . . . . . . . . 195
9.3.2 Link with the Hessian . . . . . . . . . . . . . . . . 196
9.3.3 Equivalence with the Gauss-Newton matrix . . . . 196
9.4 Inverse-Hessian vector product . . . . . . . . . . . . . . . 198
9.4.1 Definition as a linear map . . . . . . . . . . . . . . 198
9.4.2 Implementation with matrix-free linear solvers . . . 198
9.4.3 Complexity . . . . . . . . . . . . . . . . . . . . . . 199
9.5 Second-order backpropagation . . . . . . . . . . . . . . . 200
9.5.1 Second-order Jacobian chain rule . . . . . . . . . . 200
9.5.2 Computation chains . . . . . . . . . . . . . . . . . 202
9.5.3 Fan-in and fan-out . . . . . . . . . . . . . . . . . 203
9.6 Block diagonal approximations . . . . . . . . . . . . . . . 204
9.6.1 Feedforward networks . . . . . . . . . . . . . . . . 204
9.6.2 Computation graphs . . . . . . . . . . . . . . . . . 206
9.7 Diagonal approximations . . . . . . . . . . . . . . . . . . 206
9.7.1 Computation chains . . . . . . . . . . . . . . . . . 207
9.7.2 Computation graphs . . . . . . . . . . . . . . . . . 208
9.8 Randomized estimators . . . . . . . . . . . . . . . . . . . 209
9.8.1 Girard-Hutchinson estimator . . . . . . . . . . . . 209
9.8.2 Bartlett estimator for the factorization . . . . . . . 210
9.8.3 Bartlett estimator for the diagonal . . . . . . . . . 211
9.9 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 212

10 Inference in graphical models as differentiation 213


10.1 Chain rule of probability . . . . . . . . . . . . . . . . . . . 213
10.2 Conditional independence . . . . . . . . . . . . . . . . . . 214
10.3 Inference problems . . . . . . . . . . . . . . . . . . . . . . 215
10.3.1 Joint probability distributions . . . . . . . . . . . . 215
10.3.2 Likelihood . . . . . . . . . . . . . . . . . . . . . . 215
10.3.3 Maximum a-posteriori inference . . . . . . . . . . . 215
10.3.4 Marginal inference . . . . . . . . . . . . . . . . . . 216
10.3.5 Expectation, convex hull, marginal polytope . . . . 216
10.3.6 Complexity of brute force . . . . . . . . . . . . . . 218
10.4 Markov chains . . . . . . . . . . . . . . . . . . . . . . . . 218
10.4.1 The Markov property . . . . . . . . . . . . . . . . 219
10.4.2 Time-homogeneous Markov chains . . . . . . . . . 221
10.4.3 Higher-order Markov chains . . . . . . . . . . . . . 222
10.5 Bayesian networks . . . . . . . . . . . . . . . . . . . . . . 222
10.5.1 Expressing variable dependencies using DAGs . . . 222
10.5.2 Parameterizing Bayesian networks . . . . . . . . . 223
10.5.3 Ancestral sampling . . . . . . . . . . . . . . . . . 224
10.6 Markov random fields . . . . . . . . . . . . . . . . . . . . 224
10.6.1 Expressing factors using undirected graphs . . . . . 224
10.6.2 MRFs as exponential family distributions . . . . . . 225
10.6.3 Conditional random fields . . . . . . . . . . . . . . 227
10.6.4 Sampling . . . . . . . . . . . . . . . . . . . . . . . 227
10.7 Inference on chains . . . . . . . . . . . . . . . . . . . . . 227
10.7.1 The forward-backward algorithm . . . . . . . . . . 228
10.7.2 The Viterbi algorithm . . . . . . . . . . . . . . . . 229
10.8 Inference on trees . . . . . . . . . . . . . . . . . . . . . . 231
10.9 Inference as differentiation . . . . . . . . . . . . . . . . . 232
10.9.1 Inference as gradient of the log-partition . . . . . . 232
10.9.2 Semirings and softmax operators . . . . . . . . . . 233
10.9.3 Inference as backpropagation . . . . . . . . . . . . 235
10.10Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 237

11 Differentiating through optimization 239


11.1 Implicit functions . . . . . . . . . . . . . . . . . . . . . . 239
11.1.1 Optimization problems . . . . . . . . . . . . . . . 240
11.1.2 Nonlinear equations . . . . . . . . . . . . . . . . . 240
11.1.3 Application to bilevel optimization . . . . . . . . . 240
11.2 Envelope theorems . . . . . . . . . . . . . . . . . . . . . . 241
11.2.1 Danskin’s theorem . . . . . . . . . . . . . . . . . . 242
11.2.2 Rockafellar’s theorem . . . . . . . . . . . . . . . . 243
11.3 Implicit function theorem . . . . . . . . . . . . . . . . . . 244
11.3.1 Univariate functions . . . . . . . . . . . . . . . . . 244
11.3.2 Multivariate functions . . . . . . . . . . . . . . . . 246
11.3.3 JVP and VJP of implicit functions . . . . . . . . . 247
11.3.4 Proof of the implicit function theorem . . . . . . . 248
11.4 Adjoint state method . . . . . . . . . . . . . . . . . . . . 249
11.4.1 Differentiating nonlinear equations . . . . . . . . . 249
11.4.2 Relation with envelope theorems . . . . . . . . . . 250
11.4.3 Proof using the method of Lagrange multipliers . . 251
11.4.4 Proof using the implicit function theorem . . . . . 251
11.4.5 Reverse mode as adjoint method with backsubstitution252
11.5 Inverse function theorem . . . . . . . . . . . . . . . . . . 255
11.5.1 Differentiating inverse functions . . . . . . . . . . 255
11.5.2 Link with the implicit function theorem . . . . . . 255
11.5.3 Proof of inverse function theorem . . . . . . . . . 255
11.6 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 257
12 Differentiating through integration 258
12.1 Differentiation under the integral sign . . . . . . . . . . . 258
12.2 Differentiating through expectations . . . . . . . . . . . . 259
12.2.1 Parameter-independent distributions . . . . . . . . 259
12.2.2 Parameter-dependent distributions . . . . . . . . . 260
12.2.3 Application to expected loss functions . . . . . . . 262
12.2.4 Application to experimental design . . . . . . . . . 263
12.3 Score function estimators, REINFORCE . . . . . . . . . . 264
12.3.1 Scalar-valued functions . . . . . . . . . . . . . . . 264
12.3.2 Variance reduction . . . . . . . . . . . . . . . . . . 267
12.3.3 Vector-valued functions . . . . . . . . . . . . . . . 268
12.3.4 Second derivatives . . . . . . . . . . . . . . . . . . 269
12.4 Path gradient estimators, reparametrization trick . . . . . 270
12.4.1 Location-scale transforms . . . . . . . . . . . . . . 270
12.4.2 Differentiable transforms . . . . . . . . . . . . . . 272
12.4.3 Inverse transforms . . . . . . . . . . . . . . . . . . 273
12.4.4 Pushforward operators . . . . . . . . . . . . . . . 275
12.4.5 Change-of-variables theorem . . . . . . . . . . . . 277
12.5 Stochastic programs . . . . . . . . . . . . . . . . . . . . . 278
12.5.1 Stochastic computation graphs . . . . . . . . . . . 278
12.5.2 Examples . . . . . . . . . . . . . . . . . . . . . . 281
12.5.3 Unbiased gradient estimators . . . . . . . . . . . . 283
12.5.4 Local vs. global expectations . . . . . . . . . . . . 284
12.6 Differential equations . . . . . . . . . . . . . . . . . . . . 286
12.6.1 Parameterized differential equations . . . . . . . . 286
12.6.2 Continuous adjoint method . . . . . . . . . . . . . 288
12.6.3 Gradients via the continuous adjoint method . . . . 290
12.6.4 Gradients via reverse-mode on discretization . . . . 293
12.6.5 Reversible discretization schemes . . . . . . . . . . 294
12.6.6 Proof of the continuous adjoint method . . . . . . 296
12.7 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 298

IV Smoothing programs 300

13 Smoothing by optimization 301


13.1 Primal approach . . . . . . . . . . . . . . . . . . . . . . . 301
13.1.1 Infimal convolution . . . . . . . . . . . . . . . . . 302
13.1.2 Moreau envelope . . . . . . . . . . . . . . . . . . 303
13.1.3 Vector-valued functions . . . . . . . . . . . . . . . 307
13.2 Legendre–Fenchel transforms, convex conjugates . . . . . . 309
13.2.1 Definition . . . . . . . . . . . . . . . . . . . . . . 309
13.2.2 Closed-form examples . . . . . . . . . . . . . . . . 310
13.2.3 Properties . . . . . . . . . . . . . . . . . . . . . . 312
13.2.4 Conjugate calculus . . . . . . . . . . . . . . . . . 314
13.2.5 Fast Legendre transform . . . . . . . . . . . . . . 314
13.3 Dual approach . . . . . . . . . . . . . . . . . . . . . . . . 315
13.3.1 Duality between strong convexity and smoothness . 315
13.3.2 Smoothing by dual regularization . . . . . . . . . . 316
13.3.3 Equivalence between primal and dual regularizations 318
13.3.4 Regularization scaling . . . . . . . . . . . . . . . . 319
13.3.5 Generalized entropies . . . . . . . . . . . . . . . . 320
13.4 Smoothed ReLU functions . . . . . . . . . . . . . . . . . 324
13.5 Smoothed max operators . . . . . . . . . . . . . . . . . . 326
13.5.1 Definition and properties . . . . . . . . . . . . . . 326
13.5.2 Reduction to root finding . . . . . . . . . . . . . . 327
13.5.3 The softmax . . . . . . . . . . . . . . . . . . . . . 328
13.5.4 The sparsemax . . . . . . . . . . . . . . . . . . . . 329
13.5.5 Recovering smoothed ReLU functions . . . . . . . 332
13.6 Relaxed step functions (sigmoids) . . . . . . . . . . . . . . 332
13.7 Relaxed argmax operators . . . . . . . . . . . . . . . . . . 333
13.8 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 337

14 Smoothing by integration 339


14.1 Convolution . . . . . . . . . . . . . . . . . . . . . . . . . 339
14.1.1 Convolution operators . . . . . . . . . . . . . . . . 339
14.1.2 Convolution with a kernel . . . . . . . . . . . . . . 340
14.1.3 Discrete convolution . . . . . . . . . . . . . . . . . 341
14.1.4 Differentiation . . . . . . . . . . . . . . . . . . . . 343
14.1.5 Multidimensional convolution . . . . . . . . . . . . 343
14.1.6 Link between convolution and infimal convolution . 343
14.1.7 The soft infimal convolution . . . . . . . . . . . . 344
14.1.8 The soft Moreau envelope . . . . . . . . . . . . . . 345
14.2 Fourier and Laplace transforms . . . . . . . . . . . . . . . 346
14.2.1 Convolution theorem . . . . . . . . . . . . . . . . 346
14.2.2 Link between Fourier and Legendre transforms . . . 346
14.2.3 The soft Legendre-Fenchel transform . . . . . . . . 347
14.3 Examples . . . . . . . . . . . . . . . . . . . . . . . . . . . 350
14.3.1 Smoothed step function . . . . . . . . . . . . . . . 350
14.3.2 Smoothed ReLU function . . . . . . . . . . . . . . 351
14.4 Perturbation of blackbox functions . . . . . . . . . . . . . 353
14.4.1 Expectation in a location-scale family . . . . . . . 353
14.4.2 Gradient estimation by reparametrization . . . . . 354
14.4.3 Gradient estimation by SFE, Stein’s lemma . . . . 355
14.4.4 Link between reparametrization and SFE . . . . . . 356
14.4.5 Variance reduction and evolution strategies . . . . 357
14.4.6 Zero-temperature limit . . . . . . . . . . . . . . . 358
14.5 Gumbel tricks . . . . . . . . . . . . . . . . . . . . . . . . 359
14.5.1 The Gumbel distribution . . . . . . . . . . . . . . 359
14.5.2 Perturbed comparison . . . . . . . . . . . . . . . . 360
14.5.3 Perturbed argmax . . . . . . . . . . . . . . . . . . 361
14.5.4 Perturbed max . . . . . . . . . . . . . . . . . . . . 362
14.5.5 Gumbel trick for sampling . . . . . . . . . . . . . . 363
14.5.6 Perturb-and-MAP . . . . . . . . . . . . . . . . . . 364
14.5.7 Gumbel-softmax . . . . . . . . . . . . . . . . . . . 366
14.6 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 367

V Optimizing differentiable programs 369

15 Optimization basics 370


15.1 Objective functions . . . . . . . . . . . . . . . . . . . . . 370
15.2 Oracles . . . . . . . . . . . . . . . . . . . . . . . . . . . . 371
15.3 Variational perspective of optimization algorithms . . . . . 372
15.4 Classes of functions . . . . . . . . . . . . . . . . . . . . . 372
15.4.1 Lipschitz functions . . . . . . . . . . . . . . . . . 372
15.4.2 Smooth functions . . . . . . . . . . . . . . . . . . 373
15.4.3 Convex functions . . . . . . . . . . . . . . . . . . 375
15.4.4 Strongly-convex functions . . . . . . . . . . . . . . 377
15.4.5 Nonconvex functions . . . . . . . . . . . . . . . . 378
15.5 Performance guarantees . . . . . . . . . . . . . . . . . . . 380
15.6 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 383

16 First-order optimization 384


16.1 Gradient descent . . . . . . . . . . . . . . . . . . . . . . . 384
16.1.1 Variational perspective . . . . . . . . . . . . . . . 384
16.1.2 Convergence for smooth functions . . . . . . . . . 385
16.1.3 Momentum and accelerated variants . . . . . . . . 387
16.2 Stochastic gradient descent . . . . . . . . . . . . . . . . . 388
16.2.1 Stochastic gradients . . . . . . . . . . . . . . . . . 389
16.2.2 Vanilla SGD . . . . . . . . . . . . . . . . . . . . . 390
16.2.3 Momentum variants . . . . . . . . . . . . . . . . . 391
16.2.4 Adaptive variants . . . . . . . . . . . . . . . . . . 392
16.3 Projected gradient descent . . . . . . . . . . . . . . . . . 392
16.3.1 Variational perspective . . . . . . . . . . . . . . . 393
16.3.2 Optimality conditions . . . . . . . . . . . . . . . . 394
16.3.3 Commonly-used projections . . . . . . . . . . . . . 394
16.4 Proximal gradient method . . . . . . . . . . . . . . . . . . 395
16.4.1 Variational perspective . . . . . . . . . . . . . . . 396
16.4.2 Optimality conditions . . . . . . . . . . . . . . . . 396
16.4.3 Commonly-used proximal operators . . . . . . . . . 397
16.5 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 397

17 Second-order optimization 399


17.1 Newton’s method . . . . . . . . . . . . . . . . . . . . . . 399
17.1.1 Variational perspective . . . . . . . . . . . . . . . 399
17.1.2 Regularized Newton method . . . . . . . . . . . . 400
17.1.3 Approximate direction . . . . . . . . . . . . . . . . 401
17.1.4 Convergence guarantees . . . . . . . . . . . . . . . 401
17.1.5 Linesearch . . . . . . . . . . . . . . . . . . . . . . 401
17.1.6 Geometric interpretation . . . . . . . . . . . . . . 402
17.1.7 Stochastic Newton’s method . . . . . . . . . . . . 403
17.2 Gauss-Newton method . . . . . . . . . . . . . . . . . . . 404
17.2.1 With exact outer function . . . . . . . . . . . . . . 405
17.2.2 With approximate outer function . . . . . . . . . . 406
17.2.3 Linesearch . . . . . . . . . . . . . . . . . . . . . . 407
17.2.4 Stochastic Gauss-Newton . . . . . . . . . . . . . . 407
17.3 Natural gradient descent . . . . . . . . . . . . . . . . . . 408
17.3.1 Variational perspective . . . . . . . . . . . . . . . 408
17.3.2 Stochastic natural gradient descent . . . . . . . . . 409
17.4 Quasi-Newton methods . . . . . . . . . . . . . . . . . . . 410
17.4.1 BFGS . . . . . . . . . . . . . . . . . . . . . . . . 410
17.4.2 Limited-memory BFGS . . . . . . . . . . . . . . . 411
17.5 Approximate Hessian diagonal inverse preconditionners . . 411
17.6 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 411

18 Duality 413
18.1 Dual norms . . . . . . . . . . . . . . . . . . . . . . . . . 413
18.2 Fenchel duality . . . . . . . . . . . . . . . . . . . . . . . . 414
18.3 Bregman divergences . . . . . . . . . . . . . . . . . . . . 417
18.4 Fenchel-Young loss functions . . . . . . . . . . . . . . . . 420
18.5 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 421

References 422
The Elements of
Differentiable Programming
Mathieu Blondel1 and Vincent Roulet1
1 Google DeepMind

ABSTRACT
Artificial intelligence has recently experienced remarkable
advances, fueled by large models, vast datasets, acceler-
ated hardware, and, last but not least, the transformative
power of differentiable programming. This new programming
paradigm enables end-to-end differentiation of complex com-
puter programs (including those with control flows and data
structures), making gradient-based optimization of program
parameters possible.
As an emerging paradigm, differentiable programming builds
upon several areas of computer science and applied mathe-
matics, including automatic differentiation, graphical mod-
els, optimization and statistics. This book presents a com-
prehensive review of the fundamental concepts useful for
differentiable programming. We adopt two main perspec-
tives, that of optimization and that of probability, with clear
analogies between the two.
Differentiable programming is not merely the differentiation
of programs, but also the thoughtful design of programs
intended for differentiation. By making programs differen-
tiable, we inherently introduce probability distributions over
their execution, providing a means to quantify the uncer-
tainty associated with program outputs.
Acknowledgements

We thank the following people for sending us feedback, suggestions


and typos: Fabian Pedregosa, Kevin Murphy, Niklas Schmitz, Nidham
Gazagnadou, Bruno De Backer, David López, Guillaume Gautier, Sam
Duffield, Logan Bruns, Wojciech Stokowiec, (add your name here!).

2
Source code

We provide some Python source code to accompany the book on github.

3
Notation

Table 1: Naming conventions

Notation Description
X ⊆R D
Input space (e.g., features)
Y ⊆ RM Output space (e.g., classes)
Sk ⊆ RDk Output space on layer or state k
W ⊆ RP Weight space
Λ ⊆ RQ Hyperparameter space
Θ ⊆ RR Distribution parameter space, logit space
N Number of training samples
T Number of optimization iterations
x∈X Input vector
y∈Y Target vector
sk ∈ Sk State vector k
w∈W Network (model) weights
λ∈Λ Hyperparameters
θ∈Θ Distribution parameters, logits
π ∈ [0, 1] Probability value
π ∈ △M Probability vector

4
5

Table 2: Naming conventions (continued)

Notation Description
f Network function
f (·; x) Network function with x fixed
L Objective function
ℓ Loss function
κ Kernel function
ϕ Output embedding, sufficient statistic
step Heaviside step function
logisticσ Logistic function with temperature σ
logistic Shorthand for logistic1
pθ Model distribution with parameters θ
ρ Data distribution over X × Y
ρX Data distribution over X
µ, σ 2 Mean and variance
Z Random noise variable
1
Introduction

1.1 What is differentiable programming?

A computer program is a sequence of elementary instructions for per-


forming a task. In traditional computer programming, the program is
typically manually written by a programmer. However, for certain tasks,
particularly those involving intricate patterns and complex decision-
making, such as image recognition or text generation, manually writing
a program is extremely challenging, if not impossible.
In contrast, modern neural networks offer a different approach. They
are constructed by combining parameterized functional blocks and
are trained directly from data using gradient-based optimization. This
end-to-end training process, where the network learns both feature
extraction and task execution simultaneously, allows neural networks to
tackle complex tasks that were previously considered insurmountable
for traditional, hand-coded programs. This new programming paradigm
has been referred to as “differentiable programming” or “software 2.0”,
terms popularized among others by LeCun (2018) and Karpathy (2017).
We given an informal definition below.

6
1.1. What is differentiable programming? 7

Definition 1.1 (Differentiable programming). Differentiable program-


ming is a programming paradigm in which complex computer pro-
grams (including those with control flows and data structures) can
be differentiated end-to-end automatically, enabling gradient-based
optimization of parameters in the program.

Modern neural networks as parameterized programs

In differentiable programming, as in regular computer programming, a


classical program is defined as the composition of elementary operations,
forming a computation graph. The key difference is that the program
(such as a neural network) contains parameters that can be adjusted
from data and can be differentiated end-to-end, using automatic
differentiation (autodiff). Typically, it is assumed that the program
defines a mathematically valid function (a.k.a. pure function): the
function should return identical values for identical arguments and
should not have any side effects. Moreover, the function should have
well-defined derivatives, ensuring that it can be used in a gradient-
based optimization algorithm. Therefore, differentiable programming
is not only the art of differentiating through programs but also of
designing meaningful differentiable programs.

Why are derivatives important?

Machine learning typically boils down to optimizing a certain objec-


tive function, which is the composition of a loss function and a model
(network) function. Derivative-free optimization is called zero-order
optimization. It only assumes that we can evaluate the objective
function that we wish to optimize. Unfortunately, it is known to suffer
from the curse of dimensionality, i.e., it only scales to small di-
mensional problems, such as less than 10 dimensions. Derivative-based
optimization, on the other hand, is much more efficient and can scale to
millions or billions of parameters. Algorithms that use first and second
derivatives are known as first-order and second-order algorithms,
respectively.
8 Introduction

Why is autodiff important?

Before the autodiff revolution, researchers and practitioners needed


to manually implement the gradient of the functions they wished to
optimize. Manually deriving gradients can become very tedious for
complicated functions. Moreover, every time the function is changed
(for example, for trying out a new idea), the gradient needs to be re-
derived. Autodiff is a game changer because it allows users to focus on
quickly and creatively experimenting with functions for their tasks.

Differentiable programming is not just deep learning

While there is clearly overlap between deep learning and differentiable


programming, their focus is different. Deep learning studies artificial
neural networks composed of multiple layers, able to learn intermediate
representations of the data. Neural network architectures have been
proposed with various inductive biases. For example, convolutional
neural networks are designed for images and transformers are designed
for sequences. On the other hand, differentiable programming studies the
techniques for designing complex programs and differentiating through
them. It is useful beyond deep learning: for instance in reinforcement
learning, probabilistic programming and scientific computing in general.

Differentiable programming is not just autodiff

While autodiff is a key ingredient of differentiable programming, this


is not the only one. Differentiable programming is also concerned with
the design of principled differentiable operations. In fact, much research
on differentiable programming has been devoted to make classical com-
puter programming operations compatible with autodiff. As we shall
see, many differentiable relaxations can be interpreted in a probabilis-
tic framework. A core theme of this book is the interplay between
optimization, probability and differentiation. Differentiation is useful
for optimization and conversely, optimization can be used to design
differentiable operators.
1.2. Book goals and scope 9

Our vision for differentiable programming

Computer programming offers powerful tools like control flows, data


structures, and standard libraries, enabling users to construct complex
programs for solving intricate problems. Our long-term vision is to
achieve parity between traditional and differentiable programming, em-
powering programmers to seamlessly express differentiable programs
(such as neural networks) using the full suite of tools they are accus-
tomed to. However, as discussed earlier, differentiable programming is
not simply a matter of applying automatic differentiation to existing
code. Programs must be designed with differentiability in mind. This
usually comes to inducing a probability distribution over the program
or its components. While significant work remains to fully realize this
ambitious goal, we hope this book offers a solid foundation.

1.2 Book goals and scope

The present book aims to provide an comprehensive introduction to


differentiable programming with an emphasis on core mathematical
tools.

• In Part I, we review fundamentals: differentiation and proba-


bilistic learning.

• In Part II, we review differentiable programs. This includes


neural networks, sequence networks and control flows.

• In Part III, we review how to differentiate through programs.


This includes automatic differentiation, but also differentiating
through optimization and integration (in particular, expectations).

• In Part IV, we review smoothing programs. We focus on two


main techniques: infimal convolution, which comes from the world
of optimization and convolution, which comes from the world of
integration. We also strive to spell out the connections between
them.
10 Introduction

• In Part V, we review optimizing programs: basic optimiza-


tion concepts, first-order algorithms, second-order algorithms and
duality.

Our goal is to present the fundamental techniques useful for differentiable


programming, not to survey how these techniques have been used in
various applications.

1.3 Intended audience

This book is intended to be a graduate-level introduction to differentiable


programming. Our pedagogical choices are made with the machine
learning community in mind. Some familiarity with calculus, linear
algebra, probability theory and machine learning is beneficial.

1.4 How to read this book?

This book does not need to be read linearly chapter by chapter. When
needed, we indicate at the beginning of a chapter what chapters are
recommended to be read as a prerequisite.

1.5 Related work

Differentiable programming builds upon a variety of connected topics.


We review in this section relevant textbooks, tutorials and software.
Standard textbooks on backpropagation and automatic differenti-
ation are that of Werbos (1994) and Griewank and Walther (2008).
A tutorial with a focus on machine learning is provided by Baydin
et al. (2018). Automatic differentiation is also reviewed as part of more
general textbooks, such as those of Deisenroth et al. (2020), Murphy
(2022) (from a linear algebra perspective) and Murphy (2023) (from a
functional perspective; autodiff section authored by Roy Frostig). The
present book was also influenced by Peyré (2020)’s textbook on data
science. The history of reverse-mode autodiff is reviewed by Griewank
(2012).
A tutorial on different perspectives of backpropagation is “There and
Back Again: A Tale of Slopes and Expectations” (link), by Deisenroth
1.5. Related work 11

and Ong. A tutorial on implicit differentiation is “Deep Implicit Layers -


Neural ODEs, Deep Equilibirum Models, and Beyond” (link), by Kolter,
Duvenaud, and Johnson.
The standard reference on inference in graphical models and its
connection with exponential families is that of Wainwright and Jor-
dan (2008). Differential programming is also related to probabilistic
programming; see, e.g., Meent et al. (2018).
A review of smoothing from the infimal convolution perspective is
provided by Beck and Teboulle (2012). A standard textbook on convex
optimization is that of Nesterov (2018). A textbook on first-order
optimization methods is that of Beck (2017).
Autodiff implementations that accelerated the autodiff revolution
in machine learning are Theano (Bergstra et al., 2010) and Autograd
(Maclaurin et al., 2015). Major modern implementations of autodiff
include Tensorflow (Abadi et al., 2016), JAX (Bradbury et al., 2018),
and PyTorch (Paszke et al., 2019). We in particular acknowledge the
JAX team for influencing our view of autodiff.
Part I

Fundamentals
2
Differentiation

In this chapter, we review key differentiation concepts. In particular,


we emphasize on the fundamental role played by linear maps.

2.1 Univariate functions

2.1.1 Derivatives
To study functions, such as defining their derivatives, we need to capture
their infinitesimal variations around points as defined by the notion of
limit.

Definition 2.1 (Limit). A function f : R → R tends to c ∈ R as


its input v ∈ R approaches w ∈ R, if, for any ε > 0, there exists
R > 0 such that for any v ∈ R satisfying 0 < |v − w| ≤ R, we have
|f (v) − c| ≤ ε. We say that c is the limit if f as v approaches w
and denote it
lim f (w) = c.
v→w

Limits are preserved under additions and multiplications. Namely, if


limv→w f (w) = c and limv→w g(w) = d, then, denoting (af + bg)(w) :=
af (w) + bg(w) for any a, b ∈ R and (f g)(w) := f (w)g(w), we have, by

13
14 Differentiation

definition of the limit, limv→w (af + bg)(w) = ac + bd, limv→w (f g)(w) =


cd. The preservation of the limit under additions and multiplications
by a scalar is generally referred as the linearity of the limit, a property
that many definitions below inherit.
With the notion of limit we can already delineate a class of “well-
behaved” functions, that is, functions whose limits at any point equals
to the value of the function at that point. Such a property defines
continuous functions.

Definition 2.2 (Continuous function). A function f : R → R is


continuous at a point w ∈ R if

lim f (v) = f (w).


v→w

A function f is said to be continuous if it is continuous at all points


in its domain.

Although the notion of continuity appears to be a benign assumption,


several simple functions, such as the Heavyside step function (displayed
in the left panel of Fig. 2.2), are not continuous and require special
treatment.

Remark 2.1 (Landau’s notation). In the following, we use Landau’s


little o notation. We write

g(v) = o(f (v)) as v → w

if
|g(v)|
lim = 0.
v→w|f (v)|
That is, the function f dominates g in the limit v → w. For example,
f is continuous at w if and only if

f (w + δ) = f (w) + o(1) as δ → 0.

We now explain derivatives. Consider a function f : R → R. As


illustrated in Fig. 2.1, its value on an interval [w0 , w0 + δ] can be
approximated by the secant between its values f (w0 ) and f (w0 + δ),
a linear function with slope (f (w0 + δ) − f (w0 ))/δ. In the limit of an
2.1. Univariate functions 15

Figure 2.1: A function f can be locally approximated around a point w0 by a secant,


a linear function w 7→ aw + b with slope a and intercept b, crossing f at w0 with value
y0 = f (w0 ) and crossing at w0 + δ with value uδ = f (w0 + δ). Using u0 = aw0 + b
and uδ = a(w0 + δ) + b, we find that its slope is a = (f (w0 + δ) − f (w0 ))/δ and the
intercept is b = f (w0 ) − aw0 . The derivative f ′ (w) of a function f at a point w0 is
then defined as the limit of the slope a when δ → 0. It is the slope of the tangent
of f at w0 . The value f (w) of the function at w can then be locally approximated
around w0 by w 7→ f ′ (w0 )w + f (w0 ) − f ′ (w0 )w0 = f (w0 ) + f ′ (w0 )(w − w0 ).

infinitesimal variation δ around w0 , the secant converges to the tangent


of f at w0 and the resulting slope defines the derivative of f at w0 . The
definition below formalizes this intuition.

Definition 2.3 (Derivative). The derivative of f : R → R at w ∈


R is defined as
f (w + δ) − f (w)
f ′ (w) := lim , (2.1)
δ→0 δ
provided that the limit exists. If f ′ (w) is well-defined at a particular
w, we say that the function f is differentiable at w.

Here, and in the following definitions, if f is differentiable at any


w ∈ R, we say that it is differentiable everywhere or differentiable for
short. If f is differentiable at a given w, then it is necessarily continuous
at w as shown in the following proposition. Non-differentiability of a
continuous function at a given point w is generally illustrated by a kink
of this function at w as shown in Fig. 2.2.
16 Differentiation

Continuous
Discontinuous at 0 non-differentiable at 1 and -1 1.0 Differentiable everywhere
1.0 1.0

0.5 0.5 0.5

0.0 0.0 0.0


2 1 0 1 2 2 1 0 1 2 2 1 0 1 2

Figure 2.2: Graphical representation of discontinuity or non-differentiability. A


discontinuous function presents a jump in function values at a given point (left
panel). A continuous but non-differentiable everywhere function presents kinks at the
points of non-differentiability (middle panel). A differentiable everywhere function is
smooth (right panel).

Proposition 2.1 (Differentiability implies continuity). If f : R → R


is differentiable at w ∈ R, then it is continuous at w ∈ R.

Proof. In little o notation, f is differentiable at w if there exists f ′ (w) ∈


R, such that

f (w + δ) = f (w) + f ′ (w)δ + o(δ) as δ → 0.

Since f ′ (w)δ + o(δ) = o(1) as δ → 0, f is continuous at w.

In addition to enabling the construction of a linear approximation


of f in a neighborhood of w, since it is the slope of the tangent of f at
w, the derivative f ′ informs us about the monotonicity of f around
w. If f ′ (w) is positive, the function is increasing around w. Conversely,
if f ′ (w) is negative, the function is decreasing. Such information can be
used to develop iterative algorithms seeking to minimize f by computing
iterates of the form wt+1 = wt − γf ′ (wt ) for γ > 0, which move along
descent directions of f around wt .
For several elementary functions such as wn , ew , ln w, cos w or sin w,
their derivatives can be obtained directly by applying the definition of
the derivative in Eq. (2.1) as illustrated in Example 2.1.

Example 2.1 (Derivative of power function). Consider f (w) = wn


2.1. Univariate functions 17

for w ∈ R, n ∈ N \ {0}. For any δ ∈ R, we have


f (w + δ) − f (w) (w + δ)n − wn
=
δ δ
Pn n k n−k
δ w − wn
= k=0 k
n
! δ
n k−1 n−k
=
X
δ w
k=1
k
n
! !
n n−1 X n k−1 n−k
= w + δ w ,
1 k=2
k

where, in the second line, we used the binomial theorem. Since


n
1 = n and limδ→0
n k−1 n−k
= 0, we get f ′ (w) = nwn−1 .
Pn
k=2 k δ w

Remark 2.2 (Functions on a subset U of R). For simplicity, we pre-


sented the definition of the derivative for a function defined on the
whole set of real numbers R. If a function f : U → R is defined on a

subset U ⊆ R of the real numbers, as it is the case for f (w) = w
defined on U = R+ , the derivative of f at w ∈ U is defined by
the limit in (2.1) provided that the function f is well defined on a
neighborhood of w, that is, there exists r > 0 such that w + δ ∈ U
for any |δ| ≤ r. The function f is then said differentiable ev-
erywhere or differentiable for short if it is differentiable at any
point w in the interior of U, the set of points w ∈ U such that
{w + δ : |δ| ≤ r} ⊆ U for r sufficiently small. For points lying at the
boundary of U (such as a and b if U = [a, b]), one may define the
right and left derivatives of f at a and b, meaning that the limit is
taken by approaching a from the right or b from the left.

2.1.2 Calculus rules

For a given w ∈ R and two functions f : R → R and g : R → R, the


derivative of elementary operations on f and g such as their sums,
products or compositions can easily be derived from the definition of
the derivative, under appropriate conditions on the differentiability
properties of f and g at w. For example, if the derivatives of f and g
18 Differentiation

exist at w, then the derivatives of their weighted sum or product exist,


and satisfy the rules

∀a, b ∈ R, (af + bg)′ (w) = af ′ (w) + bg ′ (w) (Linearity)


′ ′ ′
(f g) (w) = f (w)g(w) + f (w)g (w), (Product rule)

where (f g)(w) = f (w)g(w). The linearity can be verified directly from


the linearity of the limits. For the product rule, in little o notation, we
have, as δ → 0,

(f g)(w + δ) = (f (w) + f ′ (w)δ + o(δ))(g(w) + g ′ (w)δ + o(δ))


= f (w)g(w) + f ′ (w)g(w)δ + f (w)g ′ (w)δ + o(δ),

hence the result.


If the derivatives of g at w and of f at g(w) exist, then the derivative
of the composition (f ◦ g)(w) := f (g(w)) at w exists and is given by

(f ◦ g)′ (w) = f ′ (g(w))g ′ (w). (Chain rule)

We prove this result more generally in Proposition 2.2. As seen in the


sequel, the linearity and the product rule can be seen as byproducts of
the chain rule, making the chain rule the cornerstone of differentiation.
Consider a function that can be expressed using sums, products or
compositions of elementary functions, such as f (w) = ew ln w + cos w2 .
Its derivative can be computed by applying the aforementioned rules
on the decomposition of f into elementary operations and functions, as
illustrated in Example 2.2.

Example 2.2 (Applying rules of differentiation). Consider f (w) =


ew ln w + cos w2 . The derivative of f on w > 0 can be computed
2.1. Univariate functions 19

step by step as follows, denoting sq(w) := w2 ,

f ′ (w) = (exp · ln)′ (w) + (cos ◦ sq)′ (w) (Linearity)


′ ′ ′
(exp · ln) (w) = exp (w) · ln(w) + exp(w) · ln (w) (Product rule)
′ ′ ′
(cos ◦ sq) (w) = cos (sq(w)) sq (w) (Chain rule)
′ ′
exp (w) = exp(w), ln (w) = 1/w, (Elem. func.)
′ ′
sq (w) = 2w, cos (w) = − sin(w). (Elem. func.)

We therefore obtain that f ′ (w) = ew ln w + ew /w − 2w sin w2 .

Such a process is purely mechanical and lends itself to an automated


procedure, which is the main idea of automatic differentiation presented
in Chapter 8.

2.1.3 Leibniz’s notation


The notion of derivative was first introduced independently by Newton
and Leibniz in the 18th century (Ball, 1960). The latter considered
derivatives as the quotient of infinitesimal variations. Namely, denoting
u = f (w) a variable depending on w through f , Leibniz considered the
derivative of f as the quotient
du du
f′ = with f ′ (w) =
dw dw w

where du and dw denote infinitesimal variations of u and w respectively


and the symbol |w denotes the evaluation of the derivative at a given
point w. This notation simplifies the statement of the chain rule first
discovered by Leibniz (Rodriguez and Lopez Fernandez, 2010) as we
have for v = g(w) and u = f (v)
du du dv
= · .
dw dv dw
This hints that derivatives are multiplied when considering compositions.
At evaluation, the chain rule in Leibniz notation recovers the formula
presented above as
du du dv
= = f ′ (g(w))g ′ (w) = (f ◦ g)′ (w).
dw w dv g(w) dw w
20 Differentiation

The ability of Leibniz’s notation to capture the chain rule as a mere


product of quotients made it popular throughout the centuries, especially
in mechanics (Ball, 1960). The rationale behind Leibniz’s notation,
that is, the concept of “infinitesimal variations” was questioned by
later mathematicians for its potential logical issues (Ball, 1960). The
notation f ′ (w) first introduced by Euler and further popularized by
Lagrange (Cajori, 1993) has then taken over in numerous mathematical
textbooks. The concept of infinitesimal variations has been rigorously
defined by considering the set of hyperreal numbers. They extend
the set of real numbers by considering each number as a sum of a
non-infinitesimal part and an infinitesimal part (Hewitt, 1948). The
formalism of infinitesimal variations further underlies the development
of automatic differentiation algorithms through the concept of dual
numbers.

2.2 Multivariate functions

2.2.1 Directional derivatives

Let us now consider a function f : RP → R with multi-dimensional


input w = (w1 , . . . , wP ) ∈ RP . The most important example in machine
learning is a function which, to the parameters w ∈ RP of a neural
network, associates a loss value in R. Variations of f need to be defined
along specific directions, such as the variation f (w + δv)−f (w) of
f around w ∈ RP in the direction v ∈ RP by an amount δ > 0.
This consideration naturally leads to the definition of the directional
derivative.

Definition 2.4 (Directional derivative). The directional derivative


of f at w in the direction v is given by
f (w + δv) − f (w)
∂f (w)[v] := lim ,
δ→0 δ
provided that the limit exists.

One example of directional derivative consists in computing the


2.2. Multivariate functions 21

derivative of a function f at w in any of the canonical directions

ei := (0, . . . , 0, |{z}
1 , 0, . . . , 0).
i

This allows us to define the notion of partial derivatives, denoted for


i ∈ [P ]
f (w + δei ) − f (w)
∂i f (w) := ∂f (w)[ei ] = lim .
δ→0 δ
This is also denoted in Leibniz’s notation as ∂i f (w) = ∂f∂w(w)
i
or ∂i f (w) =
∂wi f (w). By moving along only the i coordinate of the function, the
th

partial derivative is akin to using the function ϕ(ωi ) = f (w1 , . . . , ωi , . . . , wP )


around ωi , letting all other coordinates fixed at their values wi .

2.2.2 Gradients
We now introduce the gradient vector, which gathers the partial deriva-
tives. We first recall the definitions of linear map and linear form.

Definition 2.5 (Linear map, linear form). A function l : RP → RM


is a linear map if for any a1 , a2 ∈ R, v1 , v2 ∈ RD ,

l[a1 v1 + a2 v2 ] = a1 l(v1 ) + a2 l[v2 ].

A linear map with values in R, l : RP → R, is called a linear form.

Linearity plays a crucial role in the differentiability of a function.

Definition 2.6 (Differentiability, single-output case). A function f :


RP → R is differentiable at w ∈ RP if its directional derivative
is defined along any direction, linear in any direction, and if
|f (w + v) − f (w) − ∂f (w)[v]|
lim = 0.
∥v∥2 →0 ∥v∥2

We can now introduce the gradient.

Definition 2.7 (Gradient). The gradient of a differentiable func-


tion f : RP → R at a point w ∈ RP is defined as the vector of
22 Differentiation

partial derivatives

∂1 f (w) ∂f (w)[e1 ]
   

∇f (w) :=  .
.. ..
=
   
  
. .

∂P f (w) ∂f (w)[eP ]

By linearity, the directional derivative of f at w in the direction


v = Pi=1 vi ei is then given by
P

P
∂f (w)[v] = vi ∂f (w)[ei ] = ⟨v, ∇f (w)⟩.
X

i=1

Here, ⟨·, ·⟩ denotes the inner product. We provide its definition in


Euclidean spaces in Section 2.3.2.
In the definition above, the fact that the gradient can be used to
compute the directional derivative is a mere consequence of linearity.
However, in more abstract cases presented in later sections, the gradient
is defined through this property.
As a simple example, any linear function of the form f (w) =
⟨a, w⟩ = Pi=1 ai wi is differentiable as we have (⟨a, w + v⟩ − ⟨a, w⟩ −
P

⟨a, v⟩)/∥v∥2 = 0 for any v and in particular for ∥v∥ → 0. Moreover, its
gradient is naturally given by ∇f (w) = a.
Generally, to show that a function is differentiable and find its
gradient, one approach is to approximate f (w + v) around v = 0. If we
can find a vector g such that
f (w + v) = f (w) + ⟨g, v⟩ + o(∥v∥2 ),
then f is differentiable at w since ⟨g, ·⟩ is linear. Moreover, g is then
the gradient of f at w.
Remark 2.3 (Gateaux and Fréchet differentiability). Multiple defini-
tions of differentiability exist. The one presented in Definition 2.6 is
about Fréchet differentiable functions. Alternatively, if f : RP →
R has well-defined directional derivatives along any directions then
the function is Gateaux differentiable. Note that the existence of
directional derivatives in any directions is not a sufficient condition
for the function to be differentiable. In other words, any Fréchet
2.2. Multivariate functions 23

differentiable function is Gateaux differentiable, but the converse


is not true. As a counter-example, one can verify that the function
f (x1 , x2 ) = x31 /(x21 + x22 ) is Gateaux differentiable at 0 but not
(Fréchet) differentiable at 0 (because the directional derivative at 0
is not linear).
Some authors also require Gateaux differentiable functions to
have linear directional derivatives along any direction. These are
still not Fréchet differentiable functions. Indeed, the limit in Defini-
tion 2.6 is over any vectors tending to 0 (potentially in a pathological
way), while directional derivatives look at such limits uniquely in
terms of a single direction.
In the remainder of this chapter, all definitions of differentiability
are in terms of Fréchet differentiability.

Example 2.3 illustrates how to compute the gradient of the logistic


loss and validate its differentiability.

Example 2.3 (Gradient of logistic loss). Consider the logistic loss


ℓ(θ, y) := −⟨y, θ⟩ + log M
i=1 e , that measures the prediction error
θi
P

of the logits θ ∈ R w.r.t. the correct label y ∈ {e1 , . . . , eM }. Let


M

us compute the gradient of this loss w.r.t. θ for fixed y, i.e., we


want to compute the gradient of f (θ) := ℓ(θ, y). Let us decompose
f as f = l + logsumexp with l(θ) := ⟨−y, θ⟩ and
M
logsumexp(θ) := log exp(θi ),
X

i=1

the log-sum-exp function. The function l is linear so differentiable


with gradient ∇l(θ) = −y. We therefore focus on logsumexp.
Denoting exp(θ) = (exp(θ1 ), . . . , exp(θM )), using that exp(x) =
1 + x + o(x), log(1 + x) = x + o(x), and denoting ⊙ the elementwise
24 Differentiation

product, we get

logsumexp(θ + v) = log (⟨exp(θ + v), 1⟩)


= log (⟨exp(θ) ⊙ exp(v), 1⟩)
= log (⟨exp(θ) ⊙ (1 +v + o(∥v∥2 )), 1⟩)
= log (⟨exp(θ), 1⟩ + ⟨exp(θ), v⟩ + o(∥v∥2 ))
exp(θ)
 
= log (⟨exp(θ), 1⟩) + , v + o(∥v∥2 ),
⟨exp(θ), 1⟩
The above decomposition of logsumexp(θ + v) shows that it is
differentiable, and that ∇logsumexp(θ) = softargmax(θ), where
    
M M
softargmax(θ) := eθ1 / 
X X
eθj  , . . . , eθM /  eθj  .
j=1 j=1

In total, we then get that ∇f (θ) = −y + softargmax(θ).

Linearity of gradients
The notion of differentiability for multi-input functions naturally inherits
from the linearity of derivatives for single-input functions. For any
u1 , . . . , uM ∈ R and any multi-input functions f1 , . . . , fM differentiable
at w, the function u1 f1 + . . . + uM fM is differentiable at w and its
gradient is

∇(u1 f1 + . . . + uM fM )(w) = u1 ∇f1 (w) + . . . + uM ∇fM (w).

Why is the gradient useful?


The gradient defines the steepest ascent direction of f from w. To see
why, we note that
arg max ∂f (w)[v] = arg max ⟨v, ∇f (w)⟩ = ∇f (w)/∥∇f (w)∥2 ,
v∈RP ,∥v∥2 ≤1 v∈RP ,∥v∥2 ≤1

where we assumed ∇f (w) ̸= 0. The gradient ∇f (w) is orthogonal to


the level set of the function (the set of points w sharing the same
value f (w)) and points towards higher values of f as illustrated in
Fig. 2.3. Conversely, the negative gradient −∇f (w) points towards lower
2.2. Multivariate functions 25

Figure 2.3: The gradient of a function


f : R2 → R at (w1 , w2 ) is the normal Figure 2.4: The directional derivative
vector to the tangent space of the level of a parametric curve f : R → R2 at w
set Lf (w1 ,w2 ) = {(w1′ , w2′ ) : f (w1′ , w2′ ) = is the tangent to the curve at the point
f (w1 , w2 )} and points towards points f (w) ∈ R2 .
with higher function values.

values of f . This observation motivates the development of optimization


algorithms such as gradient descent. It is based on iteratively performing
the update wt+1 = wt − γ∇f (wt ), for γ > 0. It therefore seeks for a
minimizer of f by moving along the steepest descent direction around
wt given, up to a multiplicative factor, by −∇f (wt ).

2.2.3 Jacobians
Let us now consider a multi-output function f : RP → RM defined by
f (w) := (f1 (w), . . . , fM (w)), where fj : RP → R. A typical example
in machine learning is a neural network. The notion of directional
derivative can be extended to such function by defining it as the vector
composed of the coordinate-wise directional derivatives:
 f (w+δv)−f (w) 
1 1
δ
f (w + δv) − f (w) ..
∂f (w)[v] := lim = lim   ∈ RM ,
 
δ→0 δ δ→0 
. 
fM (w+δv)−fM (w)
δ

where the limits (provided that they exist) are applied coordinate-wise.
The directional derivative of f in the direction v ∈ RP is therefore the
vector that gathers the directional derivative of each fj , i.e., ∂f (w)[v] =
(∂fj (w)[v])M
j=1 . In particular, we can define the partial derivatives of
26 Differentiation

f at w as the vectors
∂i f1 (w)
 

∂i f (w) := ∂f (w)[ei ] = 
 ..  ∈ RM .


. 
∂i fM (w)
As for the usual definition of the derivative, the directional derivative
can provide a linear approximation of a function around a current input
as illustrated in Fig. 2.4 for a parametric curve f : R → R2 .
Just as in the single-output case, differentiability is defined not only
as the existence of directional derivatives in any direction but also by
the linearity in the chosen direction.

Definition 2.8 (Differentiability, multi-output case). A function f :


RP → RM is (Fréchet) differentiable at a point w ∈ RP if its
directional derivative is defined along any directions, linear along
any directions, and,
∥f (w + v) − f (w) − ∂f (w)[v]∥2
lim = 0.
∥v∥2 →0 ∥v∥2

The partial derivatives of each coordinate’s function are gathered in


the Jacobian matrix.
Definition 2.9 (Jacobian). The Jacobian of a differentiable func-
tion f : RP → RM at w is defined as the matrix gathering partial
derivatives of each coordinate’s function provided they exist,

∂1 f1 (w) . . . ∂P f1 (w)
 

∂f (w) := 
 .. .. ..  ∈ RM ×P .


. . . 
∂1 fM (w) . . . ∂P fM (w)
The Jacobian can be represented by stacking columns of partial
derivatives or rows of gradients,

∇f1 (w)⊤
 

..
 
∂f (w) = ∂1 f (w), . . . , ∂P f (w) = 
 
 . .

∇fM (w)⊤
2.2. Multivariate functions 27

By linearity, the directional derivative of f at w along any input


direction v = Pi=1 vi ei ∈ RP is then given by
P

P
∂f (w)[v] = vi ∂i f (w) = ∂f (w)v ∈ RM .
X

i=1

Notice that we use bold ∂ to indicate the Jacobian matrix. The


Jacobian matrix naturally generalizes the concepts of derivatives and
gradients presented earlier. As for the single input case, to show that
a function is differentiable, one approach is to approximate f (w + v)
around v = 0. If we find a linear map l such that

f (w + v) = f (w) + l[v] + o(∥v∥2 ),

then f is differentiable at w. Moreover, if l is represented by matrix J


such that l[v] = J v then J = ∂f (w).
As a simple example, any linear function f (w) = Aw for A ∈ RM ×P
is differentiable, since all its coordinate-wise components are single-
output linear functions, and the Jacobian of f at any w is given by
∂f (w) = A.

Remark 2.4 (Special cases of the Jacobian). For single-output func-


tions f : RP → R, i.e., M = 1, the Jacobian matrix reduces to a
row vector identified as the transpose of the gradient, i.e.,

∂f (w) = ∇f (w)⊤ ∈ R1×P .

For a single-input function f : R → RM , the Jacobian reduces to a


single column vector of directional derivatives, denoted

f1′ (w)
 

. 
 ..  ∈ R
∂f (w) = f ′ (w) :=  M ×1

 .
′ (w)
fM

For a single-input single-output function f : R → R, the Jacobian


reduces to the derivative of f , i.e.,

∂f (w) = f ′ (w) ∈ R.
28 Differentiation

Example 2.4 illustrates the form of the Jacobian matrix for the
element-wise application of a differentiable function such as the softplus
activation. This example already shows that the Jacobian takes a simple
diagonal matrix form. As a consequence, the directional derivative
associated with this function is simply given by an element-wise product
rather than a full matrix-vector product as suggested in Definition 2.9.
We will revisit this point in Section 2.3.

Example 2.4 (Jacobian matrix of the softplus activation). Consider


the element-wise application of the softplus defined for w ∈ RP by

σ(w1 )
 
 . 
 ..  ∈ R
f (w) :=  where σ(w) := log(1 + ew ).
 P

σ(wP )
Since σ is differentiable, each coordinate of this function is differen-
tiable and the overall function is differentiable. The j th coordinate of
f is independent of the ith coordinate of w for i = ̸ j, so ∂i fj (w) = 0
for i ̸= j. For i = j, the result boils down to the derivative of σ
at wj . That is, ∂j fj (w) = σ ′ (wj ), where σ ′ (w) = ew /(1 + ew ). The
Jacobian of f is therefore a diagonal matrix

σ (w1 ) 0 ... 0
 ′ 

.. .. ..
. .
 
 0 . 
∂f (w) = diag(σ ′ (w1 ), . . . , σ ′ (wP )) :=  . .
 
 . .. ..
 . . . 0


0 ... 0 σ ′ (w P)

Variations along outputs


Rather than considering variations of f along an input direction v ∈ RP ,
we may also consider the variations of f along an output direction
u ∈ RM , namely, computing the gradient ∇⟨u, f ⟩(w) of the single-
output function
⟨u, f ⟩(w) := ⟨u, f (w)⟩ ∈ R.
In particular, we may consider computing the gradients ∇fj (w) of each
function coordinate fj = e⊤
j f at w, where ej is the j
th canonical vector
2.2. Multivariate functions 29

in RM . The infinitesimal variations of f at w along any output direction


u= M j=1 uj ej ∈ R
M are given by
P

M
∇⟨u, f ⟩(w) = uj ∇fj (w) = ∂f (w)⊤ u ∈ RP ,
X

j=1

where ∂f (w)⊤ is the Jacobian’s transpose. Using the definition of


derivative as a limit, we obtain for i ∈ [P ]
⟨u, f (w + δei ) − f (w)⟩
∇i ⟨u, f ⟩(w) = [∂f (w)⊤ u]i = lim .
δ→0 δ

Chain rule
Equipped with a generic definition of differentiability and the associated
objects, gradients and Jacobians, we can now generalize the chain rule,
previously introduced for single-input single-output functions.

Proposition 2.2 (Chain rule). Consider f : RP → RM and g :


RM → RR . If f is differentiable at w ∈ RP and g is differen-
tiable at f (w) ∈ RM , then the composition g ◦ f is differentiable
at w ∈ RP and its Jacobian is given by

∂(g ◦ f )(w) = ∂g(f (w))∂f (w).

Proof. We progressively approximate g ◦ f (w + v) using the differentia-


bility of f at w and g at f (w),

g(f (w + v)) = g(f (w) + ∂f (w)v + o(∥v∥))


= g(f (w)) + ∂g(f (w))∂f (w)v + o(∥v∥).

Hence, g ◦ f is differentiable at w with Jacobian ∂g(f (w))∂f (w).

Proposition 2.2 can be seen as the cornerstone of any derivative


computations. For example, it can be used to rederive the linearity or the
product rule associated to the derivatives of single-input single-outptut
functions.
When g is scalar-valued, combined with Remark 2.4, we obtain a
simple expression for ∇(g ◦ f ).
30 Differentiation

Proposition 2.3 (Chain rule, scalar-valued case). Consider f : RP →


RM and g : RM → R. The gradient of the composition is given by

∇(g ◦ f )(w) = ∂f (w)⊤ ∇g(f (w)).

This is illustrated with linear regression in Example 2.5.


Example 2.5 (Linear regression). Consider the squared residuals of
a linear regression of N inputs x1 , . . . , xN ∈ RD onto N targets
y1 , . . . , yN ∈ R with a vector w ∈ RD , that is, f (w) = ∥Xw −
y∥22 = N i=1 (⟨xi , w⟩ − yi ) for X = (x1 , . . . , xN ) ∈ R
2 ⊤ N ×D and
P

y = (y1 , . . . , yN )⊤ ∈ RN .
The function f can be decomposed into a linear mapping
f1 (w) = Xw and a squared error f2 (p) = ∥p − y∥22 , so that
f = f2 ◦ f1 . We can then apply the chain rule in Proposition 2.3 to
get
∇f (w) = ∂f1 (w)⊤ ∇f2 (f1 (w))
provided that f1 , f2 are differentiable at w and f1 (w), respectively.
The function f1 is linear so differentiable with Jacobian ∂f1 (w) =
X. On the other hand the partial derivatives of f2 are given by
∂j f2 (p) = 2(pj − yj ) for j ∈ {1, . . . , N }. Therefore, f2 is differen-
tiable at any p and its gradient is ∇f2 (p) = 2(p − y). By combining
the computations of the Jacobian of f1 and the gradient of f2 , we
then get the gradient of f as

∇f (w) = 2X ⊤ (f1 (w) − y) = 2X ⊤ (Xw − y).

2.3 Linear differentiation maps

The Jacobian matrix is useful as a representation of the partial deriva-


tives. However, the core idea underlying the definition of differentiable
functions, as well as their implementation in an autodiff framework,
lies in the access to two key linear maps. These two maps encode in-
finitesimal variations along input or output directions and are referred
to, respectively, as Jacobian-vector product (JVP) and Vector-
jacobian product (VJP). This section formalizes these notions, in the
context of Euclidean spaces.
2.3. Linear differentiation maps 31

2.3.1 The need for linear maps

So far, we have focused on functions f : RP → RM , that take a vector as


input and produce a vector as output. However, functions that use matrix
or even tensor inputs/outputs are common place in neural networks. For
example, consider the function of matrices of the form f (W ) = W x,
where x ∈ RD and W ∈ RM ×D . This function takes a matrix as input,
not a vector. Of course, a matrix W ∈ RM ×D can always be “flattened”
into a vector w ∈ RM D , by stacking the columns of W . We denote this
operation by w = vec(W ) and its inverse by W = vec−1 (w). We can
then equivalently write f (W ) as f˜(w) = f (vec−1 (w)) = vec−1 (w)x,
so that the previous framework applies. However, we will now see that
this would be inefficient.

Indeed, the resulting Jacobian of f˜ at any w consists in a matrix


of size RM ×M D , which, after some computations, can be observed to
be mostly filled with zeros. Getting the directional derivative of f at
W ∈ RM ×D in a direction V ∈ RM ×D would consist in (i) vectorizing
V into v = vec(V ), (ii) computing the matrix-vector product ∂ f˜(w)v
at a cost of M 3 D2 computations (ignoring the fact that the Jacobian
has zero entries), (iii) re-shaping the result into a matrix.

On the other hand, since f is linear in its matrix input, we can


infer that the directional derivative of f at any W ∈ RM ×D in any
direction V ∈ RM ×D is simply given by the function itself applied on
V . Namely, we have ∂f (W )[V ] = f (V ) = V x, which is simple to
implement and clearly only requires M D operations. Note that the
cost would have been the same, had we ignored the non-zero entries of
∂ f˜(w). The point here is that by considering the operations associated
to the differentiation of a function as linear maps rather than using the
associated representation as a Jacobian matrix, we can streamline the
associated implementations and exploit the structures of the underlying
input or output space. To that end, we now recall the main abstractions
necessary to extend the previous definitions in the context of Euclidean
spaces.
32 Differentiation

2.3.2 Euclidean spaces

Linear spaces, a.k.a. vector spaces, are spaces equipped (and closed
under) an addition rule compatible with multiplication by a scalar
(we limit ourselves to the field of reals). Namely, in a linear space E,
there exists the operations + and ·, such that for any u, v ∈ E, and
a ∈ R, we have u + v ∈ E and a · u ∈ E. Euclidean spaces are linear
spaces equipped with a basis e1 , . . . , eP ∈ E. Any element v ∈ E can be
decomposed as v = Pi=1 vi ei for some unique scalars v1 , . . . , vP ∈ R. A
P

canonical example of Euclidean space is the set RP of all vectors of size


P that we already covered. The set of matrices RP1 ×P2 of size P1 × P2
is also naturally a Euclidean space generated by the set of canonical
matrices Eij ∈ {0, 1}P1 ×P2 for i ∈ [P1 ], j ∈ [P2 ] filled with zero except
at the (i, j)th entry filled with one. For example, W ∈ RP1 ×P2 can be
written W = Pi,j=1
P 1 ,P2
Wij Eij .
Euclidean spaces are naturally equipped with a notion of inner
product defined below.

Definition 2.10 (Inner product). An inner product on a linear


space E is a function ⟨·, ·⟩ : E × E → R that is

• bilinear: x 7→ ⟨x, w⟩ and y 7→ ⟨v, y⟩ are linear for any w, v ∈


E,

• symmetric: ⟨w, v⟩ = ⟨v, w⟩ for any w, v ∈ E,

• positive definite: ⟨w, w⟩ ≥ 0 for any w ∈ E, and ⟨w, w⟩ = 0


if and only if w = 0.

An inner product defines a norm ∥w∥ :=


p
⟨w, w⟩.

The norm induced by an inner product defines a distance ∥w − v∥


between w, v ∈ E, and therefore a notion of convergence.
For vectors, where E = RP , the inner product is the usual one
⟨w, v⟩ = Pi=1 wi vi . For matrices, where E = RP1 ×P2 , the inner product
P

is the so-called Frobenius inner product. It is defined for any W , V ∈


2.3. Linear differentiation maps 33

RP1 ×P2 by
PX
1 ,P2

⟨W , V ⟩ := ⟨vec(W ), vec(V )⟩ = Wij Vij = tr(W ⊤ V ),


i,j=1

where tr(Z) := i=1 Zii is the trace operator defined for square matrices
PP

Z ∈ RP ×P . For tensors of order R, which generalize matrices to E =


RP1 ×...×PR , the inner product is defined similarly for W, V ∈ RP1 ×...×PR
by
P1 X
,...,PR
⟨W, V⟩ := ⟨vec(W), vec(V)⟩ = Wi1 ...iR Vi1 ...iR ,
i1 ,...,iR =1

where Wi1 ...iR is the (i1 , . . . , iR )th entry of W.

2.3.3 Linear maps and their adjoints


The notion of linear map defined in Definition 2.5 naturally extends
to Euclidean spaces. Namely, a function l : E → F from a Euclidean
space E onto a Euclidean space F is a linear map if for any w, v ∈ E
and a, b ∈ R, we have l[aw + bv] = a · l[w] + b · l[v]. When E = RP and
F = RM , there always exists a matrix A ∈ RM ×P such that l[x] = Ax.
Therefore, we can think of A as the “materialization” of l.
We can define the adjoint operator of a linear map.

Definition 2.11 (Adjoint operator). Given two Euclidean spaces E


and F equipped with inner products ⟨·, ·⟩E and ⟨·, ·⟩F , the adjoint
of a linear map l : E → F is the unique linear map l∗ : F → E such
that for any v ∈ E and u ∈ F,

⟨l[v], u⟩F = ⟨v, l∗ [u]⟩E .

The adjoint can be thought as the counterpart of the matrix trans-


pose for linear maps. When l[v] = Av, we have l∗ [u] = A⊤ u since
⟨l[v], u⟩F = ⟨Av, u⟩F = ⟨v, A⊤ u⟩E = ⟨v, l∗ [u]⟩E .

2.3.4 Jacobian-vector products


We now define the directional derivative using linear maps, leading
to the notion of Jacobian-vector product (JVP). This can be used to
34 Differentiation

facilitate the treatment of functions on matrices or be used for further


extensions to infinite-dimensional spaces. In the following, E and F
denote two Euclidean spaces equipped with inner products ⟨·, ·⟩E and
⟨·, ·⟩F . We start by defining differentiability in general Euclidean spaces.

Definition 2.12 (Differentiability in Euclidean spaces). A function f :


E → F is differentiable at a point w ∈ E if the directional
derivative along v ∈ E
f (w + δv) − f (w)
∂f (w)[v] := lim
δ→0 δ
is well-defined for any v ∈ E, linear in v and if
∥f (w + v) − f (w) − ∂f (w)[v]∥F
lim = 0.
∥v∥F →0 ∥v∥F

We can now formally define the Jacobian-vector product.

Definition 2.13 (Jacobian-vector product). For a differentiable func-


tion f : E → F, the linear map ∂f (w) : E → F, mapping v to
∂f (w)[v] is called the Jacobian-vector product (JVP) by anal-
ogy with Definition 2.9. The function ∂f is a function from E to a
linear map from E to F. That is, we have

∂f : E → (E → F).

Strictly speaking, v belongs to E. Therefore it may not be a vector,


if for instance E is the set of real matrices. We adopt the name JVP, as
it is now largely adopted.

Recovering the gradient


Previously, we saw that for differentiable functions with vector input
and scalar output, the directional derivative is equal to the inner prod-
uct between the direction and the gradient. The same applies when
considering differentiable functions from a Euclidean space with single
outputs, except that the gradient is now an element of the input space
and the inner product is the one associated with the input space.
2.3. Linear differentiation maps 35

Proposition 2.4 (Gradient). If a function f : E → R is differen-


tiable at w ∈ E, then there exists ∇f (w) ∈ E, called the gradient
of f at w such that the directional derivative of f at w along any
input direction v ∈ E is given by

∂f (w)[v] = ⟨∇f (w), v⟩E .

In Euclidean spaces, the existence of the gradient can simply be


shown by decomposing the partial derivative along a basis of E. Such a
definition generalizes to infinite-dimensional (e.g., Hilbert spaces) spaces
as seen in Section 2.3.9.

2.3.5 Vector-Jacobian products


For functions with vector input and vector output, we already discussed
infinitesimal variations along output directions. The same approach
applies for Euclidean spaces and is tied to the adjoint of the JVP as
detailed in Proposition 2.5.

Proposition 2.5 (Vector-Jacobian product). If a function f : E →


F is differentiable at w ∈ E, then its infinitesimal variation along an
output direction u ∈ F is given by the adjoint map ∂f (w)∗ : F →
E of the JVP, called the vector-Jacobian product (VJP). It
satisfies

∇⟨u, f ⟩F (w) = ∂f (w)∗ [u],

where we denoted ⟨u, f ⟩F (w) := ⟨u, f (w)⟩F . The function ∂f (·)∗


is a function from E to a linear map from F to E. That is, we have

∂f (·)∗ : E → (F → E).

Proof. The chain rule presented in Proposition 2.2 naturally generalizes


to Euclidean spaces (see Proposition 2.6). Since ⟨u, ·⟩ is linear, its
directional derivative is itself. Therefore, the directional derivative of
⟨u, f ⟩F is
∂(⟨u, f ⟩F )(w)[v] = ⟨u, ∂f (w)[v]⟩F
= ⟨∂f (w)∗ [u], v⟩F .
36 Differentiation

As this is true for any v ∈ E, ∂f (w)∗ [u] is the gradient of ⟨u, f ⟩F per
Proposition 2.4.

2.3.6 Chain rule


The chain rule presented before in terms of Jacobian matrices can
readily be formulated in terms of linear maps to take advantage of the
implementations of the JVP and VJP as linear maps.

Proposition 2.6 (Chain rule, general case). Consider f : E → F


and g : F → G for E, F, G some Euclidean spaces. If f is differ-
entiable at w ∈ E and g is differentiable at f (w) ∈ F, then the
composition g ◦ f is differentiable at w ∈ E. Its JVP is given by

∂(g ◦ f )(w)[v] = ∂g(f (w))[∂f (w)[v]]

and its VJP is given by

∂(g ◦ f )(w)∗ [u] = ∂f (w)∗ [∂g(f (w))∗ [u]].

The proof follows the one of Proposition 2.2. When the last function
is scalar-valued, which is often the case in machine learning, we obtain
the following simplified result.

Proposition 2.7 (Chain rule, scalar case). Consider f : E → F and


g : F → R, the gradient of the composition is given by

∇(g ◦ f )(w) = ∂f (w)∗ [∇g(f (w))].

2.3.7 Functions of multiple inputs (fan-in)


Oftentimes, the inputs of a function do not belong to only one Euclidean
space but to a product of them. An example is f (x, W ) = W x, which
is defined on E = RD × RM ×D . In such a case, it is convenient to
generalize the notion of partial derivatives to handle blocks of inputs.
Consider a function f (w1 , . . . , wS ) defined on E = E1 × . . . × ES ,
where wi ∈ Ei . We denote the partial derivative with respect to the
ith input wi along vi ∈ Ei as ∂i f (w1 , . . . , wS )[vi ]. Equipped with this
2.3. Linear differentiation maps 37

notation, we can analyze how JVPs or VJPs are decomposed along


several inputs.

Proposition 2.8 (Multiple inputs). Consider a differentiable func-


tion of the form f (w) = f (w1 , . . . , wS ) with signature f : E → F,
where w := (w1 , . . . , wS ) ∈ E and E := E1 × · · · × ES . Then the
JVP with the input direction v = (v1 , . . . , vS ) ∈ E is given by

∂f (w)[v] = ∂f (w1 , . . . , wS )[v1 , . . . , vS ] ∈ F


S
= ∂i f (w1 , . . . , wS )[vi ].
X

i=1

The VJP with the output direction u ∈ F is given by

∂f (w)∗ [u] = ∂f (w1 , . . . , wS )∗ [u] ∈ E


= (∂1 f (w1 , . . . , wS )∗ [u], . . . , ∂S f (w1 , . . . , wS )∗ [u]).

Example 2.6 (Matrix-vector product). Consider f (x, W ) = W x,


where W ∈ RM ×D and x ∈ RD . This corresponds to setting
E = E1 × E2 = RD × RM ×D and F = RM . For the JVP, letting
v ∈ RD and V ∈ RM ×D , we obtain

∂f (x, W )[v, V ] = W v + V x ∈ F.

We can also access the individual JVPs as


∂1 f (x, W )[v] = W v ∈ F,
∂2 f (x, W )[V ] = V x ∈ F.

For the VJP, letting u ∈ RM , we obtain

∂f (x, W )∗ [u] = (W ⊤ u, ux⊤ ) ∈ E.

We can access the individual VJPs by

∂1 f (x, W )∗ [u] = W ⊤ u ∈ E1 ,
∂2 f (x, W )∗ [u] = ux⊤ ∈ E2 .
38 Differentiation

Remark 2.5 (Nested inputs). It is sometimes convenient to group


inputs into meaningful parts. For instance, if the input is naturally
broken down into two parts x = (x1 , x2 ), where x1 is a text
part and x2 is an image part, and the network parameters are
naturally grouped into three layers w = (w1 , w2 , w3 ), we can write
f (x, w) = f ((x1 , x2 ), (w1 , w2 , w3 )). This is mostly a convenience
and we can again reduce it to a function of a single input, thanks
to the linear map perspective in Euclidean spaces.

Remark 2.6 (Hiding away inputs). It will often be convenient to


ignore inputs when differentiating. We use the semicolon for this
purpose. For instance, a function of the form L(w; x, y) (notice
the semicolon) has signature L : W → R because we treat x and
y as constants. Therefore, the gradient is ∇L(w; x, y) ∈ W. On
the other hand, the function L(w, x, y) (notice the comma) has
signature L : W × X × Y → R so its gradient is ∇L(w, x, y) ∈
W × X × Y. If we need to access partial gradients, we use indexing,
e.g., ∇1 L(w, x, y) ∈ W or ∇w L(w, x, y) ∈ W when there is no
ambiguity.

2.3.8 Functions with multiple outputs (fan-out)


Similarly, it is often convenient to deal with functions that have multiple
outputs.

Proposition 2.9 (Multiple outputs). Consider a differentiable func-


tion of the form f (w) = (f1 (w), . . . , fT (w)), with signatures f : E →
F and fi : E → Fi , where F := F1 × · · · × FT . Then the JVP with
the input direction v ∈ E is given by

∂f (w)[v] = (∂f1 (w)[v], . . . , ∂fT (w)[v]) ∈ F.


2.3. Linear differentiation maps 39

The VJP with the output direction u = (u1 , . . . , uT ) ∈ F is

∂f (w)∗ [u] = ∂f (w)∗ [u1 , . . . , uT ] ∈ E


T
= ∂fi (w)∗ [ui ].
X

i=1

Combined with the chain rule, we obtain that the Jacobian of


h(w) := g(f (w)) = g(f1 (w), . . . , fT (w))
is ∂h(w) = i=1 ∂i f (g(w)) ◦ ∂gi (w) and therefore the JVP is
PT

T
∂h(w)[v] = ∂i f (g(w))[∂gi (w)[v]].
X

i=1

2.3.9 Extensions to non-Euclidean linear spaces


We focused on Euclidean spaces, i.e., linear spaces with a finite basis.
However, the notions introduced earlier can be defined in more generic
spaces.
For example, directional derivatives (see Definition 2.12) can
be defined in any linear space equipped with a norm and complete
with respect to this norm. Such spaces are called Banach spaces.
Completeness is a technical assumption that requires that any Cauchy
sequence converges (a Cauchy sequence is a sequence whose elements
become arbitrarily close to each other as the sequence progresses). A
function f : E → F defined from a Banach space E onto a Banach space
F is then called Gateaux differentiable if its directional derivative is
defined along any direction (where limits are defined w.r.t. the norm in
F). Some authors also require the directional derivative to be linear to
define a Gateaux differentiable function.
Fréchet differentiability can also naturally be generalized to
Banach spaces. The only difference is that, in generic Banach spaces,
the linear map l satisfying Definition 2.12 must be continuous, i.e., there
must exist C > 0, such that l[v] ≤ C∥v∥, where ∥ · ∥ is the norm in the
Banach space E.
The definitions of gradient and VJPs require in addition a notion of
inner product. They can be defined in Hilbert spaces, that is, linear
40 Differentiation

spaces equipped with an inner product and complete with respect to


the norm induced by the inner product (they could also be defined
in a Banach space by considering operations in the dual space, see,
e.g. (Clarke et al., 2008)). The existence of the gradient is ensured by
Riesz’s representation theorem which states that any continuous
linear form in a Hilbert space can be represented by the inner product
with a vector. Since for a differentiable function f : E → R, the JVP
∂f (w) : E → R is a linear form, Riesz’s representation theorem ensures
the existence of the gradient as the element g ∈ E such that ∂f (w)v =
⟨g, v⟩ for any v ∈ E. The VJP is also well-defined as the adjoint of the
JVP w.r.t. the inner product of the Hilbert space.
As an example, the space of squared integrable functions on R is a
Hilbert space equipped with the inner product ⟨a, b⟩ := a(x)b(x)dx.
R

Here, we cannot find a finite number of functions that can express all
possible functions on R. Therefore, this space is not a mere Euclidean
space. Nevertheless, we can consider functions on this Hilbert space
(called functionals to distinguish them from the elements of the space).
The associated directional derivatives and gradients, can be defined
and are called respectively, functional derivative and functional
gradient, see, e.g., Frigyik et al. (2008) and references therein.

2.4 Second-order differentiation

2.4.1 Second derivatives


For a single-input, single-output differentiable function f : R → R,
its derivative at any point is itself a function f ′ : R → R. We may
then consider the derivative of the derivative at any point: the second
derivative.

Definition 2.14 (Second derivative). The second derivative f (2) (w)


of a differentiable function f : R → R at w ∈ R is defined as the
derivative of f ′ at w, that is,
f ′ (w + δ) − f ′ (w)
f (2) (w) := lim ,
δ→0 δ
2.4. Second-order differentiation 41

Figure 2.5: Points at which the second derivative is small are points along which
the function is well approximated by its tangent line. On the other hand, point with
large second derivative tend to be badly approximated by the tangent line.

provided that the limit is well-defined. If the second derivative


of a function f is well-defined at w, the function is said twice
differentiable at w.
If f has a small second derivative at a given w, the derivative around
w is almost constant. That is, the function behaves like a line around
w, as illustrated in Fig. 2.5. Hence, the second derivative is usually
interpreted as the curvature of the function at a given point.

2.4.2 Second directional derivatives


For a multi-input function f : RP → R, we saw that the directional
derivative encodes infinitesimal variations of f along a given direction.
To analyze the second derivative, the curvature of the function at a
given point w, we can consider the variations along a pair of directions,
as defined below.
Definition 2.15 (Second directional derivative). The second direc-
tional derivative of f : RP → R at w ∈ RP along v, v ′ ∈ RP
is defined as the directional derivative of w 7→ ∂f (w)[v] along v ′ ,
that is,
∂f (w + δv ′ )[v] − ∂f (w)[v]
∂ 2 f (w)[v, v ′ ] := lim ,
δ→0 δ
42 Differentiation

provided that ∂f (w)[v] is well-defined around w and that the limit


exists.

Of particular interest are the variations of a function around the


canonical directions: the second partial derivatives, defined as
2
∂ij f (w) := ∂ 2 f (w)[ei , ej ]

for ei , ej the ith and j th canonical directions in RP , respectively. In


Leibniz notation, the second partial derivatives are denoted
∂ 2 f (w)
2
∂ij f (w) = .
∂wi ∂wj

2.4.3 Hessians
For a multi-input function, twice differentiability is simply defined as
the differentiability of any directional derivative ∂f (w)[v] w.r.t. w.

Definition 2.16 (Twice differentiability). A function f : RP → R is


twice differentiable at w ∈ RP if it is differentiable and ∂f : RP →
(RP → R) is also differentiable at w.

As a result, the second directional derivative is a bilinear form.

Definition 2.17 (Bilinear map, bilinear form). A function b : RP ×


RP → RM is a bilinear map if b[v, ·] : RP → R is linear for any v
and b[·, v ′ ] is linear for any v ′ . That is,
P P X
P
b[v, v ′ ] = vi b[ei , v ′ ] = vi vj′ b[ei , ej ],
X X

i=1 i=1 j=1

for v = Pi=1 vi ei and v ′ = Pi=1 vi′ ei . A bilinear map with values


P P

in R, b : RP × RP → R, is called a bilinear form.

The second directional derivative can be computed along any two


directions from the knowledge of the second partial derivatives gathered
in the Hessian.
2.4. Second-order differentiation 43

Definition 2.18 (Hessian). The Hessian of a twice differentiable


function f : RP → R at w is the P × P matrix gathering all second
partial derivatives:

∂11 f (w) . . . ∂1P f (w)


 

∇ f (w) := 
2
 .. .. .. 

. . . ,

∂P 1 f (w) . . . ∂P P f (w)
provided that all second partial derivatives are well-defined.
The second directional derivative at w is bilinear in any direc-
tions v = Pi=1 vi ei and v ′ = Pi=1 vi′ ei . Therefore,
P P

P
∂ 2 f (w)[v, v ′ ] = vi vj′ ∂ 2 f (w)[ei , ej ] = ⟨v, ∇2 f (w)v ′ ⟩.
X

i,j=1

Given the gradient of f , the Hessian is equivalent to the transpose


of the Jacobian of the gradient. By slightly generalizing the notation ∇
to denote the transpose of the Jacobian of a function (which matches
its definition for single-output functions), we have that the Hessian can
be expressed as ∇2 f (w) = ∇(∇f )(w), which justifies its notation.
Similarly as for the differentiability of a function f , twice differen-
tiability of f at w is equivalent to having the second partial derivatives
not only defined but also continuous in a neighborhood of w. Remark-
ably, by requiring twice differentiability, i.e., continuous second partial
derivatives, the Hessian is guaranteed to be symmetric (Schwarz, 1873).

Proposition 2.10 (Symmetry of the Hessian). If a function f : RP →


R is twice differentiable at w, then its Hessian ∇2 f (w) is symmetric,
that is, ∂ij
2 f (w) = ∂ 2 f (w) for any i, j ∈ {1, . . . P }.
ji

The symmetry of the Hessian means that it can alternatively be


written as ∇2 f (w) = (∂ji
2 f (w))P
i,j=1 = ∂(∇f )(w), i.e., the Jacobian of
the gradient of f .

2.4.4 Hessian-vector products


Similarly to the Jacobian, we can exploit the formal definition of the
Hessian as a bilinear form to extend its definition for Euclidean spaces.
44 Differentiation

In particular, we can define the notion of Hessian-vector product.

Definition 2.19 (Hessian-vector product). If a function f : E → R


defined on a Euclidean space E with inner product ⟨·, ·⟩, is twice
differentiable at w ∈ E, then for any v ∈ E, there exists ∇2 f (w)[v],
called the Hessian-vector product (HVP) of f at w along v
such that for any v ′ ∈ E,

∂ 2 f (w)[v, v ′ ] = ⟨v ′ , ∇2 f (w)[v]⟩.

In particular for E = RP , the HVP is ∇2 f (w)[v] = (∂ 2 f (w)[v, ei ])Pi=1 .

From an autodiff point of view, the HVP can be implemented in


four different ways, as explained in Section 9.1.

2.4.5 Second-order Jacobians

The previous definitions naturally extend to multi-output functions


f : E → F, where f := (f1 , . . . , fM ), fj : E → Fj and F := F1 ×· · ·×FM .
The second directional derivative is defined by gathering the second
derivatives of each coordinate’s function. That is, for w, v, v ′ ∈ E,

∂f (w)[v, v ′ ] = (∂fj (w)[v, v ′ ])M


j=1 ∈ F.

The function f is twice differentiable if and only if all its coordinates are
twice differentiable. The second directional derivative is then a bilinear
map. We can then compute second directional derivatives as
P
∂ 2 f (w)[v, v ′ ] = vi vj′ ∂ 2 f (w)[ei , ej ] = (⟨v, ∇2 fj (w)v ′ ⟩)M
X
j=1 .
i,j=1

When E = RP and Fj = R, so that F = RM , the bilinear map can be


materialized as a tensor

∂ 2 f (w) = (∂ 2 f (w)[ei , ej ])Pi,j=1 ∈ RM ×P ×P ,

the “second-order Jacobian” of f . However, similarly to the Hessian,


it is usually more convenient to apply the bilinear map to prescribed
vectors v and v ′ than to materialize as a tensor.
2.5. Higher-order differentiation 45

2.5 Higher-order differentiation

2.5.1 Higher-order derivatives


Derivatives can be extended to any order. Formally, the nth derivative
can be defined inductively as follows for a single-input, single-output
function.

Definition 2.20 (nth order derivative). The nth derivative f (n) of a


function f : R → R at w ∈ R is defined as

f (n−1) (w + δ) − f (n−1) (w)


f (n) (w) := (f (n−1) )′ (w) = lim
δ→0 δ
provided that f n−1 is differentiable around w and that the limit
exists. In such a case, the function is said n times differentiable at
w.

2.5.2 Higher-order directional derivatives


For a multi-input function f , we can naturally extend the notion of
directional derivative as follows.

Definition 2.21 (nth order directional derivative). The nth directional


derivative of f : RP → R at w ∈ RP along v1 , . . . , vn is defined as

∂ n f (w)[v1 , . . . , vn ]
= ∂(∂ n−1 f (w)[v1 , . . . , vn−1 ])[vn ]
∂f (w + δvn )[v1 , . . . , vn−1 ] − ∂f (w)[v1 , . . . , vn−1 ]
= lim
δ→0 δ

A multi-input function f is n-times differentiable if it is n − 1


differentiable and its n − 1 directional derivative along any direction
is differentiable. As a consequence the nth directional derivative is a
multilinear form.
Definition 2.22 (Multilinear map, multilinear form). A function c :
⊗ni=1 RP → RM is a multilinear map if it is linear in each coor-
dinate given all others fixed, that is, if vj 7→ c[v1 , . . . , vj , . . . , vn ]
46 Differentiation

is linear in vj for any j ∈ [n]. It is a multilinear form if it has


values in R.

The nth order directional derivative is then given by


P
∂ f (w)[v1 , . . . , vn ] = v1,i1 . . . vn,in ∂ n f (w)[ei1 , . . . , ein ].
X
n

i1 ,...,in =1

Its materialization is an nth order tensor


∇n f (w) = (∂ n f (w)[ei1 , . . . , ein ])Pi1 ,...,in =1 ∈ RP ×...×P .

2.5.3 Higher-order Jacobians


All above definitions extend directly to the case of multi-output functions
f : E → F, where F = F1 × · · · × FM . The nth directional derivatives
∂ n f (w)[v1 , . . . , vn ] = (∂ n fj (w)[v1 , . . . , vn ])M
j=1 .

The function f is then n times differentiable if if it is n − 1 differentiable


and its n − 1 directional derivative along any direction is differentiable.
As a consequence, the nth directional derivative is a multilinear map.
The nth directional derivative can be decomposed into partial derivatives
as
P
∂ n f (w)[v1 , . . . , vn ] = v1,i1 . . . vn,in ∂ n f (w)[ei1 , . . . , ein ].
X

i1 ,...,in =1

When E = RP and F = RM , it is materialized by an n + 1th order


tensor
∂ n f (w) = (∂ n fj (w)[ei1 , . . . , ein ])M,P,...,P
j=1,i1 ,...,in =1 ∈ R
M ×P ×...×P
.

2.5.4 Taylor expansions


With Landau’s little o notation, we have seen that if a function is
differentiable it is approximated by a linear function in v,
f (w + v) = f (w) + ⟨∇f (w), v⟩ + o(∥v∥2 ).
Such an expansion of the function up to its first derivative is called the
first-order Taylor expansion of f around w.
2.6. Differential geometry 47

If the function f is twice differentiable, we can approximate it by a


quadratic in v, leading to the second-order Taylor expansion of f
around w,
1
f (w + v) = f (w) + ⟨∇f (w), v⟩ + ⟨v, ∇2 f (w)v⟩ + o(∥v∥22 ).
2
Compared to the first-order Taylor approximation, it is naturally more
accurate around w, as reflected by the fact that ∥v∥32 ≤ ∥v∥22 for
∥v∥2 ≤ 1.
More generally, we can build the nth order Taylor expansion of
a n times differentiable function f : RP → RM around w ∈ RP by
1
f (w + v) = f (w) + ∂f (w)[v] + ∂ 2 f (w)[v, v] + . . .
2
1
+ ∂ n f (w)[v, . . . , v ] + o(∥v∥n2 ).
n! | {z }
n times

Note that, using the change of variable w′ = w + v ⇐⇒ v = w′ − w,


it is often convenient to write the nth Taylor expansion of f (w′ ) around
w as
n
1 j
f (w′ ) = f (w) + ∂ f (w)[w′ − w, . . . , w′ − w] + o(∥w′ − w∥n2 ).
X
j!
j=1
| {z }
j times

Taylor expansions will prove useful in Chapter 7 for computing deriva-


tives by finite differences.

2.6 Differential geometry

In this chapter, we progressively generalized the notion of derivative


from real numbers to vectors and variables living in a linear space (a.k.a.
vector space), either finite dimensional or infinite dimensional. We can
further generalize these notions by considering a local notion of linearity.
This is formalized by smooth manifolds in differential geometry,
whose terminology is commonly adopted in the automatic differentiation
literature and software. In this section, we give a brief overview of
derivatives on smooth manifolds (simply referred as manifolds), and
refer to Boumal (2023) for a complete introduction.
48 Differentiation

2.6.1 Differentiability on manifolds

Essentially, a manifold is a set that can be locally approximated by


a Euclidean space. The most common example is a sphere like the
Earth. Seen from the Moon, the Earth is not a plane, but locally, at a
human level, it can be seen as a flat surface. Euclidean spaces are also
trivial examples of manifold. A formal characterization of the sphere
as a manifold is presented in Example 2.7. For now, we may think of
a “manifold” as some set (e.g., the sphere) contained in some ambient
Euclidean space; note however that manifolds can be defined generally
without being contained in a Euclidean space (Boumal, 2023, Chapter
8). Differentiability in manifolds is simply inherited from the notion of
differentiability in the ambient Euclidean space.

Definition 2.23 (Differentiability of restricted functions). Let M and


N be manifolds. A function f : M → N defined from M ⊆ E to
N ⊆ F, with E and F Euclidean spaces, is differentiable if f is
the restriction of a differentiable function f¯ : E → F, so that f
coincides with f¯ on M.

Our objective is to formalize the directional derivatives and gradients


for functions defined on manifolds. This formalization leads to the
definitions of tangent spaces and cotangent spaces, and the associated
generalizations of JVP and VJP operators as pushforward and pullback
operators, respectively.

2.6.2 Tangent spaces and pushforward operators

To generalize the notion of directional derivatives of a function f , the


one property we want to preserve is the chain rule. Rather than starting
from the variations of f at a given point along a direction, we start with
the variations of f along curves. Namely, on a manifold like the sphere
S P in RP , we can look at curves α : R → S P passing by w ∈ S P at
time 0, that is, α(0) = w. For single-input functions like α, we denoted
for simplicity α′ (0) := (α1′ (0), . . . , αP′ (0)). The directional derivative of
f must typically serve to define the derivative of f ◦ α at 0, such that
(f ◦ α)′ (0) = ∂f (w)[α′ (0)]. In the case of the sphere, as illustrated in
2.6. Differential geometry 49

Fig. 2.6, the derivative α′ (0) of a curve α passing through a point w is


always tangent to the sphere at w. The tangent plane to the sphere
at w, then captures all possible relevant vectors to pass to the JVP
we are building. To define the directional derivative of a function f on
a manifold, we therefore restrict ourselves to an operator defined on
the tangent space Tw M, whose definition below is simplified for our
purposes.

Definition 2.24 (Tangent space). The tangent space of a mani-


fold M at w ∈ M is defined as

Tw M := {v = α′ (0) for any α : R → M differentiable s.t. α(0) = w}.

In the case of the sphere in Fig. 2.6, the tangent space is a plane, that
is, a Euclidean space. This property is generally true: tangent spaces
are Euclidean spaces such that we will be able to define directional
derivatives as linear operators. Now, if f is differentiable and goes from
a manifold M to a manifold N , then f ◦ α is a differentiable curve in N .
Therefore, (f ◦ α)′ (0) is the derivative of a curve passing through f (w)
at 0 and is tangent to N at f (w). Hence, the directional derivative
of f : M → N at w can be defined as a function from the tangent
space Tw M of M at w onto the tangent space Tf (w) N of N at f (w).
Overall, we built the directional derivative (JVP) by considering how a
composition of f with any curve α pushes forward the derivative of α
into the derivative of f ◦ α. The resulting JVP is called a pushforward
operator in differentiable geometry.

Definition 2.25 (Pushforward operator). Given two manifolds M


and N , the pushforward operator of a differentiable function
f : M → N at w ∈ M is the linear map ∂f (w) : Tw M → Tf (w) N
defined by
∂f (w)[v] := (f ◦ α)′ (0),
for any v ∈ Tw M such that v = α′ (0), for a differerentiable curve
α : R → M, passing by w at 0, i.e., α(0) = w.
50 Differentiation

Figure 2.6: A differentiable function f defined from a sphere M to a sphere N


defines a push-forward operator that maps tangent vectors (derivatives of functions
on the sphere passing by w) in the tangent space Tw M to tangent vectors of N at
f (w) in the tangent space Tf (w) N .

2.6.3 Cotangent spaces and pullback operators

To generalize the JVP, we composed f : M → N with any single-input


function α : R → M giving values on the manifold. The derivative of
any such α is then pushed forward from Tw M to Tf (w) N by the action
of f . To define the VJP, we take a symmetric approach. We consider all
single-output differentiable functions β : N → R defined on y ∈ N with
y = f (w) for some w ∈ M. We then want to pull back the derivatives
of β when precomposing it by f . Therefore, the space on which the
VJP acts is the space of directional derivatives of any β : N → R at y,
defining the cotangent space.

Definition 2.26 (Cotangent space). The cotangent space of a


manifold N at y ∈ N is defined as

Ty∗ N = {u = ∂β(y) for any β : N → R differentiable}


= {u : Ty N → R for any linear map u},

Note that elements of the cotangent space are linear mappings, not
vectors. This distinction is important to define the pullback operator
as an operator on functions as done in measure theory. From a linear
algebra viewpoint, the cotangent space is exactly the dual space of
Ty N , that is, the set of linear maps from Ty N to R, called linear forms.
As Ty N is a Euclidean space, its dual space Ty∗ N is also a Euclidean
space. The pullback operator is then defined as the operator that gives
2.6. Differential geometry 51

access to directional derivatives of β ◦ f given the directional derivative


of β at f (w).

Definition 2.27 (Pullback operator). Given two manifolds M and


N , the pullback operator of a differentiable function f : M → N
at w ∈ M is the linear map ∂f (w)⋆ : Tf∗(w) N → Tw∗ M defined by

∂f (w)⋆ u := ∂(β ◦ f )(w),

for any u ∈ Tf (w) N ∗ such that ∂β(f (w)) = u, for a differentiable


function β : N → R.

Contrary to the pushforward operator that acts on vectors, the


pullback operator acts on linear forms. Hence, the slight difference in
notation between ∂f (w)⋆ and ∂f (w)∗ , the adjoint operator of ∂f (w).
To properly define the adjoint operator ∂f (w)∗ , we need a notion
of inner product. Since tangent spaces are Euclidean spaces, we can
define an inner product ⟨·, ·⟩w for each Tw M and w ∈ M, making M
a Riemannian manifold. Equipped with these inner products, the
cotangent space can be identified with the tangent space, and we can
define gradients.

Definition 2.28 (Gradients in Riemannian manifolds). Let M be a


Riemannian manifold equipped with inner products ⟨·, ·⟩w . For any
cotangent vector u ∈ Tw∗ M, with w ∈ M, there exists a unique
tangent vector u ∈ Tw M such that

∀v ∈ Tw M, u[v] = ⟨u, v⟩w .

In particular for any differentiable function f : M → R, we can


define the gradient of f as the unique tangent vector ∇f (w) ∈
Tw M such that

∀v ∈ Tw M, ∂f (w)[v] = ⟨∇f (w), v⟩.

Therefore, rather than pulling back directional derivatives, we can


pull back gradients. The corresponding operator is then naturally the
adjoint ∂f (w)∗ of the pushforward operator. Namely, given two Rieman-
nian manifolds M and N , and a differentiable function f : M → N ,
52 Differentiation

Function f M→N
Push-forward ∂f (w) Tw M → Tf (w) N
Pullback ∂f (w)⋆ Tf∗(w) N → Tw∗ M
Adjoint of pushforward ∂f (w)∗ Tf (w) N → Tw M

Table 2.1: For a differentiable function f defined from a manifold M onto a manifold
N , the JVP is generalized with the notion of pushforward ∂f (w). The counterpart
of the pushforward is the pullback operation ∂f (w)⋆ that acts on linear forms in the
tangent spaces. For Riemannian manifolds, the pullback operation can be identified
with the adjoint operator ∂f (w)∗ of the pushforward operator as any linear form is
represented by a vector.

we have

(∂f (w)⋆ u) [v] = ⟨∂f (w)∗ [u], v⟩ for any v ∈ Tw M

for u = ⟨·, u⟩ ∈ Tf∗(w) N represented by u ∈ Tf (w) N .

Example 2.7 (The sphere as a manifold). The sphere S P in RP is


defined as the set of points w ∈ RP , satisfying c(w) := ⟨w, w⟩−1 =
0, with JVP ∂c(w)[v] = 2⟨w, v⟩.
For any v = (v1 , . . . , vP −1 ) ∈ RP −1 close enough to a point w
on the sphere, we can define ψ1 (v) = 1 − ⟨v, v⟩ such that ψ(v) =
p

(v1 , . . . , vP −1 , ψ1 (v)) satisfies ⟨ψ(v), ψ(v)⟩ = 1, that is c(ψ(w)) =


1. With the help of the mapping ψ −1 from a neighborhood of w
in the sphere to RP −1 , we can see the sphere locally as a space of
dimension P − 1.
The tangent space can be naturally characterized in terms of
the constraining function c. Namely, the curve α : R → S such that
α(0) = w satisfies for any δ ∈ R, c(α(δ)) = 0. Hence, differentiating
the implicit equation, we have

(c ◦ α)′ (0) = ∂c(w)[α′ (0)].

That is, α′ (0) is in the null space of ∂c(w), denoted

Null(∂c(w)) := {v ∈ RP : ∂c(w)[v] = 0}.


2.7. Generalized derivatives 53

The tangent space of S at w is then

Tw M = Null(2⟨w, ·⟩)
= {v ∈ RP : ⟨w, v⟩ = 0}

We naturally recover that the tangent space is a Euclidean space


of dimension P − 1, defined as the set of points orthogonal to w.

2.7 Generalized derivatives

While we largely focus on differentiable functions in this book, it is


important to characterize non-differentiable functions. We distinguish
here two cases: continuous functions and non-continuous functions. For
the former case, there exists generalizations of the notion of directional
derivative, gradient and Jacobian, presented below. For non-continuous
functions, even if derivatives exist almost everywhere, they may be
uninformative. For example, piecewise constant functions, encountered
in e.g. control flows (Chapter 5), are almost everywhere differentiable
but with zero derivatives. In such cases, surrogate functions can be
defined to ensure the differentiability of a program (Part IV).

2.7.1 Rademacher’s theorem


We first recall the definition of locally Lipschitz continuous function.

Definition 2.29 (Locally Lipschitz continuous function). A function


f : E → F, is Lipschitz continuous if there exists C ≥ 0 such that
for any x, y ∈ E,

∥f (x) − f (y)∥ ≤ C∥x − y∥.

A function f : E → F is locally Lipschitz continuous if for any


x ∈ E, there exists a neighborhood U of x such that f restricted
to U is Lipschitz continuous.

Rademacher’s theorem (Rademacher, 1919) ensures that f is differ-


entiable almost everywhere; see also Morrey Jr (2009) for a standard
proof.
54 Differentiation

Proposition 2.11 (Rademacher’s theorem). Let E and F denote Eu-


clidean spaces. If f : E → F is locally Lipschitz-continuous, then f
is almost everywhere differentiable, that is the set of points in E at
which f is not differentiable is of (Lebesgue) measure zero.

2.7.2 Clarke derivatives

Rademacher’s theorem hints that the definitions of directional deriva-


tives, gradients and Jacobians may be generalized to locally Lipschitz
continuous functions. This is what Clarke (1975) did in his seminal
work, which laid the foundation of nonsmooth analysis. The first
building block is a notion of generalized directional derivative.

Definition 2.30 (Clarke generalized directional derivative). The Clarke


generalized directional derivative of a locally Lipschitz con-
tinuous function f : E → R at w ∈ E in the direction v ∈ E
is
f (u + δv) − f (u)
∂C f (w)[v] := lim sup ,
u→w δ
δ↘0

provided that the limit exists, where δ ↘ 0 means that δ approaches


0 by non-negative values and the limit superior is defined as

lim sup f (x) := lim sup{f (x) : x ∈ B(a, ε) \ {a}}


x→a ε→0

for B(a, ε) := {x ∈ E : ∥x − a∥ ≤ ε} the ball centered at a of


radius ε.

There are two differences with the usual definition of a directional


derivative: (i) we considered slopes of the function in a neighborhood of
the point rather than at the given point, (ii) we took a limit superior
rather than a usual limit. The first point is rather natural in the light
of Rademacher’s theorem: we can properly characterize variations on
points where the function is differentiable, therefore we may take the
limits of these slopes as a candidate slope for the point of interest. The
second point is more technical but essential: it allows us to characterize
the directional derivative as the supremum of some linear forms (Clarke
2.7. Generalized derivatives 55

et al., 2008). These linear forms in turn define a set of generalized


gradients (Clarke et al., 2008, Chapter 2).

Definition 2.31 (Clarke generalized gradient). A Clarke general-


ized gradient of a locally Lipschitz function f : E → R at w ∈ E
is a point g ∈ E such that ∀v ∈ E

∂f (w)[v] ≥ ⟨g, v⟩.

The set of Clarke generalized gradients is called the Clarke subd-


ifferential of f at w.

Definition 2.30 and Definition 2.31 can be used in non-Euclidean


spaces, such as Banach or Hilbert spaces (Clarke et al., 2008). In
Euclidean spaces, the Clarke generalized gradients can be characterized
more simply thanks to Rademacher’s theorem (Clarke et al., 2008,
Theorem 8.1). Namely, they can be defined as a convex combination of
limits of gradients of f evaluated at a sequence in E \ Ω that converges
to w (Proposition 2.12). In the following, the convex hull of a set S ⊆ E,
the set of convex combinations of elements of S, is denoted
m
conv(S) := {λ1 s1 + . . . + λm sm : m ∈ N, λi ≥ 0, λi = 1, si ∈ S}.
X

i=1

Proposition 2.12 (Characterization of Clarke generalized gradients).


Let f : E → R be a locally Lipschitz continuous and denote Ω the
set of points at which f is not differentiable (Proposition 2.11). An
element g ∈ E is a Clarke generalized gradient of f at w ∈ E if and
only if
 
g ∈ conv lim ∇f (vn ) : (vn )+∞
n=1 s.t. vn ∈ E \ Ω, vn → w .
n→+∞ n→+∞

The Jacobian of a function f : E → F between two Euclidean spaces


can be generalized similarly (Clarke et al., 2008, Section 3.3).

Definition 2.32 (Clarke generalized Jacobian). Let f : E → F be a


locally Lipschitz continuous and denote Ω the set of points at which
f is not differentiable (Proposition 2.11). A Clarke generalized
56 Differentiation

Jacobian of f at w ∈ E is an element J of
 
conv lim ∂f (vn ) : (vn )+∞
n=1 s.t. vn ∈ E \ Ω, vn → w .
n→+∞ n→+∞

For a continuously differentiable function f : E → F or f : E → R,


there is a unique generalized gradient, that is the usual gradient (Clarke
et al., 2008, Proposition 3.1, page 78). The chain-rule can be generalized
to these objects (Clarke et al., 2008). Recently, Bolte and Pauwels (2020)
and Bolte et al. (2022) further generalized Clarke gradients through the
definition of conservative gradients to define automatic differentiation
schemes for nonsmooth functions.

2.8 Summary

• The usual definition of derivatives of real-valued univariate


functions extends to multivariate functions f : RP → R through
the notion of directional derivative ∂f (w)[v] at w ∈ RP in
the direction v ∈ RP .

• To take advantage of the representation of w = Pj=1 wj ej using


P

the canonical bases {e1 , . . . , eP }, the definition of differentiable


functions requires the linearity of the directional derivative w.r.t.
the direction v.

• This requirement gives rise to the notion of gradient ∇f (w) ∈


RP , the vector that gathers the partial derivatives and further
defines the steepest ascent direction at w.

• For vector-input vector-output functions f : RP → RM , the di-


rectional derivative leads to the definition of Jacobian matrix
∂f (w) ∈ RM ×P , the matrix which gathers all partial derivatives
(notice that we use bold ∂). The chain rule is then the product
of Jacobian matrices.

• These notions can be extended to general Euclidean spaces, such as


the spaces of matrices or tensors. For functions of the form f : E →
R, the gradient is ∇f (w) ∈ E. More generally, for functions
of the form f : E → F, the Jacobian ∂f (w) can be seen as a
2.8. Summary 57

linear map (notice the non-bold ∂). The directional derivative


at w ∈ E naturally defines a linear map l[v] = ∂f (w)[v], where
∂f (w) : E → F is called the Jacobian vector product (JVP)
and captures the infinitesimal variation at w ∈ E along the input
direction v ∈ E.

• Its adjoint ∂f (w)∗ : F → E defines another linear map l[u] =


∂f (w)∗ [u] called the vector Jacobian product (VJP) and
captures the infinitesimal variation at w ∈ E along the output
direction u ∈ F. The chain rule is then the composition of
these linear maps.

• For the particular case when we compose a scalar-valued function


ℓ (such as a loss function) with a vector-valued function f (such
as a network function), the gradient is given by ∇(ℓ ◦ f )(w) =
∂f (w)∗ ∇ℓ(f (w)). This is why being able to apply the adjoint to
a gradient, which as we shall see can be done with reverse-mode
autodiff, is so pervasive in machine learning.

• The definitions of JVP and VJP operators can further be general-


ized in the context of differentiable geometry. In that framework,
the JVP amounts to the pushforward operator that acts on
tangent vectors. The VJP amounts to the pullback operator
that acts on cotangent vectors.

• We also saw that the Hessian matrix of a function f (w) from


RP to R is denoted ∇2 f (w) ∈ RP ×P . It is symmetric if the second
partial derivatives are continuous. Seen as linear map, the Hessian
leads to the notion of Hessian-vector product (HVP), which
we saw can be reduced to the JVP or the VJP of ∇f (w).

• The main take-away message of this chapter is that computing the


directional derivative or the gradient of compositions of functions
does not require computing intermediate Jacobians but only to
evaluate linear maps (JVPs or VJPs) associated with these inter-
mediate functions. The goal of automatic differentiation, presented
in Chapter 8, is precisely to provide an efficient implementation
58 Differentiation

of these maps for computation chains or more generally for


computation graphs.
3
Probabilistic learning

In this chapter, we review how to perform probabilistic learning. We


also introduce exponential family distributions, as they play a key role
in this book.

3.1 Probability distributions

3.1.1 Discrete probability distributions

A discrete probability distribution over a set Y is specified by its


probability mass function (PMF) p : Y → [0, 1]. The probability of
y ∈ Y is then defined by

P(Y = y) := p(y),

where Y denotes a random variable. When Y follows a distribution p,


we write Y ∼ p (with some abuse of notation, we use the same letter p
to denote the distribution and the PMF). The expectation of ϕ(Y ),
where Y ∼ p and ϕ : Y → RM , is then

E[ϕ(Y )] =
X
p(y)ϕ(y),
y∈Y

59
60 Probabilistic learning

its variance (for one-dimensional variables) is


V[ϕ(Y )] = E[(ϕ(Y ) − E[ϕ(Y )])2 ] = p(y)(ϕ(y) − E[ϕ(Y )])2
X

y∈Y

and its mode is


arg max p(y).
y∈Y
The Kullback-Leibler (KL) divergence (also known as relative entropy)
between two discrete distributions over Y, with associated PMFs p and
q, is the statistical “distance” defined by
p(y) p(Y )
   
KL(p, q) := p(y) log = EY ∼p log
X
.
y∈Y
q(y) q(Y )

3.1.2 Continuous probability distributions


A continuous probability distribution over Y is specified by its proba-
bility density function (PDF) p : Y → R+ . The probability of A ⊆ Y
is then Z
P(Y ∈ A) = p(y)dy.
A
The definitions of expectation, variance and KL divergence are defined
analogously to the discrete setting, simply replacing y∈Y with Y .
P R

Specifically, the expectation of ϕ(Y ) is


Z
E[ϕ(Y )] = p(y)ϕ(y)dy,
Y

the variance is
Z
V[ϕ(Y )] = E[(ϕ(Y ) − E[ϕ(Y )])2 ] = p(y)(ϕ(y) − E[ϕ(Y )])2 dy
Y

and the KL divergence is


p(y) p(Y )
Z    
KL(p, q) := p(y) log dy = EY ∼p log .
Y q(y) q(Y )
The mode is defined as the arg maximum of the PDF.
When Y = R, we can also define the cumulative distribution
function (CDF)
Z b
P(Y ≤ b) = p(y)dy.
−∞
3.2. Maximum likelihood estimation 61

The probability of Y lying in the semi-closed interval (a, b] is then


P(a < Y ≤ b) = P(Y ≤ b) − P(Y ≤ a).

3.2 Maximum likelihood estimation

3.2.1 Negative log-likelihood


We saw that a probability distribution over Y is specified by p(y), which
is called the probability mass function (PMF) for discrete variables
or the probability density function (PDF) for continuous variables. In
practice, the true distribution p generating the data is unknown and we
wish to approximate it with a distribution pλ , with parameters λ ∈ Λ.
Given a finite set of i.i.d. observations y1 , . . . , yN , how do we fit λ ∈ Λ
to the data? This can be done by maximizing the likelihood of the
data, i.e., we seek to solve
N
b N := arg max pλ (yi ).
Y
λ
λ∈Λ i=1

This is known as maximum likelihood estimation (MLE). Because


the log function is monotonically increasing, this is equivalent to mini-
mizing the negative log-likelihood, i.e., we have
N
b N = arg min − log pλ (yi ).
X
λ
λ∈Λ i=1

Example 3.1 (MLE for the normal distribution). Suppose we set pλ


to the normal distribution with parameters λ = (µ, σ), i.e.,

1 1
2 !
y−µ

pλ (y) := √ exp − .
σ 2π 2 σ

Then, given observations y1 , . . . , yN , the MLE estimators for µ and


σ 2 are the sample mean and the sample variance, respectively.

3.2.2 Consistency w.r.t. the Kullback-Leibler divergence


It is well-known that the MLE estimator is consistent, in the sense of
the Kullback-Leibler divergence. That is, denoting the true distribution
62 Probabilistic learning

p and
p(Y )
 
λ∞ := arg min KL(p, pλ ) = EY ∼p log ,
λ∈Λ pλ (Y )
then λ
b N → λ∞ in expectation over the observations, as N → ∞. This
can be seen by using
N N
1 X p(yi ) 1 X
 
KL(p, pλ ) ≈ log = log p(yi ) − log pλ (yi )
N i=1 pλ (yi ) N i=1
and the law of large numbers.

3.3 Probabilistic supervised learning

3.3.1 Conditional probability distributions


Many times in machine learning, instead of a probability P(Y = y) for
some y ∈ Y, we wish to define a conditional probability P(Y = y|X =
x), for some input x ∈ X . This can be achieved by reduction to an
unconditional probability distribution,
P(Y = y|X = x) := pλ (y)
where
λ := f (x, w)
and f is a model function with model parameters w ∈ W. That is,
rather than being a deterministic function from X to Y, f is a function
from X to Λ, the set of permissible distribution parameters of the
output distribution associated with the input.
We emphasize that λ could be a single parameter or a collection of
parameters. For instance, in the Bernoulli distribution, λ = π, while in
the univariate normal distribution, λ = (µ, σ).
In Section 3.4 and throughout this book, we will also use the notation
pθ instead of pλ when θ are the canonical parameters of an exponential
family distribution (Section 3.4).

3.3.2 Inference
The main advantage of this probabilistic approach is that our prediction
model is much richer than if we just learned a function from X to Y.
3.3. Probabilistic supervised learning 63

We now have access to the whole distribution over possible outcomes in


Y and can compute various statistics:
• Probability: P(Y = y|X = x) or P(Y ∈ A|X = x),

• Expectation: E[ϕ(Y )|X = x] for some function ϕ,

• Variance: V[ϕ(Y )|X = x],

• Mode: arg maxy∈Y pλ (y).


We now review probability distributions useful for binary classification,
multiclass classification, regression, multivariate regression, and integer
regression. In the following, to make the notation more lightweight, we
omit the dependence on x.

3.3.3 Binary classification


For binary outcomes, where Y = {0, 1}, we can use a Bernoulli
distribution with parameter
λ := π ∈ [0, 1].
When a random variable Y is distributed according to a Bernoulli
distribution with parameter π, we write
Y ∼ Bernoulli(π).
The PMF of this distribution is

π if y = 1
pπ (y) := .
1 − π if y = 0
The Bernoulli distribution is a binomial distribution with a single
trial. Since y ∈ {0, 1}, the PMF can be rewritten as
pπ (y) = π y (1 − π)1−y .
The mean is
E[Y ] = π = P(Y = 1)
and the variance is
V[Y ] = π(1 − π) = P(Y = 1)P(Y = 0).
64 Probabilistic learning

PMF CDF Mean function A 0( ) Loss L( , y)


1.0 1.0 1.0 1.0 y=1
0.8 4 y=0

0.5 0.5 0.5 2


0.2 0.2
0.0 0 0
1 0.0 0 1 0.0 3 0 3 5 0 5
y y

Figure 3.1: The Bernoulli distribution, whose PMF and CDF are here illustrated
with parameter π = 0.8. Its mean function is π = A′ (θ) = logistic(θ) = 1+exp(−θ)
1
,
where θ is for instance the output of a neural network. The negative log-likelihood
leads to the logistic loss, L(θ, y) = softplus(θ) − θy = log(1 + exp(θ)) − θy. The
loss curve is shown for y ∈ {0, 1}.

Parameterization using a sigmoid


Since the parameter π of a Bernoulli distribution needs to belong to
[0, 1], we typically use a sigmoid function (Section 4.4.3), such as a
logistic function as the output layer:

π := f (x, w) := logistic(g(x, w)),

where g : X × W → R is for example a neural network and


1
logistic(a) := ∈ (0, 1).
1 + exp(−a)
When g is linear in w, this is known as binary logistic regression.

Remark 3.1 (Link with the logistic distribution). The logistic distri-
bution with mean and scale parameters µ and σ is a continuous
probability distribution with PDF
u−µ
 
pµ,σ (u) := p0,1
σ
where
exp(−z)
p0,1 (z) :=.
(1 + exp(−z))2
If a random variable U follow a logistic distribution with pa-
rameters µ and σ, we write U ∼ Logistic(µ, σ). The CDF of
3.3. Probabilistic supervised learning 65

U ∼ Logistic(µ, σ) is
Z u
u−µ
 
P(U ≤ u) = pµ,σ (u)du = logistic .
−∞ σ
Therefore, if
U ∼ Logistic(µ, σ)
and
u−µ
  
Y ∼ Bernoulli logistic ,
σ
then
P(Y = 1) = P(U ≤ u).
Here, U can be interpreted as a latent continuous variable and u
as a threshold.

3.3.4 Multiclass classification


For categorical outcomes with M possible choices, where Y = [M ],
we can use a categorical distribution with parameters
λ := π ∈ △M ,
where we define the probability simplex
△M := {π ∈ RM
+ : ⟨π, 1⟩ = 1},

i.e., the set of valid discrete probability distributions. When Y follows


a categorical distribution with parameter π, we write
Y ∼ Categorical(π).
The PMF of the categorical distribution is
pπ (y) := ⟨π, ϕ(y)⟩ = πy ,
where
ϕ(y) := ey
is the standard basis vector for the coordinate y ∈ [M ].
Since Y is a categorical variable, it does not make sense to compute
the expectation of Y but we can compute that of ϕ(Y ) = eY ,
EY ∼pπ [ϕ(Y )] = π.
66 Probabilistic learning

PMF CDF A( ), y Loss L( , y)


1.0 1.0 1.0 1.0 6 y = e1
0.9 y = e2
y = e3
0.6 4
0.5 0.5 0.5
0.3 0.3 2
0.1
0.0 0.0 0.0 0
1 2 3 1 2 3 3 0 3 5 0 5
y y s s

Figure 3.2: The categorical distribution, whose PMF and CDF are here il-
lustrated with parameter π = (0.3, 0.6, 0.1). Its mean function is π = ∇A(θ) =
softargmax(θ), where θ ∈ RM is for instance the output of a neural network. Here,
for illustration purpose, we choose to set θ = (s, 1, 0) and vary only s. Since the
mean function ∇A(θ) belongs to R3 , we choose to display ⟨∇A(θ), ei ⟩ = ∇A(θ)i ,
for i ∈ {1, 2, 3}. The negative log-likelihood leads to the logistic loss, L(θ, y) =
logsumexp(θ) − ⟨θ, y⟩. The loss curve is shown for y ∈ {e1 , e2 , e3 }, again with
θ = (s, 1, 0) and varying s.

Therefore, as was also the case for the Bernoulli distribution, the mean
and the probability distribution (represented by the vector π) are the
same in this case.

Parameterization using a softargmax

Since the parameter vector π of a categorical distribution needs to


belong to △M , we typically use a softargmax as the output layer:

π := f (x, w) := softargmax(g(x, w)),

where g : X × W → RM is for example a neural network and

exp(u)
softargmax(u) := P ∈ relint(△M ).
j exp(uj )

The output of the softargmax is in the relative interior of △M , relint(△M ) =


△M ∩ RM >0 . That is, the produced probabilities are always strictly posi-
tive. The categorical distribution is a multinomial distribution with
a single trial. When g is linear in w, this is therefore known as multi-
class or multinomial logistic regression, though strictly speaking a
multinomial distribution could use more than one trial.
3.3. Probabilistic supervised learning 67

3.3.5 Regression
For real outcomes, where Y = R, we can use, among other choices, a
normal distribution with parameters
λ := (µ, σ),
where µ ∈ R is the mean parameter and σ ∈ R+ is the standard
deviation parameter. The PDF is
1 1 (y − µ)2
!
pµ,σ (y) := √ exp − .
σ 2π 2 σ2
The expectation is
EY ∼pµ,σ [Y ] = µ.
One advantage of the probabilistic perspective is that we are not limited
to predicting the mean. We can also compute the CDF
1 y−µ
  
P(Y ≤ y) = 1 + erf √ , (3.1)
2 σ 2
where we used the error function Z
2 z 2
erf(z) := √ e−t dt.
π 0
This function is available in most scientific computing libraries, such as
SciPy (Virtanen et al., 2020). From the CDF, we also easily obtain
1 b−µ a−µ
    
P(a < Y ≤ b) = erf √ − erf √ .
2 σ 2 σ 2

Parameterization
Typically, in regression, the mean is output by a model, while the
standard deviation σ is kept fixed (typically set to 1). Since µ is uncon-
strained, we can simply set
µ := f (x, w) ∈ R,
where f : X × W → R is for example a neural network. That is, the
output of f is the mean of the distribution,
EY ∼pµ,1 [Y ] = µ = f (x, w).
We can also use µ to predict P(Y ≤ y) or P(a < Y ≤ b), as shown
above.
68 Probabilistic learning

PDF CDF (Y y) Mean function A 0( ) Loss L( , y)


0.4 1.0 5 y= 2
40 y=0
y=2
0.2 0.5 0 20
= 2 = 2
=0 =0
=2 =2
0.0 0.0 5 0
5 0 5 5 0 5 5 0 5 5 0 5
y y =

Figure 3.3: The Gaussian distribution, with mean parameter µ and variance
σ 2 = 1. Its mean function is µ = A′ (θ) = θ, where θ is for instance the output of a
neural network. The negative log-likelihood leads to the squared loss, L(θ, y) =
(y − θ)2 . The loss curve is shown for y ∈ {−2, 0, 2}.

3.3.6 Multivariate regression


More generally, for multivariate outcomes, where Y = RM , we can
use a multivariate normal distribution with parameters
λ := (µ, Σ),
where µ ∈ RM is the mean and Σ ∈ RM ×M is the covariance matrix.
The PDF is
1 1
 
pµ,Σ (y) := q −1
exp − ⟨y − µ, Σ (y − µ)⟩ .
2π M |Σ| 2

Using a diagonal covariance matrix is equivalent to using M independent


normal distributions for each Yj , for j ∈ [M ]. The expectation is
EY ∼pµ,Σ [Y ] = µ.

Parameterization
Typically, in multivariate regression, the mean is output by a model,
while the covariance matrix is kept fixed (typically set to the identity
matrix). Since µ is again unconstrained, we can simply set
µ := f (x, w) ∈ RM .
More generally, we can parametrize the function f so as to output both
the mean µ and the covariance matrix Σ, i.e.,
(µ, Σ) := f (x, w) ∈ RM × RM ×M .
3.3. Probabilistic supervised learning 69

PMF (Y = y) CDF (Y y) Mean function A 0( ) Loss L( , y)


=1 1.0 20 y=1
40
=4 y=4
0.2 = 10 y = 10
0.5 =1 10 20
=4
= 10
0.0 0.0 0 0
0 10 20 0 10 20 2.5 0.0 2.5 2.5 0.0 2.5
y y

Figure 3.4: The Poisson distribution, with mean parameter λ. For the PMF
and the CDF, the lines between markers are shown for visual aid: the Poisson
distribution does not assign probability mass to non-integer values. Its mean function
is λ = A′ (θ) = exp(θ), where θ is for instance the output of a neural network.
The negative log-likelihood leads to the Poisson loss, L(θ, y) = − log pλ (y) =
−yθ + exp(θ) + log(y!), which is a convex function of θ. The loss curve is shown for
y ∈ {1, 4, 10}.

The function f must be designed such that Σ is symmetric and positive


semi-definite. This is easy to achieve for instance by parametrizing
Σ = SS ⊤ for some matrix S.

3.3.7 Integer regression

For integer outcomes, where Y = N, we can use, among other choices,


a Poisson distribution with mean parameter λ > 0. Its PMF is

λy exp(−λ)
P(Y = y) = pλ (y) := .
y!

Its CDF is
y
P(Y ≤ y) = P(Y = y).
X

i=0

The Poisson distribution implies that the index of dispersion (the


ratio between variance and mean) is 1, since

E[Y ] = V[Y ] = λ.

When this assumption is inappropriate, one can use generalized Poisson


distributions (Satterthwaite, 1942).
70 Probabilistic learning

Parameterization using an exponential


Since the parameter λ of a Poisson distribution needs to be strictly
positive, we typically use an exponential function as output layer:
λ := f (x, w) := exp(g(x, w)) > 0,
where g : X × W → R.

3.3.8 Loss functions


In the conditional setting briefly reviewed in Section 3.3.1, we can use
maximum likelihood estimation (MLE) to estimate the model parame-
ters w ∈ W of f . Given a set of input-output pairs (x1 , y1 ), . . . , (xN , yN ),
we choose the model parameters that maximize the likelihood of the
data,
N
b N := arg max pλi (yi ),
Y
w
w∈W i=1
where
λi := f (xi , w).
Again, this is equivalent to minimizing the negative log-likelihood,
N
b N = arg min − log pλi (yi ).
X
w
w∈W i=1

Interestingly, MLE allows us to recover several popular loss functions.

• For the Bernoulli distribution with parameter λi = πi = logistic(g(xi , w)),


we have
− log pλi (yi ) = − [yi log πi + (1 − yi ) log(1 − πi )] ,
which is the binary logistic loss function.

• For the categorical distribution with parameters λi = πi =


softargmax(g(xi , w)), we have
M
− log pλi (yi ) = log exp(πi,j ) − πi,yi
X

j=1

= logsumexp(πi ) − ⟨πi , eyi ⟩,


3.4. Exponential family distributions 71

which is the multiclass logistic loss function, also known as


cross-entropy loss.
• For the normal distribution with mean λi = µi = f (xi , w) and
fixed variance σi2 , we have
1 1 1
− log pλi (yi ) = 2 (yi − µi )2 + log σi2 + log(2π),
σi 2 2
which is, up to constant and with unit variance, the squared loss
function.
• For the Poisson distribution with mean λi = exp(θi ), where θi :=
g(xi , w), we have
− log pλi (yi ) = −yi log(λi ) + λi + log(yi !)
= −yi θi + exp(θi ) + log(yi !)
which is the Poisson loss function. The loss function is convex
w.r.t. λi and θi for yi ≥ 0.

3.4 Exponential family distributions

3.4.1 Definition
The exponential family is a class of probability distributions, whose
PMF or PDF can be written in the form
h(y) exp [⟨θ, ϕ(y)⟩]
pθ (y) =
exp(A(θ))
= h(y) exp [⟨θ, ϕ(y)⟩ − A(θ)] ,
where θ are the natural or canonical parameters of the distribution.
The function h is known as the base measure. The function ϕ is the
sufficient statistic: it holds all the information about y and is used
to embed y in a vector space. The function A is the log-partition
or log-normalizer (see below for a details). All the distributions we
reviewed in Section 3.3 belong to the exponential family. With some
abuse of notation, we use pλ for the distribution in original form and
pθ for the distribution in exponential family form. As we will see,
we can go from θ to λ and vice-versa. We illustrate how to rewrite a
distribution in exponential family form below.
72 Probabilistic learning

Example 3.2 (Bernoulli distribution). The PMF of the Bernoulli


distribution with parameter λ = π equals

pλ (y) := π y (1 − π)1−y
= exp(log(π y (1 − π)1−y ))
= exp(y log(π) + (1 − y) log(1 − π))
= exp(log(π/(1 − π))y + log(1 − π))
= exp(θy − log(1 + exp(θ)))
= exp(θy − softplus(θ))
=: pθ (y).

Therefore, Bernoulli distributions belong to the exponential family,


with natural parameter θ = logit(π) := log(π/(1 + π)).

We rewrite the previously-described distributions in exponential


family form in Table 3.1. This list is non-exhaustive: there are many
more distributions in the exponential family! (Barndorff-Nielsen, 2014)

3.4.2 The log-partition function


The log-partition function A is the logarithm of the distribution’s
normalization factor. That is,
A(θ) := log h(y) exp [⟨θ, ϕ(y)⟩]
X

y∈Y

for discrete random variables and


Z
A(θ) := log h(y) exp [⟨θ, ϕ(y)⟩] dy
Y
for continuous random variables. We denote the set of valid parameters
Θ := {θ ∈ RM : A(θ) < +∞} ⊆ RM .
We can conveniently rewrite A(θ) for discrete random variables as
A(θ) = logsumexp(B(θ)) := log [B(θ)]y ,
X

y∈Y

and similarly for continuous variables. Here, we defined the affine map
B(θ) := (⟨θ, ϕ(y)⟩ + log h(y))y∈Y .
3.4. Exponential family distributions 73

Table 3.1: Examples of distributions in the exponential family.

Bernoulli Categorical
Y {0, 1} [M ]
λ π = logistic(θ) π = softmax(θ)
θ logit(π) log π + exp(A(θ))
ϕ(y) y ey
A(θ) softplus(θ) logsumexp(θ)
h(y) 1 1

Normal (location only) Normal (location-scale)


Y R R
λ µ = θσ (µ, σ 2 ) = ( −θ
2θ2 , 2θ2 )
1 −1

µ
θ σ ( σµ2 , 2σ
−1
2)
y
ϕ(y) σ (y, y 2 )
θ2 µ2 −θ12 µ2
A(θ) 2 = 2σ 2 4θ2 − 12 log(−2θ2 ) = σ2
+ log σ
2
exp( −y2 )
h(y) √ 2σ √1
2πσ 2π

Multivariate normal Poisson


Y RM N
λ (µ, Σ) = (− 12 θ2−1 θ1 , − 12 θ2−1 ) λ = exp(θ)
θ (Σ−1 µ, − 21 Σ−1 ) log λ
ϕ(y) (y, yy ⊤ ) y
1 ⊤ −1
A(θ) − 4 θ1 θ2 − 12 log | − 2θ2 | exp(θ)
= 21 µ⊤ Σ−1 µ + 12 log |Σ|
h(y) (2π)−M/2 1/y!
74 Probabilistic learning

Since A(θ) is the composition of logsumexp, a convex function, and of


B, an affine map, we immediately obtain the following proposition.

Proposition 3.1 (Convexity of the log-partition). A(θ) is a convex


function.

A major property of the log-partition function is that its gradient


coincides with the expectation of ϕ(Y ) according to pθ .

Proposition 3.2 (Gradient of the log-partition).

µ(θ) := ∇A(θ) = EY ∼pθ [ϕ(Y )] ∈ M.

Proof. The result follows directly from

∇A(θ) = ∂B(θ)∗ ∇logsumexp(B(θ)) = (ϕ(y))y∈Y softmax(B(θ)).

The gradient ∇A(θ) is therefore often called the mean function.


The set of achievable means µ(θ) is defined by

M := conv(ϕ(Y)) := {Ep [ϕ(Y )] : p ∈ P(Y)},

where conv(S) is the convex hull of S and P(Y) is the set of valid
probability distributions over Y.
Similarly, the Hessian ∇2 A(θ) coincides with the covariance matrix
of ϕ(Y ) according to pθ (Wainwright and Jordan, 2008, Chapter 3).
When the exponential family is minimal, which means that the
parameters θ uniquely identify the distribution, it is known that ∇A
is a one-to-one mapping from Θ to M. That is, µ(θ) = ∇A(θ) and
θ = (∇A)−1 (µ(θ)).

3.4.3 Maximum entropy principle


Suppose we observe the empirical mean µ b = N 1
i=1 ϕ(yi ) ∈ M of
n
P

some observations y1 , . . . , yN . How do we find a probability distribution


achieving this mean? Clearly, such a distribution may not be unique.
3.4. Exponential family distributions 75

One way to choose among all possible distributions is by using the


maximum entropy principle. Let us define the Shannon entropy by

H(p) := − p(y) log p(y)


X

y∈Y

for discrete variables and by


Z
H(p) := − p(y) log p(y)dy
Y

for continuous variables. This captures the level of “uncertainty” in


p, i.e., it is maximized when the distribution is uniform. Then, the
maximum entropy distribution satisfying the first-order moment
condition (i.e., whose expectation matches the empirical mean) is

p⋆ := arg max H(p) s.t. EY ∼p [ϕ(Y )] = µ.


b
p∈P(Y)

It can be shown that the maximum entropy distribution is necessarily


in the exponential family with sufficient statistics defined by ϕ and its
canonical parameters θ coincide with the Lagrange multipliers of
the above constraint (Wainwright and Jordan, 2008, Section 3.1).

3.4.4 Maximum likelihood estimation


Similarly as in Section 3.2, to fit the parameters θ ∈ Θ of an exponential
family distribution to some i.i.d. observations y1 , . . . , yN , we can use
the MLE principle, i.e.,
N N
θbN = arg max pθ (yi ) = arg min − log pθ (yi ).
Y X
θ∈Θ i=1 θ∈Θ i=1

Fortunately, for exponential family distributions, the log probabil-


ity/density enjoys a particularly simple form.

Proposition 3.3 (Negative log-likelihood). The negative log-likelihood


of an exponential family distribution is

− log pθ (y) = A(θ) − ⟨θ, ϕ(y)⟩ − log h(y).


76 Probabilistic learning

Its gradient is

−∇θ log pθ (y) = ∇A(θ) − ϕ(y) = EY ∼pθ [ϕ(Y )] − ϕ(y)

and its Hessian is

−∇2θ log pθ (y) = ∇2 A(θ),

which is independent of y.

It follows from Proposition 3.1 that θ 7→ − log pθ (y) is convex.


Interestingly, we see that the gradient is the residual between the
expectation of ϕ(Y ) according to the model and the observed ϕ(y).

3.4.5 Probabilistic learning with exponential families


In the supervised probabilistic learning setting, we wish to estimate a
conditional distribution of the form pw (y | x). Given a model function
f , such as a neural network, a common approach for defining such a
conditional distribution is by reduction to the unconditional setting,

pw (y | x) := pθ (y) where θ := f (x, w).

In other words, the role of f is to produce the parameters of pθ given


some input x. It is a function from X × W to Θ. Note that f must be
designed such that it produces an output in

Θ := {θ ∈ RM : A(θ) < +∞}.

Many times, Θ will be the entire RM but this is not always the case.
For instance, as we previously discussed, for a multivariate normal
distribution, where θ = (µ, Σ) = f (x, w), we need to ensure that Σ is
a positive semidefinite matrix.

Training
Given input-output pairs {(xi , yi )}N
i=1 , we then seek to find the param-
eters w of f (x, w) by minimizing the negative log-likelihood
N N
arg min − log pθi (yi ) = arg min A(θi ) − ⟨θi , ϕ(yi )⟩
X X
w∈W i=1 w∈W i=1
3.5. Summary 77

where θi := f (xi , w). While − log pθ (y) is a convex function of θ for


exponential family distributions, we emphasize that − log pf (x,w) (y) is
typically a nonconvex function of w, when f is a nonlinear function,
such as a neural network.

Inference
Once we found w by minimizing the objective function above, there are
several possible strategies to perform inference for a new input x.

• Expectation. When the goal is to compute the expectation of


ϕ(Y ), we can use ∇A(f (x, w)). That is, we compute the distribu-
tion parameters associated with x by θ = f (x, w) and then we
compute the mean by µ = ∇A(θ). When f is linear in w, the
composition ∇A ◦ f is called a generalized linear model.

• Probability. When the goal is to compute the probability of a


certain y, we can compute the distribution parameters associated
with x by θ = f (x, w) and then we can compute P(Y = y|X =
x) = pθ (y). In the particular case of the categorical distribution
(of which the Bernoulli distribution is a special case), we point
out again that the mean and the probability vector coincide:

µ = p = ∇A(θ) = softargmax(θ) ∈ △M .

• Other statistics. When the goal is to compute other quantities,


such as the variance or the CDF, we can convert the natural pa-
rameters θ to the original distribution parameters λ (see Table 3.1
for examples). Then, we can use established formulas for the
distribution in original form, to compute the desired quantities.

3.5 Summary

• We reviewed discrete and continuous probability distributions.

• We saw how to fit distribution parameters to data using the


maximum likelihood estimation (MLE) principle and saw its
connection with the Kullback-Leibler divergence.
78 Probabilistic learning

• Instead of designing a model function from the input space X to


the output space Y, we saw that we can perform probabilistic
supervised learning by designing a model function from X to
distribution parameters Λ.

• Leveraging the so-obtained parametric conditional distribution


then allowed us to compute, not only output probabilities, but
also various statistics such as the mean and the variance of the
outputs.

• We reviewed the exponential family, a principled generalization


of numerous distributions, which we saw is tightly connected with
the maximum entropy principle.

• Importantly, the approaches described in this chapter produce per-


fectly valid computation graphs, meaning that we can combine
them with neural networks and we can use automatic differentia-
tion, to compute their derivatives.
Part II

Differentiable programs
4
Parameterized programs

Neural networks can be thought of as parameterized programs: programs


with learnable parameters. In this chapter, we begin by reviewing how to
represent programs mathematically. We then review several key neural
network architectures and components.

4.1 Representing computer programs

4.1.1 Computation chains

To begin with, we consider simple programs that apply a sequence of


functions f1 , . . . , fK to an input s0 ∈ S0 . We call such programs com-
putation chains. For example, an image may go through a sequence of
transformations such as cropping, rotation, normalization, and so on. In
neural networks, the transformations are typically parameterized, and
the parameters are learned, leading to feedforward networks, presented
in Section 4.2. Another example of sequence of functions is a for loop,
presented in Section 5.8.

80
4.1. Representing computer programs 81

... ...

Figure 4.1: A computation chain is a sequence of function compositions. In the


graph above, each intermediate node represents a single function. The first node
represents the input, the last node the output. Edges represent the dependencies of
the functions with respect to previous outputs or to the initial input.

Formally, a computation chain can be written as

s0 ∈ S0
s1 := f1 (s0 ) ∈ S1
..
.
sK := fK (sK−1 ) ∈ SK
f (s0 ) := sK . (4.1)

Here, s0 is the input, sk ∈ Sk is an intermediate state of the program,


and sK ∈ SK is the final output. Of course, the domain (input space)
of fk must be compatible with the image (output space) of fk−1 . That is,
we should have fk : Sk−1 → Sk . We can write the program equivalently
as

f (s0 ) = (fK ◦ · · · ◦ f2 ◦ f1 )(s0 )


= fK (. . . f2 (f1 (s0 ))).

A computation chain can be represented by a directed graph, shown


in Fig. 4.1. The edges in the chain define a total order. The order is
total, since two nodes are necessarily linked to each other by a path.

4.1.2 Directed acylic graphs


In generic programs, intermediate functions may depend, not only on
the previous function output, but on the outputs of several different
functions. Such dependencies are best expressed using graphs.
A directed graph G = (V, E) is defined by a set of vertices or
nodes V and a set of edges E defining directed dependencies between
82 Parameterized programs

Figure 4.2: Example of a directed acyclic graph. Here the nodes are V =
{0, 1, 2, 3, 4}, the edges are E = {(0, 1), (0, 2), (0, 3), (1, 3), (2, 3), (1, 4), (3, 4)}. Parents
of the node 3 are pa(3) = {0, 1, 2}. Children of node 1 are ch(1) = {3, 4}. There is
a unique root, 0, and a unique leaf, 4; 0 → 3 → 4 is a path from 0 to 4. This is an
acyclic graph since there is no cycle (i.e., a path from a node to itself). We can order
nodes 0 and 3 as 0 ≤ 3 since there is no path from 3 to 0. Similarly, we can order 1
and 2 as 1 ≤ 2 since there is no path from 2 to 1. Two possible topological orders of
the nodes are (0, 1, 2, 3, 4) and (0, 2, 1, 3, 4).

vertices. An edge (i, j) ∈ E is an ordered pair of vertices i ∈ V and


j ∈ V. It is also denoted i → j, to indicate that j depends on i. For
representing inputs and outputs, it will be convenient to use incoming
half-edges → j and outgoing half-edges i →.

In a graph G = (V, E), the parents of a vertex j is the set of nodes


pointing to j, denoted pa(j) := {i : i → j}. The children of a vertex i
is the set of nodes i is pointing to, that is, ch(i) := {j : i → j}. Vertices
without parents are called roots and vertices without children are called
leaves.

A path from i to j is defined by a sequence of vertices j1 , . . . , jm ,


potentially empty, such that i → j1 → . . . → jm → j. An acyclic graph
is a graph such that there exists no vertex i with a path from i to i. A
directed acyclic graph (DAG) is a graph that is both directed and
acyclic.

The edges of a DAG define a partial order of the vertices, denoted


i ⪯ j if there exists a path from i to j. The order is partial, since
two vertices may not necessarily be linked to each other by a path.
Nevertheless, we can define a total order called a topological order:
any order such that i ≤ j if and only if there is no path from j to i.
4.1. Representing computer programs 83


Figure 4.3: Representation of f (x1 , x2 ) = x2 ex1 x1 + x2 ex1 as a DAG, with
functions and variables as nodes. Edges indicate function and variable dependencies.
The function f is decomposed as 8 elementary functions in topological order.

4.1.3 Computer programs as DAGs


We assume that a program defines a mathematically valid function (a.k.a.
pure function): the program should return identical values for identical
arguments and should not have any side effects. We also assume that the
program halts, i.e., that it terminates in a finite number of steps. As
such a program is made of a finite number of intermediate functions and
intermediate variables, the dependencies between functions and variables
can be expressed using a directed acyclic graph (DAG). Without loss of
generality, we make the following simplifying assumptions:

1. There is a single input s0 ∈ S0 .

2. There is a single output sK ∈ SK .

3. Each intermediate function fk in the program outputs a single


variable sk ∈ Sk .

We number the nodes as V := {0, 1, . . . , K}. Node 0 is the root, corre-


sponding to the input s0 ∈ S0 . Node K is the leaf, corresponding to the
final output sK ∈ SK . Because of the third assumption above, apart
from s0 , each variable sk is in bijection with a function fk . Therefore,
node 0 represents the input s0 , and nodes 1, . . . , K represent both a
function fk and an output variable sk .
Edges in the DAG represent dependencies. The parents i1 , . . . , ipk :=
pa(k) of node k, where pk := |pa(k)|, indicate the variables spa (k) :=
si1 , . . . , sipk that the function fk needs to perform its computation. Put
84 Parameterized programs

Algorithm 4.1 Executing a program


Functions: f1 , . . . , fK in topological order
Input: input s0 ∈ S0
1: for k := 1, . . . , K do
2: Retrieve parent nodes i1 , . . . , ipk := pa(k)
3: Compute sk := fk (spa(k) ) := fk (si1 , . . . , sipk )
4: Output: f (s0 ) := sK

differently, the parents i1 , . . . , ipk indicate the functions fi1 , . . . , fipk


that need to be evaluated, prior to evaluating fk .
An example of computation graph in our formalism is presented
in Fig. 4.3.

Executing a program

To execute a program, we need to ensure that we evaluate the intermedi-


ate functions in the correct order. Therefore, we assume that the nodes
0, 1, . . . , K are in a topological order (if this is not the case, we need
to perform a topological sort first). We can then execute a program by
evaluating for k ∈ [K]

sk := fk (spa(k) ) := fk (si1 , . . . , sipk ) ∈ Sk .

Note that we can either view fk as a single-input function of spa(k) ,


which is a tuple of elements, or as a multi-input function of si1 , . . . , sipk .
The two views are essentially equivalent.
The procedure for executing a program is summarized in Algo-
rithm 4.1.

Dealing with multiple program inputs or outputs

When a program has multiple inputs, we can always group them into
s0 ∈ S0 as s0 = (s0,1 , . . . , s0,N0 ) with S0 = (S0,1 × · · · × S0,N0 ), since
later functions can always filter out what elements of s0 they need.
Likewise, if an intermediate function fk has multiple outputs, we can
always group them as a single output sk = (sk,1 , . . . , sk,Nk ) with Sk =
4.1. Representing computer programs 85

Figure 4.4: Two possible representations of a program. Left: Functions and output
variables are represented by the same nodes. Right: functions and variables are
represented by a disjoint set of nodes.

(Sk,1 × · · · × Sk,Nk ), since later functions can filter out the elements of
sk that they need.

Alternative representation: bipartite graphs

In our formalism, because a function fk always has a single output sk , a


node k can be seen as representing both the variable sk and the function
fk . Alternatively, as shown in Fig. 4.4, we can represent variables and
functions as separate nodes, that is, using a bipartite graph. This
formalism is akin to factor graphs (Frey et al., 1997; Loeliger, 2004)
used in probabilistic modeling, but with directed edges. One advantage
of this formalism is that is allows functions to explicitly have multiple
outputs. We focus on our formalism for simplicity.

4.1.4 Arithmetic circuits

Arithmetic circuits are one of the simplest examples of computation


graph, originating from computational complexity theory. Formally,
an arithmetic circuit over a field F, such as the reals R, is a directed
acyclic graph (DAG) whose root nodes are elements of F and whose
functions fk are either + or ×. The latter are often called gates.
Contrary to the general computation graph case, because each fk is
either + or ×, it is important to allow the graph to have several root
nodes. Root nodes can be either variables or constants, and should
belong to F.
Arithmetic circuits can be used to compute polynomials. There
are potentially multiple arithmetic circuits for representing a given
polynomial. One important question is then to find the most efficient
arithmetic circuit for computing a given polynomial. To compare arith-
metic circuits representing the same polynomial, an intuitive notion of
86 Parameterized programs

complexity is the circuit size, as defined below.

Definition 4.1 (Circuit and polynomial sizes). The size S(C) of a


circuit C is the number of edges in the directed acyclic graph
representing C. The size S(f ) of a polynomial f is the smallest S(C)
among all C representing f .

For more information on arithmetic circuits, we refer the reader to


the monograph of Chen et al. (2011).

4.2 Feedforward networks

A feedforward network can be seen as a computation chain with pa-


rameterized functions fk ,

s0 := x
s1 := f1 (s0 , w1 )
s2 := f2 (s1 , w2 )
..
.
sK := fK (sK−1 , wK ),

for a given input x ∈ X and learnable parameters w1 , . . . , wK ∈


W1 × · · · × WK . Each function fk is called a layer and each sk ∈ Sk
can be seen as an intermediate representation of the input x. The
dimensionality of Sk is known as the width (or number of hidden units)
of layer k. A feedforward network defines a function sK =: f (x, w) from
X × W to SK , where w := (w1 , . . . , wK ) ∈ W := W1 × . . . × WK .
Given such a parameterized program, we can learn the parameters
by adjusting w to fit some data. For instance, given a dataset of (xi , yi )
pairs, we may minimize the squared loss ∥yi −f (xi , w)∥22 on average over
the data, w.r.t. w. The minimization of such a loss requires accessing
its gradients with respect to w.
4.3. Multilayer perceptrons 87

4.3 Multilayer perceptrons

4.3.1 Combining affine layers and activations

In the previous section, we did not specify how to parametrize the


feedforward network. A typical parametrization, called the multilayer
perceptron (MLP), uses fully-connected (also called dense) layers of
the form
sk = fk (sk−1 , wk ) := ak (Wk sk−1 + bk ),
where we defined the tuple wk := (Wk , bk ) and where we assumed that
Wk and bk are a matrix and vector of appropriate size. We can further
decompose the layer into two functions. The function s 7→ Wk s + bk
is called an affine layer. The function v 7→ ak (v) is a parameter-free
nonlinearity, often called an activation function (see Section 4.4).
More generally, we may replace the matrix-vector product Wk sk−1
by any parametrized linear function of sk−1 . For example, convolu-
tional layers use the convolution of an input sk−1 with some filters
Wk , seen as a linear map.

Remark 4.1 (Dealing with multiple inputs). Sometimes, it is neces-


sary to deal with networks of multiple inputs. For example, sup-
pose we want to design a function g(x1 , x2 , wg ), where x1 ∈ X1
and x2 ∈ X2 . A simple way to do so is to use the concatenation
x := x1 ⊕ x2 ∈ X2 ⊕ X2 as input to a network f (x, wf ). Alterna-
tively, instead of concatenating x1 and x2 at the input layer, they
can be concatenated after having been through one or more hidden
layers.

4.3.2 Link with generalized linear models

When the depth is K = 1 (only one layer), the output of an MLP is

s1 = a1 (W1 x + b1 ).

This is called a generalized linear model (GLM); see Section 3.4.


Therefore, MLPs include GLMs as a special case. In particular, when
a1 is the softargmax (see Section 4.4), we obtain (multiclass) logistic
88 Parameterized programs

regression. For general depth K, the output of an MLP is

sK = aK (WK sK−1 + bK ).

This can be seen as a GLM on top of learned representation sK−1


of the input x. This is the main appeal of MLPs: they learn the feature
representation and the output model at the same time! We will see that
MLPs can also be used as subcomponents in other architectures.

4.4 Activation functions

As we saw in Section 4.3, feedforward networks typically use an ac-


tivation function ak at each layer. In this section, we present various
nonlinearities from scalar to scalar or from vector to scalar. We also
present probability mappings that can be used as such activations.

4.4.1 ReLU and softplus

Many activations are scalar-to-scalar functions, but they can also


be applied to vectors in an element-wise fashion. The ReLU (rectified
linear unit) is a popular nonlinearity defined as the non-negative part
of its input 
u, u≥0
relu(u) := max(u, 0) = .
0, u<0
It is a piecewise linear function and includes a kink at u = 0. A multilayer
perceptron with ReLU activations is called a rectifier neural network.
The layers take the form

sk = relu(Ak sk−1 + bk ),

where the ReLU is applied element-wise. The ReLU can be replaced


with a smooth approximation (i.e., without kinks), called the softplus

softplus(u) := log(1 + eu ).

Unlike the ReLU, it is always strictly positive. Other smoothed variants


of the ReLU are possible, see Section 13.4.
4.4. Activation functions 89

4.4.2 Max pooling and log-sum-exp

Many activations are vector-to-scalar functions: they reduce vec-


tors to a scalar value. This scalar value can be seen as a statistic,
“summarizing” the vector.

Max pooling

A simple way to do so is to use the maximum value, also known as max


pooling. Given a vector u ∈ RM , it is defined as

max(u) := max uj .
j∈[M ]

Log-sum-exp as a soft maximum

As a smooth approximation of it, we can use the log-sum-exp function


M
logsumexp(u) := softmax(u) := log
X
euj ,
j=1

which is known to behave like a soft maximum. The log-sum-exp can


be seen as a generalization of the softplus, as we have for all u ∈ R

logsumexp((u, 0)) = softplus(u).

A numerically stable implementation of the log-sum-exp is given by

logsumexp(u) = logsumexp(u − c 1) + c,

where c := maxj∈[M ] uj .
More generally, we can introduce a temperature parameter γ > 0

logsumexpγ (u) = γ · logsumexp(u/γ).

It can be shown that for all u ∈ RM ,

max(u) ≤ logsumexpγ (u) ≤ max(u) + γ · log(M ).

Therefore, logsumexpγ (u) → max(u) as γ → 0. Other definitions of


soft maximum are possible; see Section 13.5.
90 Parameterized programs

Log-sum-exp as a log-domain sum


Besides its use as a soft maximum, the log-sum-exp often arises for
computing sums in the log domain. Indeed, suppose we want to compute
s := Mj=1 ui , where ui > 0. If we define ũi := log ui and s̃ := log s, we
P

then have
M
s̃ = log exp(ũi ).
X

j=1
Written differently, we have the identity
M
!
log = logsumexp(log(u)).
X
ui
i=1

We can therefore see the log-sum-exp as the sum counterpart of the


identity for products
 
M M
log  ui  = log(ui ).
Y X

j=1 i=1

As an example, we use the log-sum-exp to perform the forward-backward


algorithm in the log-domain in Section 10.7.1.

4.4.3 Sigmoids: binary step and logistic functions


Oftentimes, we want to map a real value to a number in [0, 1], that can
represent the probability of an event. For that purpose, we generally
use sigmoids. A sigmoid is a function with a characteristic “S”-shaped
curve. These functions are scalar-to-scalar probability mappings: they
are used to squash real values to [0, 1].

Binary step function


An example is the binary step function, also known as Heaviside
step function, 
1, u ≥ 0
step(u) := .
0, u < 0

It is a mapping from R to {0, 1}. Unfortunately, it has a discontinuity: a


jump in its graph at u = 0. Moreover, because the function is constant
4.4. Activation functions 91

at all other points, it has zero derivative at these points, which makes it
difficult to use as part of a neural network trained with backpropagation.

Logistic function
A better sigmoid is the logistic function, which is a mapping from R
to (0, 1) and is defined as
1
logistic(u) :=
1 + e−u
eu
=
1 + eu
1 1 u
 
= + tanh .
2 2 2
It maps (−∞, 0) to (0, 0.5),[0, +∞) to [0.5, 1) and it satisfies logistic(0) =
0.5. It can therefore be seen as mapping from real values to probability
values. The logistic can be seen as a differentiable approximation to the
discontinuous binary step function step(u). The logistic function can
be shown to be the derivative of softplus, i.e., for all u ∈ R
softplus′ (u) = logistic(u).
Two important properties of the logistic function are that for all u ∈ R
logistic(−u) = 1 − logistic(u)
and
logistic′ (u) = logistic(u) · logistic(−u)
= logistic(u) · (1 − logistic(u)).
Other sigmoids are possible; see Section 13.6.

4.4.4 Probability mappings: argmax and softargmax


It is often useful to transform a real vector into a vector of probabilities.
This is a mapping from RM to the probability simplex, defined by
 
 M 
△M := π ∈ RM : ∀j ∈ [M ], πj ≥ 0, πj = 1 .
X
 
j=1

Two examples of such vector-to-vector probability mappings are the


argmax and the softargmax.
92 Parameterized programs

Argmax
The argmax operator is defined by
!
argmax(u) := ϕ arg max uj ,
j∈[M ]

where ϕ(j) denotes the one-hot encoding of an integer j ∈ [M ], that is,


ϕ(j) := (0, . . . , 0, |{z}
1 , 0, . . . , 0) = ej ∈ {0, 1}M .
j

This mapping puts all the probability mass onto a single coordinate (in
case of ties, we pick a single coordinate arbitrarily). Unfortunately, this
mapping is a discontinuous function.

Softargmax
As a differentiable everywhere relaxation, we can use the softargmax
defined by
exp(u)
softargmax(u) := PM .
j=1 exp(uj )
This operator is commonly known in the literature as softmax but this
is a misnomer: this operator really defines a differentiable relaxation
of the argmax. The output of the softargmax belongs to the relative
interior of the probability simplex, meaning that it can never reach the
borders of the simplex. If we denote π = softargmax(u), this means
that πj ∈ (0, 1), that is, πj can never be exactly 0 or 1. The softargmax
is the gradient of log-sum-exp,
∇logsumexp(u) = softargmax(u).
The softargmax can be seen as a generalization of the logistic function,
as we have for all u ∈ R
[softargmax((u, 0))]1 = logistic(u).

Remark 4.2 (Degrees of freedom and invertibility of softargmax). The


softargmax operator satisfies the property for all u ∈ RM and c ∈ R

π := softargmax(u) = softargmax(u + c 1).


4.5. Residual neural networks 93

This means that the softargmax operator has M − 1 degrees of


freedom and is a non-invertible function. However, due to the
above property, without loss of generality, we can impose u⊤ 1 =
i=1 ui = 0 (if this is not the case, we simply do ui ← ui − ū,
PM

where ū := M1 PM
j=1 uj ). Using this constraint together with

M
log πi = ui − log exp(uj ),
X

j=1

we then obtain
M M
log πi = −M log exp(uj )
X X

i=1 j=1

so that
M
1 X
ui = [softargmax−1 (π)]i = log πi − log πj .
M j=1

4.5 Residual neural networks

We now discuss another feedforward network parametrization: residual


neural networks. Consider a feedforward network with K + 1 layers
f1 , . . . , fK , fK+1 . Surely, as long as fK+1 can exactly represent the
identity function, the set of functions that this feedforward network
can express should be a superset of the functions that f1 , . . . , fK can
express. In other words, depth should in theory not hurt the expressive
power of feedforward networks. Unfortunately, the assumption that each
fk can exactly represent the identity function may not hold in practice.
This means that deeper networks can sometimes be more difficult to
train than shallower ones, making the accuracy saturate or degrade as
a function of depth.
The key idea of residual neural networks (He et al., 2016) is to design
layers fk , called residual blocks, that make it easier to represent the
identity function. Formally, a residual block takes the form

sk = fk (sk−1 , wk ) := sk−1 + hk (sk−1 , wk ).


94 Parameterized programs

The function hk is called residual, since it models the difference sk −


sk−1 . The addition with sk−1 is often called a skip connection. As
long as it is easy to adjust wk so that hk (sk−1 , wk ) = 0, fk can freely
become the identity function. For instance, if we use

hk (sk−1 , wk ) := Ck ak (Wk sk−1 + bk ) + dk ,

where wk := (Wk , bk , Ck , dk ), it suffices to set Ck and dk to a zero


matrix and vector. Residual blocks are known to remedy the so-called
vanishing gradient problem.
Many papers and software packages include an additional activation
and instead define the residual block as

sk = fk (sk−1 , wk ) := ak (sk−1 + hk (sk−1 , wk )),

where ak is typically chosen to be the ReLU activation. Whether to


include this additional activation or not is essentially a modelling choice.
In practice, residual blocks may also include additional operations such
as batch norm and convolutional layers.

4.6 Recurrent neural networks

Recurrent neural networks (RNNs) are a class of neural networks that


operate on sequences of vectors, either as input, output or both. Their
actual parametrization depends on the setup but the core idea is to
maintain a state vector that is updated from step to step by a recursive
function that uses shared parameters across steps. Unrolling this
recursion defines a valid computational graph, as we will see in Chapter 8.
We distinguish between the following setups illustrated in Fig. 4.5:

• Vector to sequence (one to many):


f d : RD × RP → RL×M

• Sequence to vector (many to one):


f e : RL×D × RP → RM

• Sequence to sequence (many to many, aligned):


f a : RL×D × RP → RL×M
4.6. Recurrent neural networks 95

(a) One to many (decoder) (b) Many to one (encoder)

(c) Sequence to sequence aligned (d) Sequence to sequence unaligned

Figure 4.5: Recurrent neural network architectures

• Sequence to sequence (many to many, unaligned):



f u : RL×D × RP → RL ×M

where L stands for length. Note that we use the same number of
parameters P for each setup for notational convenience, but this of
course does not need to be the case. Throughout this section, we use
the notation p1:L := (p1 , . . . , pL ) for a sequence of L vectors.

4.6.1 Vector to sequence


In this setting, we define a decoder function p1:L = f d (x, w) from an
input vector x ∈ RD and parameters w ∈ RP to an output sequence
p1:L ∈ RL×M . This is for instance useful for image caption generation,
where a sentence (a sequence of word embeddings) is generated from
an image (a vector of pixels). Formally, we may define p1:L := f d (x, w)
96 Parameterized programs

through the recursion

zl := g(x, zl−1 , wg ) l ∈ [L]


pl := h(zl , wh ) l ∈ [L].

where w := (wg , wh , z0 ). The goal of g is to update the current decoder


state zl given the input x, and the previous decoder state zl−1 . The
goal of h is to generate the output pl given the current decoder state
zl . Importantly, the parameters of g and h are shared across steps.
Typically, g and h are parametrized using one-hidden-layer MLPs. Note
that g has multiple inputs; we discuss how to deal with such cases in
Section 4.3.

4.6.2 Sequence to vector

In this setting, we define an encoder function p = f e (x1:L , w) from an


input sequence x1:L ∈ RL×D and parameters w ∈ RP to an output
vector p ∈ RM . This is for instance useful for sequence classification,
such as sentiment analysis. Formally, we may define p := f e (x1:L , w)
using the recursion

sl := γ(xl , sl−1 , wg ) l ∈ [L]


p = pooling(s1:L )

where w := (wg , s0 ). The goal of γ is similar as g, except that it updates


encoder states and does not take previous predictions as input. The
pooling function is typically parameter-less. Its goal is to reduce a
sequence to a vector. Examples include using the last state, the average
of states and the coordinate-wise maximum of states.

4.6.3 Sequence to sequence (aligned)

In this setting, we define a function p1:L = f a (x1:L , w) from an in-


put sequence x1:L ∈ RL×D and parameters w ∈ RP to an output
sequence p1:L ∈ RL×M , which we assume to be of the same length.
An example of application is part-of-speech tagging, where the goal is
to assign each word xl to a part-of-speech (noun, verb, adjective, etc).
4.6. Recurrent neural networks 97

Formally, we may define p1:L = f a (x1:L , w) as

sl := γ(xl , sl−1 , wγ ) l ∈ [L]


pl = h(sl , wh ) l ∈ [L]

where w := (wγ , wh , s0 ). The function γ and h are similar as before.

4.6.4 Sequence to sequence (unaligned)

In this setting, we define a function p1:L′ = f u (x1:L , w) from an input


sequence x1:L ∈ RL×D and parameters w ∈ RP to an output sequence

p1:L′ ∈ RL ×M , which potentially has a different length. An example
of application is machine translation, where the sentences in the source
and target languages do not necessarily have the same length. Typically,
p1:L′ = f u (x1:L , w) is defined as the following two steps

c := f e (x1:L , we )
p1:L′ := f d (c, wd )

where w := (we , wd ), and where we reused the previously-defined


encoder fe and decoder fd . Putting the two steps together, we obtain

sl := γ(xl , sl−1 , wγ ) l ∈ [L]


c = pooling(s1:L )
zl := g(c, pl−1 , zl−1 , wg ) l ∈ [L′ ]
pl := h(zl , wh ) l ∈ [L′ ].

This architecture is aptly named the encoder-decoder architecture.


Note that we denoted the length of the target sequence as L′ . However,
in practice, the target length can be input dependent and is often not
known ahead of time. To deal with this issue, the vocabulary (of size D
is our notation) is typically augmented with an “end of sequence” (EOS)
token so that, at inference time, we know when to stop generating the
output sequence. One disadvantage of this encoder-decoder architecture,
however, is that all the information about the input sequence is contained
in the context vector c, which can therefore become a bottleneck.
98 Parameterized programs

4.7 Summary

• Programs can be mathematically represented as a directed acyclic


graph.

• Neural networks are parameterized programs.

• Feedfoward networks are parameterized computation chains.

• Multilayer perceptrons (MLPs), residual neural networks (ResNets)


and convolutional neural network (CNNs) are all particular parametriza-
tions of feedforward networks.
5
Control flows

Control flows, such as conditionals or loops, are an essential part of


computer programming, as they allow us to express complex programs.
It is therefore natural to ask whether these constructs can be included
in a differentiable program. This is what we study in this chapter.

5.1 Comparison operators

Control flows rely on comparison operators, a.k.a. relational oper-


ators. Formally, we can define a comparison operator π = op(u1 , u2 )
as a function from u1 ∈ R and u2 ∈ R to π ∈ {0, 1}. The binary
(Boolean) output π can then be used within a conditional statement
(see Section 5.6, Section 5.7) to decide whether to execute one branch
or another. We define the following operators, illustrated in Fig. 5.1:

• greater than:

1 if u1 ≥ u2
gt(u1 , u2 ) :=
0 otherwise
= step(u1 − u2 )

99
100 Control flows

• less than:

1 if u1 ≤ u2
lt(u1 , u2 ) :=
0 otherwise
= 1 − gt(u1 , u2 )
= step(u2 − u1 )

• equal:

1 if |u1 − u2 | = 0
eq(u1 , u2 ) :=
0 otherwise
= gt(u2 , u1 ) · gt(u1 , u2 )
= step(u2 − u1 ) · step(u1 − u2 )

• not equal:

1 if |u1 − u2 | > 0
neq(u1 , u2 ) :=
0 otherwise
= 1 − eq(u1 , u2 )
= 1 − step(u2 − u2 ) · step(u1 − u2 ),

where step : R → {0, 1} is the Heaviside step function



1 if u ≥ 0
step(u) := .
0 otherwise

The Heaviside step function is piecewise constant. At u = 0, the function


is discontinuous. At u ̸= 0, it is continuous and has null derivative.
Since the comparison operators we presented are all expressed in terms
of the step function, they are all continuous and differentiable almost
everywhere, with null derivative. Therefore, while their derivatives are
well-defined almost everywhere, they are uninformative and prevent
gradient backpropagation.
5.2. Soft inequality operators 101

4
u1 greater than u2 1.0 4
u1 equal to u2 1.0

Value

Value
2 0.5 2 0.5
u2

u2
0 0.0 0 0.0
0 2 4 0 2 4
u1 u1
Smoothed operators with logistic
4
u1 greater than u2 1.0 4
u1 equal to u2 1.0
Value

Value
2 0.5 2 0.5
u2

u2

0 0.0 0 0.0
0 2 4 0 2 4
u1 u1

Figure 5.1: The greater than and equal to operators are discontinuous functions,
leading to black or white pictures. They can be smoothed with appropriate approxi-
mations of the Heaviside step function.

5.2 Soft inequality operators

5.2.1 Heuristic definition

To obtain a continuous relaxation of inequality operators, we can


heuristically replace the step function in the expression of “greater
than” and “less than” by a sigmoid function sigmoidσ , where σ > 0 is a
scaling parameter. Such a sigmoid function should satisfy the following
properties:

• sigmoidσ (−u) = 1 − sigmoidσ (u),

• limu→∞ sigmoidσ (u) = 1,

• limu→−∞ sigmoidσ (u) = 0,

• sigmoidσ (0) = 12 .
102 Control flows

Two examples of sigmoids satisfying the aforementioned properties are


the logistic function
1
sigmoidσ (u) := logisticσ (u) := ∈ (0, 1)
1 + e−u/σ
and the standard Gaussian’s CDF

sigmoidσ (u) := Φ(u/σ).

We may then define the soft “greater than”

gt(µ1 , µ2 ) = step(µ1 − µ2 )
≈ sigmoidσ (µ1 − µ2 )
=: gtσ (µ1 , µ2 )

and the soft “less than”

lt(µ1 , µ2 ) = step(µ2 − µ1 )
≈ sigmoidσ (µ2 − µ1 )
=: ltσ (µ1 , µ2 )
= 1 − sigmoidσ (µ1 − µ2 )
= 1 − gtσ (µ1 − µ2 ).

In the limit, we have that sigmoidσ (µ1 − µ2 ) → 1 when µ1 − µ2 → ∞.


In the limit, sigmoidσ therefore outputs a value of 1 if µ1 and µ2 are
infinitely apart. Besides the logistic function and the standard Gaussian’s
CDF, other sigmoid functions are possible, as discussed in Section 13.6.
In particular, with sparse sigmoids, there exists a finite value τ such
that µ1 − µ2 ≥ τ =⇒ sigmoidσ (µ1 − µ2 ) = 1.

5.2.2 Stochastic process perspective


When the sigmoid used to replace the step function is the logistic
function or the standard Gaussian’s CDF, we can revisit the previous
heuristic definition of gtσ (µ1 , µ2 ) and ltσ (µ1 , µ2 ) from a more formal
perspective. Indeed, to real values µ1 ∈ R and µ2 ∈ R, we can associate
random variables

U1 ∼ pµ1 ,σ1 and U2 ∼ pµ2 ,σ2 ,


5.2. Soft inequality operators 103

thereby forming a stochastic process (we assume that σ1 and σ2 are


fixed). Alternatively, we can also write

(U1 , U2 ) ∼ pµ1 ,σ1 ⊗ pµ2 ,σ2 ,

where for two distributions p1 and p2 , we denote p1 ⊗ p2 their outer


product (p1 ⊗ p2 )(u1 , u2 ) := p1 (u1 )p2 (u2 ). We can then define

gtσ (µ1 , µ2 ) = E [gt(U1 , U2 )]


= E [step(U1 − U2 )]
= P(U1 − U2 > 0)
= 1 − P(U1 − U2 ≤ 0)
= 1 − FU1 −U2 (0),

where FX is the cumulative distribution function (CDF) of the


random variable X, and σ is a function of σ1 and σ2 . Similarly, we
obtain

ltσ (µ1 , µ2 ) = E [lt(U1 , U2 )]


= E [step(U2 − U1 )]
= P(U1 − U2 ≤ 0)
= FU1 −U2 (0).

We see that the soft inequality operators are based on the CDF of the
difference between U1 and U2 .
From a perturbation perspective, we can also define noise variables
Z1 ∼ p0,1 and Z2 ∼ p0,1 such that U1 = µ1 + σ1 Z1 and U2 = µ2 + σ2 Z2
(Section 12.4.1). We then have

gtσ (µ1 , µ2 ) = E [gt(µ1 + σ1 Z1 , µ2 + σ2 Z2 )]


ltσ (µ1 , µ2 ) = E [lt(µ1 + σ1 Z1 , µ2 + σ2 Z2 )] .

Gaussian case

When U1 ∼ Normal(µ1 , σ12 ) and U2 ∼ Normal(µ2 , σ22 ), we have

U1 − U2 ∼ Normal(µ1 − µ2 , σ12 + σ22 ). (5.1)


104 Control flows

Denoting Φ the standard Gaussian’s CDF, we then obtain


µ1 − µ 2
 
gtσ (µ1 , µ2 ) = Φ
σ
µ2 − µ 1
 
ltσ (µ1 , µ2 ) = Φ ,
σ
q
where σ := σ12 + σ22 . The corresponding distribution for Z1 and Z2 is
Gaussian noise.

Logistic case
When U1 ∼ Gumbel(µ1 , σ) and U2 ∼ Gumbel(µ2 , σ), we have

U1 − U2 ∼ Logistic(µ1 − µ2 , σ). (5.2)

We then obtain (see Proposition 14.3)


µ 1 − µ2
 
gtσ (µ1 , µ2 ) = logistic
σ
µ 2 − µ1
 
ltσ (µ1 , µ2 ) = logistic .
σ
The corresponding distribution for Z1 and Z2 is Gumbel noise.

Recovering hard inequality operators


We easily recover the “hard” inequality operator by

gt(µ1 , µ2 ) = E [gt(U1 , U2 )] ,

where Ui ∼ δµi and where δµi is the delta distribution that assigns a
probability of 1 to µi .

5.3 Soft equality operators

5.3.1 Heuristic definition


The equality operator eq(µ1 , µ2 ) can be seen as an extreme kind of
similarity function between numbers, that can only output the values
0 or 1. To define soft equality operators, a natural idea is therefore
5.3. Soft equality operators 105

to replace the equality operator by a more general similarity function.


A similarity function should achieve its maximum at µ1 = µ2 and it
should decrease as µ1 and µ2 move apart. A common family of similarity
functions are kernels. Briefly, a kernel k(µ1 , µ2 ) can be seen as the
inner product
k(µ1 , µ2 ) := ⟨ϕ(µ1 ), ϕ(µ2 )⟩
between the embbedings ϕ(µ1 ) and ϕ(µ2 ) of µ1 and µ2 in some (po-
tentially infinite-dimensional) space H, a reproducing kernel Hilbert
space to be precise; see Schölkopf and Smola (2002) and Shawe-Taylor
and Cristianini (2004) for an in-depth review of kernels. To obtain a
similarity measure between 0 and 1 approximating the equality operator,
we can normalize to obtain
k(µ1 , µ2 )
eq(µ1 , µ2 ) ≈ p
k(µ1 , µ1 )k(µ2 , µ2 )
⟨ϕ(µ1 ), ϕ(µ2 )⟩
= ,
∥ϕ(µ1 )∥∥ϕ(µ2 )∥

where ∥ϕ(µ)∥ := ⟨ϕ(µ), ϕ(µ)⟩ = κ(µ, µ). This is the cosine simi-
p p

larity between ϕ(µ1 ) and ϕ(µ2 ).


A particular kind of kernel are isotropic kernels of the form

k(µ1 , µ2 ) := κ(µ1 − µ2 ),

that depend only on the difference between inputs. When the kernel has
a scale parameter σ > 0, we use the notation κσ . We can then define a
soft equality operator as
κσ (µ1 − µ2 )
eq(µ1 , µ2 ) ≈ eqσ (µ1 , µ2 ) := .
κσ (0)
Several isotropic kernels can be chosen such as the Gaussian kernel
!
t2
κσ (t) := exp − 2

or the logistic kernel


t
 
κσ (t) := sech 2
,

106 Control flows

Soft equal zero Soft greater than zero


1.00 1.00
0.75 0.75
0.50 0.50
Hard
0.25 0.25 Logistic
Gaussian
0.00 0.00
2 0 2 2 0 2
Figure 5.2: Soft equality and soft greater than operators can be defined as normalized
kernels (PDF) and as CDF functions, respectively.

where we defined the hyperbolic secant

sech(u) := 2/(exp(u) + exp(−u)).

As their names suggest, these kernels arise naturally from a probabilistic


perspective, that we present below.
The soft equality operators obtained with these kernels are illustrated
in Fig. 5.2. Intuitively, we replaced a bar located at µ1 = µ2 with a
bump function. The soft equality operator obtained with the logistic
kernel coincides with the expression Petersen et al. (2021) arrive at (see
their Eq. 9), in a different manner.

5.3.2 Stochastic process perspective


We again adopt the stochastic process perspective, in which we associate
random variables

U1 ∼ pµ1 ,σ1 and U2 ∼ pµ2 ,σ2

to real values µ1 ∈ R and µ2 ∈ R. However, to handle the equality


operator, we cannot simply use the expectation of eq(U1 , U2 ) since

E[eq(U1 , U2 )] = P(U1 = U2 ) = 0,

U1 and U2 being independent continuous variables. While we cannot


use the probability of U1 = U2 , or equivalently of U1 − U2 = 0, we can
consider using the probability density function (PDF) fU1 −U2 of U1 − U2
evaluated at 0. To ensure that the maximum is achieved at 0 with value
5.3. Soft equality operators 107

1, we can normalize the PDF to define


fU1 −U2 (0)
eqσ (µ1 , µ2 ) = .
f0 (0)
It is well-known that the PDF of the sum of two random variables is
the convolution of their respectives PDFs. We therefore have

fU1 −U2 (t) = (fU1 ∗ f−U2 )(t)


Z ∞
= fU1 (τ )f−U2 (t − τ )dτ.
−∞

In particular, with t = 0, if fX is the PDF of a location-scale family


distributed random variable, we obtain

fU1 −U2 (0) = (fU1 ∗ f−U2 )(0)


Z ∞
= fU1 (τ )f−U2 (−τ )dτ
−∞
Z ∞
= fU1 (τ )fU2 (τ )dτ
−∞
:= ⟨fU1 , fU2 ⟩
:= k(µ1 , µ2 ).

We indeed recover an inner product and therefore a kernel.

CDF and PDF of absolute difference

While P(U1 = U2 ) = 0, we can also consider using P(|U1 − U2 | ≤ ε) =


F|U1 −U2 | (ε) as an alternative notion of soft equality. For any random
variable X, we have

F|X| (x) = P(|X| ≤ x)


= P(−x ≤ X ≤ x)
= P(X ≤ x) − P(X ≤ −x)
= FX (x) − FX (−x).

Therefore,

P(|U1 − U2 | ≤ ε) = FU1 −U2 (ε) − FU1 −U2 (ε).


108 Control flows

We can also derive the PDF of |X| as

f|X| (x) = FX′ (x) − FX′ (−x)


= fX (x) + fX (−x)

and in particular
f|X| (0) = 2fX (0).
Therefore
f|U1 −U2 | (0) = 2fU1 −U2 (0),
further justifying using the PDF of U1 − U2 evaluated at 0. When X
follows a normal distribution, |X| follows the so-called folded normal
distribution.

Gaussian case
When U1 ∼ Normal(µ1 , σ12 ) and U2 ∼ Normal(µ2 , σ22 ), we obtain from
Eq. (5.1)
1 (t − (µ1 − µ2 ))2
!
fU1 −U2 (t) = √ exp −
2π 2(σ12 + σ22 )
so that
(µ1 − µ2 )2
!
eqσ (µ1 , µ2 ) = exp ∈ [0, 1].
2(σ12 + σ22 )
We indeedqrecover κσ (µ1 − µ2 )/κσ (0), where κσ is the Gaussian kernel
with σ = σ12 + σ22 . For the CDF of the absolute difference, we obtain

ε − (µ1 − µ2 ) −ε − (µ1 − µ2 )
   
P(|U1 − U2 | ≤ ε) = Φ −Φ .
σ σ

Logistic case
When U1 ∼ Gumbel(µ1 , σ) and U2 ∼ Gumbel(µ2 , σ), recalling that

sech(u) := 2/(exp(u) + exp(−u)),

we obtain from Eq. (5.2)


1 t − (µ1 − µ2 )
 
fU1 −U2 (t) = sech2
4σ 2σ
5.3. Soft equality operators 109

so that
µ1 − µ2
 
eqσ (µ1 , µ2 ) = sech2 ∈ [0, 1].

We indeed recover κσ (µ1 − µ2 )/κσ (0), where κσ is the logistic kernel
with σ = σ1 = σ2 .

5.3.3 Gaussian process perspective


The previous approach relied on mapping µ1 and µ2 to two indepen-
dent random variables U1 ∼ pµ1 ,σ1 and U2 ∼ pµ2 ,σ2 (we assume that σ1
and σ2 are fixed). Instead, we can consider mapping µ1 and µ2 to two
dependent random variables U1 and U2 , whose covariance depends on
the similarity between µ1 and µ2 . We can do so by using a Gaussian
process (Hida and Hitsuda, 1976).
A Gaussian process on R is a stochastic process {Uµ : µ ∈ R} indexed
by µ ∈ R such that any subset of K random variables (Uµ1 , . . . , UµK )
associated with (µ1 , . . . , µK ) ∈ R is a multivariate Gaussian random
variable. The Gaussian process is characterized by the mean function
µ 7→ E[Uµ ], and its covariance function (µi , µj ) 7→ Cov(Uµi , Uµj ). For
the mean function, we may simply choose E[Uµ ] = µ. For the covariance
function, we need to ensure that the variance of any combination of
random variables in the Gaussian process is non-negative. This property
is satisfied by kernel functions. We can therefore define

Cov(Uµi , Uµj ) := k(µ1 , µ2 ),

for some kernel k. Equipped with such a mapping from real numbers
to random variables, we need a measure of similarity between random
variables. A natural choice is their correlation
Cov(Uµi , Uµj )
corr(Uµi , Uµj ) := q ∈ [0, 1].
Var(Uµi ) Var(Uµj )

We therefore obtain
k(µ1 , µ2 )
corr(Uµi , Uµj ) = p
k(µ1 , µ1 )k(µ2 , µ2 )
⟨ϕ(µ1 ), ϕ(µ2 )⟩
= ,
∥ϕ(µ1 )∥∥ϕ(µ2 )∥
110 Control flows

which coincides with the cosine similarity measure we saw before. In


the particular case K = 2 and when kσ (µ1 , µ2 ) = κ(µ1 − µ2 ), we then
recover the previous heuristically-defined soft equality operator
κ(µ1 − µ2 )
eqσ (µ1 , µ2 ) = corr(Uµ1 , Uµ2 ) = .
κ(0)

5.4 Logical operators

Logical operators can be used to perform Boolean algebra. Formally,


we can define them as functions from {0, 1} × {0, 1} to {0, 1}. The and
(logical conjunction a.k.a. logical product), or (logical disjunction a.k.a.
logical addition) and not (logical negation a.k.a. logical complement)
operators, for example, are defined by

1 if π = π ′ = 1
and(π, π ′ ) :=
0 otherwise

1 if 1 ∈ {π, π ′ }
or(π, π ′ ) :=
0 otherwise

0 if π = 1
not(π) := .
1 if π = 0

Classical properties of these operators include

• Commutativity:

and(π, π ′ ) = and(π ′ , π)
or(π, π ′ ) = and(π ′ , π)

• Associativity:

and(π, and(π ′ , π ′′ )) = and(and(π, π ′ ), π ′′ )


or(π, or(π ′ , π ′′ )) = or(or(π, π ′ ), π ′′ )

• Distributivity of and over or:

and(π, or(π ′ , π ′′ )) = or(and(π, π ′ ), and(π, π ′′ ))


5.5. Continuous extensions of logical operators 111

• Neutral element:

and(π, 1) = π
or(π, 0) = π

• De Morgan’s laws:

not(or(π, π ′ )) = and(not(π), not(π ′ ))


not(and(π, π ′ )) = or(not(π), not(π ′ )).

More generally, for a binary vector π = (π1 , . . . , πK ) ∈ {0, 1}K ,


we can define all (universal quantification, ∀) and any (existential
quantification, ∃) operators, which are functions from {0, 1}K to {0, 1},
as 
1 if π1 = · · · = πK = 1
all(π) :=
0 otherwise

and 
1 if 1 ∈ {π1 , . . . , πK }
any(π) := .
0 otherwise

5.5 Continuous extensions of logical operators

5.5.1 Probabilistic continuous extension

We can equivalently write the and, or and not operators as

and(π, π ′ ) = π · π ′
or(π, π ′ ) = π + π ′ − π · π ′
not(π) = 1 − π.

These are extensions of the previous definitions: we can use them as


functions from [0, 1]×[0, 1] → [0, 1], as illustrated in Fig. 5.3. This means
that we can use the soft comparison operators defined in Section 5.2 to
obtain π, π ′ ∈ [0, 1]. Likewise, we can define continuous extensions of
112 Control flows

1.0
And operator 1.0 1.0
Or operator 1.0

Value

Value
0.5 0.5 0.5 0.5
0

0
0.0 0.0 0.0 0.0
0.0 0.5 1.0 0.0 0.5 1.0

Figure 5.3: The Boolean and and or operators are functions from {0, 1} × {0, 1}
to {0, 1} (corners in the figure) but their continuous extensions and(π, π ′ ) := π · π ′
as well as or(π, π ′ ) := π + π ′ − π · π ′ define a function from [0, 1] × [0, 1] to [0, 1].

all and any, which are functions from [0, 1]K to [0, 1], as

K
all(π) =
Y
πi
i=1
K
any(π) = 1 − (1 − πi ).
Y

i=1

From a probabilistic perspective, if we let Y and Y ′ to be two


independent random variables distributed according to Bernoulli dis-
tributions with parameter π and π ′ , then

and(π, π ′ ) = P(Y = 1 ∩ Y ′ = 1) = P(Y = 1) · P(Y ′ = 1)


or(π, π ′ ) = P(Y = 1 ∪ Y ′ = 1)
= P(Y = 1) + P(Y ′ = 1) − P(Y = 1 ∩ Y ′ = 1)
= P(Y = 1) + P(Y ′ = 1) − P(Y = 1)P(Y ′ = 1)
not(π) = P(Y ̸= 1) = 1 − P(Y = 1).

In probability theory, these correspond to the product rule of two


independent variables, the addition rule, and the complement rule.
Likewise, if we let Y = (Y1 , . . . , YK ) ∈ {0, 1}K be a random variable
distributed according to a multivariate Bernoulli distribution with
5.5. Continuous extensions of logical operators 113

parameters π = (π1 , . . . , πK ), then

all(π) = P(Y1 = 1 ∩ · · · ∩ YK = 1)
K
= P(Yi = 1)
Y

i=1
any(π) = P(Y1 = 1 ∪ · · · ∪ YK = 1)
= 1 − P(¬(Y1 = 1 ∪ · · · ∪ YK = 1))
= 1 − P(Y1 ̸= 1 ∩ · · · ∩ YK ̸= 1)
K
=1− (1 − P(Yi = 1)).
Y

i=1

These are the chain rule of probability and the addition rule of proba-
bility for K independent variables.

5.5.2 Triangular norms and co-norms

More generally, in the fuzzy logic literature (Klir and Yuan, 1995;
Jayaram and Baczynski, 2008), the concepts of triangular norms and
co-norms have been introduced to provide continuous relaxations of the
and and or operators, respectively.

Definition 5.1 (Triangular norms and conorms). A triangular norm,


a.k.a. t-norm, is a function from [0, 1] × [0, 1] to [0, 1] which is is
commutative, associative, neutral w.r.t. 1 and is monotone, meaning
that t(π, π ′ ) ≤ t(τ, τ ′ ) for all π ≤ τ and π ′ ≤ τ ′ . A triangular
conorm, a.k.a. t-conorm, is defined similarly but is neutral w.r.t. 0.

The previously-defined probabilistic extensions of and and or are


examples of triangle norm and conorm. More examples are given in
Table 5.1. Thanks to the associative property of these operators, we can
generalize them to vectors π ∈ [0, 1]K to define continuous extensions
of the all and any operators, as shown in Table 5.2. For more examples
and analysis, see for instance van Krieken (2024, Chapters 2 and 3).
114 Control flows

Extremum T-Norm ukasiewicz T-Norm


1.0
And operator 1.0 1.0
Or operator 1.0 1.0
And operator 1.0 1.0
Or operator 1.0

Value

Value

Value

Value
0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5
0

0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.5 1.0 0.0 0.5 1.0 0.0 0.5 1.0 0.0 0.5 1.0

Figure 5.4: Alternative relaxations of the Boolean and and or operators using
triangular norms (t-norms).

Table 5.1: Examples of triangular norms and conorms, which are continuous
relaxations of the and and or operators, respectively. More instances can be obtained
by smoothing out the min and max operators.

t-norm (relaxed and) t-conorm (relaxed or)


Probabilistic π·π ′
π + π′ − π · π′
Extremum min(π, π ′ ) max(π, π ′ )
Łukasiewicz max(π + π ′ − 1, 0) min(π + π ′ , 1)

5.6 If-else statements

An if-else statement executes different code depending on a condition.


Formally, we can define the ifelse : {0, 1} × V × V → V function by

v
1 if π = 1
ifelse(π, v1 , v0 ) := (5.3)
v0 if π = 0
= π · v1 + (1 − π) · v0 .

The π variable is called the predicate. It is a binary (Boolean) variable,


making the function ifelse undefined if π ̸∈ {0, 1}. The function is
therefore discontinuous and nondifferentiable w.r.t. π ∈ {0, 1}. On the
other hand, v0 ∈ V and v1 ∈ V, which correspond to the false and
true branches, can be continuous variables. If π = 1, the function
is linear w.r.t. v1 and constant w.r.t. v0 . Conversely, if π = 0, the
function is linear w.r.t. v0 and constant w.r.t. v1 . We now discuss how
to differentiate through ifelse.
5.6. If-else statements 115

Table 5.2: Continuous extensions of the all and any operators.

All (∀) Any (∃)


QK QK
Probabilistic i=1 πi 1 − i=1 (1 − πi )
Extremum min(π1 , . . . , πK ) max(π1 , . . . , πK )
PK PK
Łukasiewicz max( i=1 πi − (K − 1), 0) min( i=1 πi , 1)

5.6.1 Differentiating through branch variables


For π ∈ {0, 1} fixed, ifelse(π, v1 , v0 ) is a valid function w.r.t. v1 ∈ V
and v0 ∈ V, and can therefore be used as a node in a computational
graph (Section 8.3). Due to the linearity w.r.t. v1 and v0 , we obtain
that the Jacobians w.r.t. v1 and v0 are

0 if π = 1
∂v0 ifelse(π, v1 , v0 ) :=
I if π = 0
= (1 − π) · I
and

I if π = 1
∂v1 ifelse(π, v1 , v0 ) :=
0 if π = 0
= π · I,
where I is the identity matrix of appropriate size. Most of the time,
if-else statements are composed with other functions. Let g1 : U1 → V
and g0 : U0 → V be differentiable functions. We then define v1 := g1 (u1 )
and v0 := g0 (u0 ), where u1 ∈ U1 and u0 ∈ U0 . The composition of
ifelse, g1 and g0 is then the function f : {0, 1} × U1 × U0 → V defined by
f (π, u1 , u0 ) := ifelse(π, g1 (u1 ), g0 (u0 ))
= π · g1 (u1 ) + (1 − π) · g0 (u0 ).
We obtain that the Jacobians are
∂u1 f (π, u1 , u0 ) = π · ∂g1 (u1 )
and
∂u0 f (π, u1 , u0 ) = (1 − π)∂g0 (u0 ).
116 Control flows

As long as g1 and g0 are differentiable functions, we can therefore


differentiate through the branch variables u1 and u0 without any issue.
More problematic is the predicate variable π, as we now discuss.

5.6.2 Differentiating through predicate variables


The predicate variable π is binary and therefore cannot be differentiated
directly. However, π can be the output of a comparison operator. For
example, suppose we want to express the function fh : R × U1 × U0 → V
defined by 
g (u ) if p ≥ 0
1 1
fh (p, u1 , u0 ) := .
g0 (u0 ) otherwise

Using our notation, this can be rewritten as


fh (p, u1 , u0 ) := ifelse(gt(p, 0), g1 (u1 ), g0 (u0 ))
= ifelse(step(p), g1 (u1 ), g0 (u0 ))
= step(p)g1 (u1 ) + (1 − step(p))g0 (u0 ).
The Heaviside step function has a discontinuity at p = 0, but it is
continuous and differentiable with derivative step′ (p) = 0 for all p ̸= 0.
The function fh therefore has null derivative w.r.t. p ̸= 0,
∂p fh (p, u1 , u0 ) = ∂1 fh (p, u1 , u0 )
= step′ (p)(g1 (u1 ) − g0 (u0 ))
= 0.
In other words, while fh has well-defined derivatives w.r.t. p for p ̸= 0,
the derivatives are uninformative. As another example, let us now
consider the function
gh (u1 , u0 ) := fh (t(u1 ), u1 , u0 ),
for some differentiable function t. This time, u1 influences both the
predicate and the true branch. Then, using Proposition 2.8, we obtain
∂u1 gh (u1 , u0 ) = ∂t(u1 )∂1 fh (t(u1 ), u1 , u0 ) + ∂2 fh (t(u1 ), u1 , u0 )
= ∂2 fh (t(u1 ), u1 , u0 ).
In other words, the derivatives of the predicate t(u1 ) do not influence
the derivatives of gh .
5.6. If-else statements 117

5.6.3 Continuous relaxations


Fortunately, we recall that
ifelse(π, v1 , v0 ) = π · v1 + (1 − π) · v0 .
This function is perfectly well-defined, even if π ∈ [0, 1], instead of
π ∈ {0, 1}. That is, this definition is an extension of Eq. (5.3) from the
discrete set {0, 1} to the continuous unit segment [0, 1]. We saw that
gt(a, b) ≈ sigmoid(a − b) ∈ [0, 1],
where we use sigmoid to denote a differentiable S-shaped function
mapping R to [0, 1]. For instance, we can use the logistic function or
the standard Gaussian’s CDF. If we now define
fs (p, u1 , u0 ) := ifelse(sigmoid(p), g1 (u1 ), g0 (u0 ))
= sigmoid(p)g1 (u1 ) + (1 − sigmoid(p))g0 (u0 ), (5.4)
the Jacobian becomes
∂p fs (p, u1 , u0 ) = sigmoid′ (p)(g1 (u1 ) − g0 (u0 )).
If sigmoid = logistic, the Jacobian is non-null everywhere, allowing
gradients to backpropagate through the computational graph. This is
an example of smoothing by regularization, as studied in Chapter 13.

Probabilistic perspective
From a probabilistic perspective, we can view Eq. (5.4) as the expecta-
tion of gi (ui ), where i ∈ {0, 1} is a binary random variable distributed
according to a Bernoulli distribution with parameter π = sigmoid(p):
fs (p, u1 , u0 ) = Ei∼Bernoulli(sigmoid(p)) [gi (ui )] .
Taking the expectation over the two possibles branches makes the
function differentiable with respect to p, since sigmoid(p) is differentiable.
Of course, this comes at the cost of evaluating both branches, instead
of a single one. The probabilistic perspective suggests that we can also
compute the variance if needed as
Vi∼Bernoulli(sigmoid(p)) [gi (ui )]
h i
=Ei∼Bernoulli(sigmoid(p)) (fs (p, u1 , u0 ) − gi (ui ))2 .
118 Control flows

Hard Soft
comparison comparison
operator operator

ifelse ifelse

Figure 5.5: Computation graphs of programs using if-else statements with either
hard or soft comparison operators. By using a hard comparison operator (step
function, left panel) the predicate π is a discrete variable (represented by a dashed
line). Depending on the value (0 or 1) of the predicate π, only one branch (red or blue)
contributes to the output. Derivatives along a path of continuous variables (dense
lines) can be computed. However, discrete variables such as the predicate prevent the
propagation of meaningful derivatives. By using a soft comparison operator (sigmoid,
right panel), the predicate is a continuous variable and derivatives with respect to
the input p can be taken. In this case both branches (corresponding to g0 and g1 )
contribute to the output and therefore need to be evaluated.

The probabilistic viewpoint also suggests different scales at which a


smoothing can be defined as illustrated in Fig. 5.6.
Another perspective (Petersen et al., 2021) is based on the logistic
distribution. Indeed, if P is a random variable following a logistic
distribution with mean p and scale 1, we saw in Remark 3.1 that the
CDF is P(P ≤ 0) = logistic(−p) = 1 − logistic(p) and therefore

fs (p, u1 , u0 ) = ifelse(logistic(p), g1 (u1 ), g0 (u0 ))


= logistic(p)g1 (u1 ) + (1 − logistic(p))g0 (u0 )
= P(P > 0) · g1 (u1 ) + P(P ≤ 0) · g0 (u0 ).

Remark 5.1 (Global versus local smoothing). Consider the function



y if a ≤ x ≤ b
f (x, y, z) := .
z otherwise

The derivatives w.r.t. y and z are well-defined. The derivative w.r.t.


x on the other hand is not well-defined since it involves comparison
operators and the logical operator and. Using our notation, we can
5.6. If-else statements 119

rewrite the function as

f (x, y, z) = ifelse(and(gt(x, a), lt(x, b)), y, z).

A local smoothing approach consists in replacing gt and lt by gtσ


and ltσ locally in the program:

fσloc (x, y, z) := ifelse(and(gtσ (x, a), ltσ (x, b)), y, z)


= πa πb y + (1 − πa πb )z

where

πa := sigmoidσ (x − a)
πb := sigmoidσ (b − x),

for any sigmoid function sigmoidσ . A global smoothing approach


instead uses the expectation of the entire program

fσglob (x, y, z) := E[ifelse(and(gt(x + σZ, a), lt(x + σZ, b)), y, z)]


= ifelse(π, y, z)

where

π := E[and(gt(x + σZ, a), lt(x + σZ, b))]


= P(a ≤ x + σZ ≤ b)
= sigmoidσ (b − x) − sigmoidσ (a − x)
= πb − πa ,

for sigmoidσ the CDF of σZ. We therefore obtain

fσglob (x, y, z) = (πb − πa )y + (1 − (πb − πa ))z.

The difference stems from the fact that the local approach smoothes
out a ≤ x and x ≤ b independently (treating 1X≥a and 1X≤b as
independent random variables), while the global approah smoothes
out a ≤ x ≤ b simultaenously. In practice, both approaches ap-
proximate the original function well as σ → 0 and coincide for σ
sufficiently small as illustrated in Fig. 5.6.
120 Control flows

= 1.0 = 0.5 = 0.1


1 1 1

0 0 0
2 0 2 2 0 2 2 0 2
Orignal function Locally smoothed Globally smoothed

Figure 5.6: Global versus local smoothing approaches on a gate function f (x) := 1
if x ∈ [−1, 1], and f (x) := 0 otherwise. In our notation, we can write f (x) =
ifelse(and(gt(x, −1), lt(x, 1)), 1, 0). A local approach smoothes out gt and lt separately.
A global approach uses the expectation of the whole program, see Remark 5.1. We
observe that, though the approaches differ for large σ, they quickly coincide for
smaller σ.

5.7 Else-if statements

In the previous section, we focused on if-else statements: conditionals


with only two branches. We now generalize our study to conditionals
including else-if statements, that have K branches.

5.7.1 Encoding K branches

For conditionals with only 2 branches, we encoded the branch that


the conditional needs to take using the binary variable π ∈ {0, 1}. For
conditionals with K branches, we need a way to encode which of the K
branches the conditional needs to take. To do so, we can use a vector
π ∈ {e1 , . . . , eK }, where ei denotes the standard basis vector (a.k.a.
one-hot vector)

ei := (0, . . . , |{z}
1 , . . . , 0),
i

a vector with a single one in the coordinate i and K − 1 zeros. The


vector ei is the encoding of a categorical variable i ∈ [K].
5.7. Else-if statements 121

Combining booleans
To form, such a vector π ∈ {e1 , . . . , eK }, we can combine the previously-
defined comparison and logical operators to define π = (π1 , . . . , πK ).
However, we need to ensure that only one πi is non-zero. We give an
example in Example 5.1.

Argmax and argmin operators


Another way to form π is to use the argmax and argmin operators

argmax(p) := arg max ⟨π, p⟩


π∈{e1 ,...,eK }

argmin(p) := arg min ⟨π, p⟩ = argmax(−p).


π∈{e1 ,...,eK }

They can be seen as a natural generalization of the greater than and


less than operators. In case of ties, we break them arbitrarily.

5.7.2 Conditionals
We can now express a conditional statement as the function
cond : {e1 , . . . , eK } × V K → V defined by

v if π = e1
 1


.

cond(π, v1 , . . . , vK ) := .. (5.5)
 
if π = eK

v
K
K
=
X
πi vi .
i=1

Similarly as for the ifelse function, the cond function is discontinuous


and nondifferentiable w.r.t. π ∈ {e1 , . . . , eK }. However, given π = ei
fixed for some i, the function is linear in vi and constant in vj for j ̸= i.
We illustrate how to express a simple example, using this formalism.

Example 5.1 (Soft-thresholding operator). The soft-thresholding op-


erator (see also Section 16.4) is a commonly-used operator to pro-
122 Control flows

mote sparsity. It is defined by



0

 if |u| ≤ λ
SoftThreshold(u, λ) := u−λ if u ≥ λ .

u+λ if u ≤ −λ

To express it in our formalism, we can define π ∈ {e1 , e2 , e3 } using


comparison operators as

π := (lt(|u|, λ), gt(u, λ), lt(u, −λ))


= (step(λ − |u|), step(u − λ), step(−u − λ)).

Equivalently, we can also define π using an argmax operator as

π := argmax((λ − |u|, u − λ, −u − λ)).

In case of ties, which happens at |u| = λ, we keep only one non-zero


coordinate in π. We can then rewrite the operator as

SoftThreshold(u, λ) = cond(π, 0, u − λ, u + λ).

As we will see, replacing argmax with softargmax induces a cat-


egorical distribution over the three possible branches. The mean
value can be seen as a smoothed out version of the operator, and we
can also compute the standard deviation, as illustrated in Fig. 5.7.

5.7.3 Differentiating through branch variables


For π fixed, cond(π, v1 , . . . , vK ) is a valid function w.r.t. vi , and can
therefore again be used as a node in a computational graph. Due to the
linearity w.r.t. vi , we obtain that the Jacobian w.r.t. vi is

I if π = ei
∂vi cond(π, v1 , . . . , vK ) := .
0 if π ̸= ei

Let gi : Ui → V be a differentiable function and ui ∈ Ui . If we define


the composition

f (π, u1 , . . . , uK ) := cond(π, g1 (u1 ), . . . , gK (uK )),


5.7. Else-if statements 123

6 Mean
4 Standard deviation
Hard
2
0
2
4
6
4 2 0 2 4

Figure 5.7: A conditional with three branches: the soft-thresholding operator (see
Example 5.1). It is a piecewise linear function (dotted black line). Using a softargmax,
we can induce a categorical probability distribution over the three branches. The
expected value (blue line) can be seen as a smoothed out version of the operator.
The induced distribution allows us to also compute the standard deviation.

we then obtain that the Jacobian w.r.t. ui is



∂g (u ) if π = ei
i i
∂ui f (π, u1 , . . . , uK ) := .
0 if π ̸= ei

As long as the gi functions are differentiable, we can therefore differen-


tiate through the branch variables ui for π fixed.

5.7.4 Differentiating through predicate variables


As we saw, π can be obtained by combining comparison and logical
operators, or it can be obtained by argmax and argmin operators.
We illustrate here why these operators are problematic. For example,
suppose we want to express the function

v if p1 = maxj pj
 1


.

fa (p, u1 , . . . , uK ) := .. .
 
if pK = maxj pj

v
K

In our notation, this can be expressed as

fa (p, u1 , . . . , uK ) := cond(argmax(p), g1 (u1 ), . . . , gK (uK )),


124 Control flows

As for the ifelse case, the Jacobian w.r.t. p is null almost everywhere,

∂p fa (p, u1 , . . . , uK ) = 0.

5.7.5 Continuous relaxations


Similarly to the Heaviside step function, the argmax and argmin func-
tions are piecewise constant, with discontinuities in case of ties. Their
Jacobian are zero almost everywhere, and undefined in case of ties.
Therefore, while their Jacobian is well-defined almost everywhere, they
are uninformative and prevent gradient backpropagation. We can replace
the argmax with a softargmax
exp(p)
softargmax(p) := PK ∈ △K
i=1 exp(p i )
and similarly

softargmin(p) := softargmax(−p) ∈ △K .

Other relaxations of the argmax are possible, as discussed in Section 13.7.


See also Section 14.5.3 for the perturbation perspective.
Fortunately, the definition
K
cond(π, v1 , . . . , vK ) =
X
πi vi
i=1

is perfectly valid if we use π ∈ △K instead of π ∈ {e1 , . . . , eK }, and


can therefore be seen as an extension of Eq. (5.5). If we now define

fs (p, u1 , . . . , uK ) := cond(softargmax(p), g1 (u1 ), . . . , gK (uK ))


K
= [softargmax(p)]i · gi (ui ), (5.6)
X

i=1

the Jacobian becomes

∂p fs (p, u1 , . . . , uK ) = ∂softargmax(p)(g1 (u1 ), . . . , gK (uK )),

which is non-null everywhere, allowing gradients to backpropagate


through the computational graph.
5.8. For loops 125

Probabilistic perspective
From a probabilistic perspective, we can view Eq. (5.6) as the ex-
pectation of gi (ui ), where i ∈ [K] is a categorical random variable
distributed according to a categorical distribution with parameter
π = softargmax(p):
fs (p, u1 , . . . , uK ) = Ei∼Categorical(softargmax(p)) [gi (ui )] .
Taking the expectation over the K possible branches makes the function
differentiable with respect to p, at the cost of evaluating all branches,
instead of a single one. Similarly as for the if-else case, we can compute
the variance if needed as
Vi∼Categorical(softargmax(p)) [gi (ui )]
h i
=Ei∼Categorical(softargmax(p)) (fs (p, u1 , . . . , uK ) − gi (ui ))2 .
This is illustrated in Fig. 5.7.

5.8 For loops

For loops are a control flow for sequentially calling a fixed number K
of functions, reusing the output from the previous iteration. In full
generality, a for loop can be written as follows.

Algorithm 5.1 r = forloop(s0 )


for k := 1, . . . , K do
sk := fk (sk−1 )
r := sK

As illustrated in Fig. 5.8, this defines a computation chain. Assuming


the functions fk are all differentiable, this defines a valid computation
graph, we can therefore use automatic differentiation to differentiate
forloop w.r.t. its input s0 . Feedforward networks, reviewed in Section 4.2,
can be seen as parameterized for loops, i.e.,
fk (sk−1 ) := gk (sk−1 , wk ),
for some differentiable function gk .
126 Control flows

Example 5.2 (Unrolled gradient descent). Suppose we want to min-


imize w.r.t. w the function
N
1 X λ
L(w, λ) := ℓ(h(xi , w), yi ) + ∥w∥22 .
N i=1 2

Given an initialization w0 , gradient descent (Section 16.1) performs


iterations of the form

wk = f (wk−1 , γk , λ) := wk−1 − γk ∇1 L(wk−1 , λ).

Gradient descent can therefore be expressed as a for loop with

fk (wk−1 ) := f (wk−1 , γk , λ).

This means that we can differentiate through the iterations of


gradient descent, as long as f is differentiable, meaning that L is
twice differentiable. This is useful for instance to perform gradient-
based optimization of the hyperparameters γk or λ. This a special
case of bilevel optimization; see also Chapter 11.

Example 5.3 (Bubble sort). Bubble sort is a simple sorting algo-


rithm that works by repeatedly swapping elements if necessary.
Mathematically, swapping two elements i and j can be written as
a function from RN × [N ] × [N ] to RN defined by

swap(v, i, j) := v + (vj − vi )ei + (vi − vj )ej .

We can then write bubble sort as

for i := 1, . . . , N do
for j := 1, . . . , N − i − 1 do
v ′ := swap(v, j, j + 1)
π := step(vj − vj+1 )
v ← ifelse(π, v ′ , v)

Replacing the Heaviside step function with the logistic function


gives a smoothed version of the algorithm.
5.9. Scan functions 127

... ...

Figure 5.8: A for loop forms a com- Figure 5.9: Computation graph of the
putation chain. A feed forward network scan function. Sequence-to-sequence
can be seen as a parameterized for loop, RNNs can be seen as a parameterized
where each function fk depends on some scan function.
parameters wk .

5.9 Scan functions

Scan is a higher-order function (meaning a function of a function)


originating from functional programming. It is useful to perform an
operation f on individual elements uk while carrying the result sk of
that operation to the next iteration.

Algorithm 5.2 r = scan(s0 , u1 , . . . , uK )


for k := 1, . . . , K do
sk , vk := f (sk−1 , uk )
r := (sK , v1 , . . . , vK )

As illustrated in Fig. 5.9, this again defines a valid computational


graph and can be differentiated through using autodiff, assuming the
function f is differentiable. Sequence-to-sequence RNNs, reviewed in
Section 4.6, can be seen as a parameterized scan. An advantage
of this abstraction is that parallel scan algorithms have been studied
extensively in computer science (Blelloch, 1989; Sengupta et al., 2010).

Example 5.4 (Prefix sum). Scan can be seen as a generalization of


the prefix sum (a.k.a. cumulated sum) from the addition to any
128 Control flows

binary operation. Indeed, a prefix sum amounts to perform

v1 := u1
v2 := u1 + u2
v3 := u1 + u2 + u3
..
.

which can be expressed as a scan by defining

vk := sk−1 + uk
f (sk−1 , uk ) := (vk , vk )

starting from s0 = 0 (sK and vK are redundant in this case).

5.10 While loops

5.10.1 While loops as cyclic graphs

A while loop is a control flow used to repeatedly perform an operation,


reusing the output of the previous iteration, until a certain condition is
met. Suppose f : S → {0, 1} is a function to determine whether to stop
(π = 1) or continue (π = 0) and g : S → S is a function for performing
an operation. Then, without loss of generality, a while loop can be
written as follows.

Algorithm 5.3 r = whileloop(s)


π ← f (s)
while π = 0 do
s ← g(s)
π ← f (s)
r := s

This definition is somewhat cyclic, as we used the while keyword.


However, we can equivalently rewrite the algorithm recursively.
5.10. While loops 129

ifelse

Figure 5.10: A while loop can be represented as a cyclic graph. The while loop
stops if π = 1 and performs another iteration s ← g(s), π ← f (s) if π = 0.

Algorithm 5.4 r = whileloop(s)


π := f (s)
if π = 0 then
r := s
else
r := whileloop(g(s))

Unlike for loops and scan, the number of iterations of while loops is
not known ahead of time, and may even be infinite. In this respect, a
while loop can be seen as a cyclic graph, as illustrated in Fig. 5.10.

Importance of lazy evaluation


We can also implement Algorithm 5.4 in terms of the ifelse function
defined in Section 5.6 as

r := ifelse(f (s), s, whileloop(g(s)))


= f (s) · s + (1 − f (s)) · whileloop(g(s)).

However, to avoid an infinite recursion, it is crucial that ifelse supports


lazy evaluation. That is, whileloop(g(s)) in the definition above should
be evaluated if and only if π = f (s) = 0. In other words, the fact that
f (s) ∈ {0, 1} is crucial to ensure that the recursion is well-defined.

5.10.2 Unrolled while loops


To avoid the issues with unbounded while loops, we can enforce that a
while loop stops after T iterations, i.e., we can truncate the while loop.
130 Control flows

Unrolling Algorithm 5.4 gives (here with T = 3)

π0 := f (s0 )
if π0 = 1 then
r := s0
else
s1 := g(s0 ), π1 := f (s1 )
if π1 = 1 then
r := s1
else
s2 := g(s1 ), π2 := f (s2 )
if π2 = 1 then
r := s2
else
r := s3 := g(s2 )

Using the ifelse function, we can rewrite it as

r = ifelse(π0 ,
s0 ,
ifelse(π1 ,
s1 ,
ifelse(π2 ,
s2 ,
s3 )))

which is itself equivalent to

r = π0 s0 + (1 − π0 ) [π1 s1 + (1 − π1 ) [π2 s2 + (1 − π2 )s3 ]]


= π0 s0 + (1 − π0 )π1 s1 + (1 − π0 )(1 − π1 )π2 s2 + (1 − π0 )(1 − π1 )(1 − π2 )s3 .
5.10. While loops 131

More generally, for T ∈ N, the formula is


T
r= ((1 − π0 ) . . . (1 − πi−1 )) πi si
X

i=0
 
T i−1
= (1 − πj ) πi si ,
X Y

i=0 j=0

where we defined
si := g(si−1 ) := g i (s0 ) := g ◦ · · · ◦ g (s0 ) ∈ S
| {z }
i times
πi := f (si ) ∈ {0, 1}.
See also (Petersen et al., 2021). If we further define the shorthand
notation
π̃0 := π0
 
i−1
π̃i := 
Y
πj  πi i ∈ {1, . . . , T },
j=0

so that π̃ := (π̃0 , π̃1 , . . . , π̃T ) ∈ △T +1 is a discrete probability distribu-


tion containing the probabilities to stop at each of the T iterations, we
can rewrite the output of a truncated while using a conditional,
r = cond(π̃, s0 , s1 , . . . , sT ).
This is illustrated in Fig. 5.11.

Example 5.5 (Computing the square root using Newton’s method).



Computing the square root x of a real number x > 0 can be cast
as a root finding problem, which we can solve using Newton’s
method. Starting from an initialization s0 , the iterations read
1 x
 
si+1 := g(si ) := si + .
2 si
To measure the error on iteration i, we can define
1
ε(si ) := (s2i − x)2 .
2
132 Control flows

...

...

...

cond

Figure 5.11: Computation graph of an unrolled truncated while loop. As in Fig. 5.5,
we depict continuous variables in dense lines and discrete variables in dashed lines.
The output of a while loop with at most T iterations
PT can be written as a conditional
with T + 1 branches, cond(π̃, s0 , . . . , sT ) = t=0 π̃t st .

As a stopping criterion, we can then use



1 if ε(si ) ≤ τ
πi :=
0 otherwise
= step(τ − ε(si )),

where 0 < τ ≪ 1 is an error tolerance and step is the Heaviside


step function.

5.10.3 Markov chain perspective

Given the function g : S → S and the initialization s0 ∈ S, a while


loop can only go through a discrete set of values s0 , s1 , s2 , . . . defined
by si = g(si−1 ). This set is potentially countably infinite if the while
loop is unbounded, and finite if the while loop is guaranteed to stop.
Whether the loop moves from the state si to the state si+1 , or stays
at si , is determined by the stopping criterion πi ∈ {0, 1}. To model
the state of the while loop, we can then consider a Markov chain
with a discrete space {s0 , s1 , s2 , . . . }, which we can always identify with
5.10. While loops 133

{0, 1, 2, . . . }, with transition probabilities



πi

 if i = j
P(St+1 = si |St = sj ) = pi,j := (1 − πi ) if i = j + 1 ,

0 otherwise

and initial state S0 = s0 . Here, St is the value at iteration t of the loop.


Note that since πi ∈ {0, 1}, the pi,j values are “degenerate” probabilities.
However, this framework lets us generalize to a smooth version of the
while loop naturally. To illustrate the framework, if the while loop stops
at T = 3, the transition probabilities can be cast as a matrix

s0 s1 s2 s3
s0 0 1 0 0
0 0 1 0
P := (pi,j )Ti,j=0 := s1  .
s2 0 0 0 1
s3 0 0 0 1

The output r of the while-loop is determined by the time at which the


state stays at the same value

I = min{i ∈ {1, 2, . . .} s.t. Si = Si−1 }.

Note that I itself is a random variable, as it is defined by the Si variables.


It is called a stopping time. The output of the chain is then

r = E[SI ]
+∞
= P(I = i)E[Si |I = i]
X

i=1
+∞
= P(I = i)si−1
X

i=1
X i−2
+∞
= (1 − πj )πi−1 si−1
Y

i=1 j=0
X i−1
+∞
= (1 − πj )πi si .
Y

i=0 j=0
134 Control flows

Because the stopping time is not known ahead of time, the sum over i
goes from 0 to ∞. However, if we enforce in the stopping criterion that
the while loop runs no longer than T iterations, by setting

πi := or(f (si ), eq(i, T )) ∈ {0, 1},

we then naturally recover the expression found by unrolling the while


loop before,
T i−1
r = E[SI ] = (1 − πj )πi si .
X Y

i=0 j=0

For example, with T = 3, the transition probability matrix is

s0 s1 s2 s3
s0 π0 1 − π0 0 0 
P = s1  0 π1 1 − π1 0  .

s 02 0 π2 1 − π2 
s3 0 0 0 1

Smoothed while loops

With the help of this framework, we can backpropagate even through the
while loop’s stopping criterion, provided that we smooth out the predi-
cate. For example, we saw that the stopping criterion in Example 5.5 is
f (si ) = step(τ − ε(si )) and therefore

πi := or(f (si ), eq(i, T )) ∈ {0, 1}.

Due to the step function, the derivative of the while loop with respect
to τ will always be 0, just like it was the case for if-else statements. If
we change the stopping criterion to f (si ) = sigmoid(τ − ε(si )), we then
have (recall that or is well defined on [0, 1] × [0, 1])

πi := or(f (si ), eq(i, T )) ∈ [0, 1].

With sigmoid, we obtain more informative derivatives. In particular,


with sigmoid = logistic, the derivatives w.r.t. τ are always non-zero.
5.11. Summary 135

The smoothed output is expressed as before as the expectation


T i−1
r = E[SI ] = (1 − πj )πi si
X Y

i=0 j=0
T i−1
= (1 − sigmoid(ε(si ) − τ ))sigmoid(ε(si ) − τ )si .
X Y

i=0 j=0

Instead of enforcing a number T of iterations, it is also possible to stop


when the probability of stopping becomes high enough (Petersen et al.,
2021), assuming that the probability of stopping converges to 1.

5.11 Summary

• For conditionals, we saw that differentiating through the branch


variables is not problematic given a fixed predicate.

• However, for the predicate variable, we saw that a differentiable


relaxation is required to avoid null derivatives.

• We introduced soft comparison operators in a principled manner,


using a stochastic process perspective, as well as the continuous
extension of logical operators.

• For loops and scan define valid computational graphs, as their


number of iterations is fixed ahead of time. Feedforward networks
and RNNs can be seen as parameterized for loops and scan,
respectively.

• Unlike for loops and scan, the number of iterations of while loops
is not known ahead of time and may even be infinite. However,
unrolled while loops define valid directed acyclic graphs. We
defined a principled way to differentiate through the stopping
criterion of a while loop, thanks to a Markov chain perspective.
6
Data structures

In computer science, a data structure is a specialized format for organiz-


ing, storing and accessing data. Mathematically, a data structure forms
a so-called algebraic structure: it consists of a set and the functions
to operate on that set. In this chapter, we review how to incorporate
data structures into differentiable programs, with a focus on lists and
dictionaries.

6.1 Lists

A list is an ordered sequence of elements. We restrict ourselves to lists


whose elements all belong to the same value space V. Formally, we
denote a list of fixed length K with values in V by a K-tuple

l := (l1 , . . . , lK ) ∈ LK (V)

where each li ∈ V and where

LK (V) := V K = V
|
× ·{z
· · × V} .
K times

136
6.1. Lists 137

6.1.1 Basic operations


Getting values
We first present how to retrieve values from a list l ∈ LK (V). We define
the function list.get : LK (V) × [K] → V as
list.get(l, i) := li .
The function is continuous and differentiable in l ∈ LK (V) but not in
i ∈ [K], as it is a discrete variable. In the particular case V = R, LK (V)
is equivalent to RK and we can therefore write
list.get(l, i) = ⟨l, ei ⟩,
where {e1 , . . . , eK } is the standard basis of RK .

Setting values
We now present how to replace values from a list l ∈ LK (V). We define
the function list.set : LK (V) × [K] × V → LK (V) as

v if i = j
[list.set(l, i, v)]j := ,
lj if i ̸= j
for j ∈ [K]. In the functional programming spirit, the function returns
the whole new list, even though a single element has been modified.
Again, the function is continuous and differentiable in l ∈ LK (V) and
v ∈ V but not in i ∈ [K]. In the particular case V = R, given a list
l = (l1 , . . . , lK ), we can write
list.set(l, i, v) = (v − li )ei .
That is, we subtract the old value li and add the new value v at the
location i ∈ [K].

Implementation
A fixed-length list can be implemented as an array, which enables O(1)
random access to individual elements. The hardware counterpart of
an array is random access memory (RAM), in which memory can be
retrieved by address (location).
138 Data structures

6.1.2 Operations on variable-length lists

So far, we focused on lists of fixed length K. We now turn our attention


to variable-length lists, whose size can decrease or increase over time.
In addition to the list.get and list.set functions, they support functions
that can change the size of a list.

Initializing lists

In order to initialize a list, we define list.init : V → L1 (V) as

list.init(v) := (v),

where used (v) to denote a 1-tuple.

Pushing values

In order to add new values either to the left or to the right, we define
list.pushLeft : LK (V) × V → LK+1 (V) as

list.pushLeft(l, v) := (v, l1 , . . . , lK ).

and list.pushRight : LK (V) × V → LK+1 (V) as

list.pushRight(l, v) := (l1 , . . . , lK , v).

Popping values

In order to remove values either from the left or from the right, we
define list.popLeft : LK (V) → LK−1 (V) × V as

list.popLeft(l) := (l2 , . . . , lK ), l1

and list.popRight : LK (V) → LK−1 (V) × V as

list.popRight(l) := (l1 , . . . , lK−1 ), lK .

The set L0 (V) is a singleton which contains the empty list.


6.1. Lists 139

Inserting values
The pushLeft and pushRight functions can only insert values at the
beginning and at the end of a list, respectively. We now study the insert
function, whose goal is to be able to add a new value at an arbitrary
location, shifting all values to the right and increasing the list size by 1.
We define the function list.insert : LK (V) × [K + 1] × V → LK+1 (V) as

l j

 if j < i
[list.insert(l, i, v)]j := v if j = i ,

if j > i

lj−1

for j ∈ [K+1]. As for the list.set function, list.insert is readily continuous


and differentiable in l and v, but not in i, as it is a discrete variable.
As special cases, we naturally recover
list.insert(l, 1, v) = pushLeft(l, v),
list.insert(l, K + 1, v) = pushRight(l, v).

Differentiability
The list.init, list.push and list.pop functions are readily continuous and
differentiable with respect to their arguments (a continuous relaxation
is not needed). As for the list.set function, the list.insert function is
continuous and differentiable in l and v, but not in i.

Implementation
Under the hood, a variable-length list can be implemented as a linked
list or as a dynamic array. A linked list gives O(K) random access while
a dynamic array allows O(1) random access, at the cost of memory
reallocations.

Stacks and queues


The list.pushRight and list.popRight functions can be used to implement
a stack (last in first out a.k.a. LIFO behavior). The list.pushLeft and
list.popRight functions can be used to implement a queue (first in first
out a.k.a. FIFO behavior).
140 Data structures

6.1.3 Continuous relaxations using soft indexing


Getting values
In order to be able to differentiate list.get w.r.t. indexing, a natural
idea is to replace the integer index i ∈ [K] by a distribution πi ∈ △K ,
which we can interpret as a soft index. An integer index i ∈ [K] is
then equivalent to a delta distribution πi ∈ {e1 , . . . , eK }. We define
the continuous relaxation list.softGet : LK (V) × △K → conv(V) as
K
list.softGet(l, πi ) :=
X
πi,j lj
j=1

= cond(πi , l1 , . . . , lK )
= EI∼Categorical(πi ) [lI ],

where cond is studied in Section 5.7. In the particular case V = R, we


obtain
list.softGet(l, i) = ⟨l, πi ⟩.
This is illustrated in Fig. 6.1.
The choice of the distribution πi = (πi,1 , . . . , πi,K ) encodes the
importance of the elements (l1 , . . . , lK ) w.r.t. li . If we consider that the
smaller |i − j| is, the more related li and lj are, then it makes sense
to define a distribution centered around i (i.e., such that the mode of
the distribution is achieved at i). For example, limiting ourselves to the
neighbors li−1 and li+1 (i.e., a window of size 1), we can define the
sparse distribution
1 1 1
πi := · ei−1 + ei + · ei+1 ∈ △K .
4 2 4
In this particular case, the continuous relaxation of the list.get function
can then be expressed as a discrete convolution,

list.softGet(l, πi ) = (list.get(l, ·) ∗ κ) (i) = list.get(l, i − j)κ(j),
X

j=−∞

where κ(−1) := 14 , κ(1) := 14 , κ(0) := 12 , and κ(j) := 0 for j ̸∈ {−1, 0, 1}.


Assuming V = RM , the computational complexity of list.softGet is
O(M · |supp(πi )|).
6.1. Lists 141

1 2 3 4 5 1 2 3 4 5

Figure 6.1: The list.get(l, i) function is continuous and differentiable in l but not
in i. Its relaxation list.sofGet(l, πi ) is differentiable in both l and πi . When V = R,
list.softGet(l, πi ) can be seen as taking the inner product between the list l and the
probability distribution πi , instead of the delta distribution (canonical vector) ei .

Setting values

To differentiate w.r.t. indexing, we can define the continuous relaxation


list.softSet : LK (V) × △K × V → LK (conv(V)) as

[list.softSet(l, πi , v)]j := E[list.set(l, I, v)]j


= P(I = j)v + P(I ̸= j)lj
= πi,j v + (1 − πi,j )lj ,

where j ∈ [K] and I ∼ Categorical(πi ). Equivalently, we can write

list.softSet(l, πi , v) = (πi,1 v + (1 − πi,1 )l1 , . . . , π1,K v + (1 − π1,K )lK )


= (ifelse(πi,1 , v, l1 ), . . . , ifelse(πi,K , v, lK )),

where ifelse is studied in Section 5.6. Since

ifelse(π, u1 , u0 ) = EI∼Bernouilli(π) [uI ],

this relaxation amounts to using an element-wise expectation. As a


result, the list output by list.softSet takes values in conv(V) instead of
V. Note however that when V = RM , then conv(V) = RM as well.
142 Data structures

Inserting values

To differentiate value insertion w.r.t. indexing, we can define the contin-


uous relaxation list.softInsert : LK (V) × △K+1 × V → LK+1 (conv(V))

[list.softInsert(l, πi , v)]j := E[list.insert(l, I, v)]


= P(I > j)lj + P(I = j)v + P(I < j)lj−1 ,

where I ∼ Categorical(πi ). The three necessary probabilities can easily


be calculated for j ∈ [K + 1] by

0 if j = K + 1
P(I > j) = PK+1

k=j+1 πi,k otherwise
P (I = j) = πi,j

0 if j = 1
P(I < j) = Pj−1 .

k=1 πi,k otherwise

Multi-dimensional indexing

In multi-dimensional lists (arrays or tensors), each element li ∈ V of


a list l ∈ LK1 ,...,KT (V) can now be indexed by a multivariate integer
i = (i1 , . . . , iT ) ∈ [K1 ]×· · ·×[KT ], where T ∈ N is the number of axes of
l. We can always flatten a multi-dimensional list into an uni-dimensional
list by replacing the multi-dimensional index i ∈ [K1 ] × · · · × [KT ] by
a flat index i ∈ [K1 . . . KT ]. The converse operation, converting a flat
uni-dimensional array into a multi-dimensional array, is also possible.
Therefore, there is a bijection between [K] and [K1 ] × · · · × [KT ] for
K := K1 . . . KT .
This means that the previous discussion on soft indexing in the
uni-dimensional setting readily applies to the multi-dimensional setting.
All it takes is the ability to define a probability distribution πi ∈
△K1 ×···×KT . For example, when working with images, we can define a
probability distribution putting probability mass only on the neighboring
pixels of pixel i, a standard approach in image processing. Another
simple approach is to use a product of axis-wise probability distributions.
6.2. Dictionaries 143

6.2 Dictionaries

A dictionary (a.k.a. associative array or map) is an unordered list of


key-value pairs, such that each possible key appears at most once in
the list. We denote the set of keys by K and the set of values by V (both
being potentially infinite). We can then define the set of dictionaries of
size L from K to V by

DL (K, V) := LL (K × V) = (K × V)L

and one such dictionary by

d := ((k1 , v1 ), . . . , (kL , vL )) ∈ DL (K, V).

6.2.1 Basic operations

Getting values

The goal of the dict.get function is to retrieve the value associated with
a key, assuming that the dictionary contains this key. Formally, we
define the dict.get : DL (K, V) × K → V ∪ {∞} function as

v
i if ∃i ∈ [L] s.t. k = ki
dict.get(d, k) := .
∞ if k ̸∈ {k1 , . . . , kL }

The function is continuous and differentiable in the dictionary d, but


not in the key k. Equivalently, we can write the function as

eq(k, ki )vi
PL
dict.get(d, k) := Pi=1 .
i=1 eq(k, ki )
L

The denominator encodes the fact that the function is undefined if no


key in the dictionary d matches the key k. Assuming k ∈ {k1 , . . . , kL }
and V = RM , we can also write

dict.get(d, k) = vi where i = arg max ∥k − kj ∥2 ,


j∈[L]

which shows that we can see dict.get as a nearest neighbor search.


144 Data structures

(k2, v2) (k3, v3)


0.4
Values

0.2
(k4, v4)
(k1, v1) Key-value pairs
Kernel Estimator
0.0
0.0 0.5 1.0
Keys
Figure 6.2: Given a set of key-value pairs (ki , vi ) ∈ K × V defining a dictionary d,
we can estimate a continuous mapping from K to V using Nadaraya–Watson kernel
regression (here, illustrated with K = V = R). When keys are normalized to have
unit norm, this recovers softargmax attention from Transformers.

Setting values

The goal of the dict.set function is to replace the value associated with
an existing key. Formally, we define the dict.set : DL (K, V) × K × V →
DL (K, V) function as

(k , v) if ki = k
i
(dict.set(d, k, v))i := .
(ki , vi ) if ki ̸= k

The function leaves the dictionary unchanged if no key in the dictionary


matches the input key k. The function is continuous and differentiable
in d and v, but not in k.

Implementation

While we view dictionaries as lists of key-value pairs, in practice, a


dictionary (a.k.a. associative array) is often implemented using a hash
table or search trees. The hardware counterpart of a dictionary is called
content-addressable memory (CAM), a.k.a. associative memory.
6.2. Dictionaries 145

6.2.2 Continuous relaxation using kernel regression

A dictionary can been seen as a (potentially non-injective) function that


associates a value v to each key k. To obtain a continuous relaxation of
the operations associated to a dictionary, we can adopt a probabilistic
perspective of the mapping from keys to values. We can view keys and
values as two continuous random variables K and V . We can express
the conditional PDF f (v|k) of V |K in terms of the joint PDF f (k, v)
of (K, V ) and the marginal PDF f (k) of K as

f (k, v)
f (v|k) = .
f (k)

Integrating, we obtain the conditional expectation

f (k, v)
Z Z
E[V |K = k] = f (v|k)vdv = vdv.
V V f (k)

This is the Bayes predictor, in the sense that E[V |K] is the minimizer
of E[(h(K) − V )2 ] over the space of measurable functions h : K →
V. Using a sample of L input-output pairs (ki , vi ), corresponding to
key-value pairs in our case, Nadaraya–Watson kernel regression
estimates the joint PDF and the marginal PDF using kernel density
estimation (KDE). Using a product of isotropic kernels κσ and ρσ for
key-value pairs, we can define

L
1X
fbσ (k, v) := κσ (k − ki )ρσ (v − vi ).
L i=1

The corresponding marginal distribution on the keys is then given as


Z
fbσ (k) := fbσ (k, v)dv
V
L
1X
Z
= κσ (k − ki ) ρσ (v − vi )dv
L i=1 V
L
1X
= κσ (k − ki ).
L i=1
146 Data structures

Replacing f with fbσ , we obtain the following estimator of the conditional


expectation

fbσ (k, v)
Z
b |K = k] :=
E[V vdv
V fbσ (k)
i=1 κσ (k − ki )ρσ (v − vi )
Z 1 PL
= L
vdv
i=1 κσ (k − ki )
1 PL
V L

i=1 κσ (k − ki ) V ρσ (v − vi )vdv
PL R
=
i=1 κσ (k − ki )
PL

κσ (k − ki )vi
PL
= Pi=1 .
i=1 κσ (k − ki )
L

In the above, we assumed that ρσ (v − vi ) = pvi ,σ (v), where pvi ,σ (v) is


the PDF of a distribution whose mean is vi , so that
Z
ρσ (v − vi )vdv = EV ∼pvi ,σ [V ] = vi .
V

Given a dictionary d = ((k1 , v1 ), . . . , (kL , vL )), we can therefore define


the dict.softGet : DL (K, V) × K → conv(V) function as

κσ (k − ki )vi
PL
dict.softGet(d, k) := Pi=1 .
i=1 κσ (k − ki )
L

This kernel regression perspective on dictionaries was previously pointed


out by Zhang et al. (2021). It is illustrated in Fig. 6.2 with K = V = R.

6.2.3 Discrete probability distribution perspective

While the set of possible keys K is potentially infinite, the set of


keys {k1 , . . . , kL } ⊂ K associated with a particular dictionary d =
((k1 , v1 ), . . . , (kL , vL )) is finite. To a particular key k, we can therefore
associate a discrete probability distribution πk = (πk,1 , . . . , πk,L ) ∈ △L
over the keys (k1 , . . . , kL ) of d, defined by

κσ (k − ki )
πk,i := PL ∀i ∈ [L].
j=1 κσ (k − kj )
6.2. Dictionaries 147

...

...

Weight
Avg

Figure 6.3: Computation graph of the dict.softGet function. We can use a kernel
κσ to produce a discrete probability distribution πk = (πk,1 , . . . , πk,L ) ∈ △L , that
captures the affinity between the dictionary keys (k1 , . . . , kL ) and the input key k.
The dict.softGet function can then merely be seen as a convex combination (weighted
average) of values (v1 , . . . , vL ) using the probability values (πk,1 , . . . , πk,L ) as weights.

This distribution captures the affinity between the input key k and the
keys (k1 , . . . , kL ) of dictionary d. As illustrated in Fig. 6.3, we obtain

dict.softGet(d, k) = Ei∼Categorical(πk ) [vi ]


L
=
X
πk,i vi .
i=1

In the limit σ → 0, we recover


eq(k, ki )vi
PL
dict.get(d, k) = Pi=1 .
i=1 eq(k, ki )
L

While the dict.get function is using a mapping from keys k ∈


{k1 , . . . , kL } to integer indices [L], the dict.softGet function is using a
mapping from keys k ∈ {k1 , . . . , kL } to distributions πk ∈ △L . This
perspective allows us to reuse the soft functions we developed for lists
in Section 6.1. For example, we can softly replace the value associated
with key k by performing

list.softSet(d, πk , (k, v)).

Unlike dict.set, the function is differentiable w.r.t. the distribution πk .

6.2.4 Link with attention in Transformers


In the case when κσ is the Gaussian kernel, assuming that the keys
are normalized to have unit norm (which is often the case in practical
148 Data structures

implementations (Schlag et al., 2021; Dehghani et al., 2023)), we obtain

κσ (k − ki ) = exp(−∥k − ki ∥22 /(2σ 2 ))


= exp(−(∥k∥22 + ∥ki ∥2 )/(2σ 2 )) exp(⟨k, ki ⟩/σ 2 )
= exp(−σ 2 ) exp(⟨k, ki ⟩/σ 2 )

so that
κσ (k − ki )
πk,i = PL
j=1 κσ (k − kj )
exp(⟨k, ki ⟩/σ 2 )
= PL .
j=1 exp(⟨k, kj ⟩/σ )
2

We recognize the softargmax operator. Given, a dictionary


d = ((k1 , v1 ), . . . , (kL , vL )), we thus recover attention from Transform-
ers (Vaswani et al., 2017) as
exp(⟨k, ki ⟩/σ 2 )vi
dict.softGet(d, k) = PL .
j=1 exp(⟨k, kj ⟩/σ )
2

Transformers can therefore be interpreted as relying on a differentiable


dictionary mechanism. Besides Transformers, content-based memory
addressing is also used in neural Turing machines (Graves et al., 2014).

6.3 Summary

• Operations on lists are continuous and differentiable w.r.t. the


list, but not w.r.t. the integer index. Similarly, operations on
dictionaries are continuous and differentiable w.r.t. the dictionary,
but not w.r.t. the input key.

• Similarly to the way we handled the predicate in conditionals,


we can replace the integer index (respectively the key) with a
probability distribution over the indices (respectively the keys).

• This allows us to obtain a probabilistic relaxation of operations


on lists. In particular, the relaxation for list.get amounts to per-
forming a convolution. The relaxation for dict.get amounts to
computing a conditional expectation using kernel regression.
6.3. Summary 149

• When using a Gaussian kernel with keys normalized to unit norm,


we recover softargmax attention from Transformers.
Part III

Differentiating through
programs
7
Finite differences

One of the simplest way to numerically compute derivatives is to use


finite differences, which approximate the infinitesimal definition of deriva-
tives. Finite differences only require function evaluations, and can
therefore work with blackbox functions (i.e., they ignore the composi-
tional structure of functions). Without loss of generality, our exposition
focuses on computing directional derivatives ∂f (w)[v], for a function
f : E → F, evaluated at w ∈ E, in the direction v ∈ E.

7.1 Forward differences

From Definition 2.4 and Definition 2.13, the directional derivative and
more generally the JVP are defined as a limit,

f (w + δv) − f (w)
∂f (w)[v] := lim .
δ→0 δ
This suggests that we can approximate the directional derivative and
the JVP using
f (w + δv) − f (w)
∂f (w)[v] ≈ ,
δ

151
152 Finite differences

for some 0 < δ ≪ 1. This formula is called a forward difference. From


the Taylor expansion in Section 2.5.4, we indeed have

δ2 2 δ3
f (w+δv)−f (w) = δ∂f (w)[v]+ ∂ f (w)[v, v]+ ∂ 3 f (w)[v, v, v]+. . .
2 3!
so that

f (w + δv) − f (w) δ δ2
= ∂f (w)[v] + ∂ 2 f (w)[v, v] + ∂ 3 f (w)[v, v, v] + . . .
δ 2 3!
= ∂f (w)[v] + o(δ).

The error incurred by choosing a finite rather than infinitesimal δ in the


forward difference formula is called the truncation error. The Taylor
approximation above shows that this error is of the order of o(δ).
However, we cannot choose a too small value of δ, because the
evaluation of the function f on a computer rounds the value of f to
machine precision. Mathematically, a scalar-valued function f evaluated
on a computer becomes a function f˜ such that f˜(w) ≈ [f (w)/ε]ε,
where [f (w)/ε] denotes the closest integer of f (w)/ε ∈ R and ε is
the machine precision, i.e., the smallest non-zero real number encoded
by the machine. This means that the difference f (w + δv) − f (w)
evaluated on a computer is prone to round-off error of the order of
o(ε). We illustrate the trade-off between truncation and round-off errors
in Fig. 7.1.

7.2 Backward differences

As an alternative, we can approximate the directional derivative and


the JVP by
f (w) − f (w − δv)
∂f (w)[v] ≈ ,
δ
for some 0 < δ ≪ 1. This formula is called a backward difference.
From the Taylor expansion in Section 2.5.4, we easily verify that (f (w)−
f (w − δv))/δ = ∂f (w)[v] + o(δ), so that the truncation error is the
same as for the forward difference.
7.3. Central differences 153

10 2 Round-off error Truncation error


is dominant is dominant
10 4

Approximation error
10 6
10 8
10 10
10 12
10 14 Forward difference
Central difference
10 16 Complex Step
10 13 10 11 10 9 10 7 10 5 10 3 10 1

Figure 7.1: Numerical differentiation of f (x) := softplus(x) = log(1 + exp(x)),


to approximate f ′ (x) = logistic(x) at x = 1. The forward and central difference
methods induce both truncation error (for large δ) and round-off error (for small δ).
The complex step method enjoys smaller round-off error.

7.3 Central differences

Rather than using an asymmetric formula to approximate the derivative,


we can use
f (w + δv) − f (w − δv)
∂f (w)[v] ≈ ,

for some 0 < δ ≪ 1. This formula is called a central difference. From
the Taylor expansion in Section 2.5.4, we have
2δ 3 3
f (w + δv) − f (w − δv) =2δ∂f (w)[v] + ∂ f (w)[v, v, v]
3!
2δ 5 5
+ ∂ f (w)[v, . . . , v] + . . .
5!
so that
f (w + δv) − f (w − δv) δ2
=∂f (w)[v] + ∂ 3 f (w)[v, v, v] + . . .
2δ 3!
=∂f (w)[v] + o(δ 2 ).

We see that the terms corresponding to derivatives of even order


canceled out, allowing the formula to achieve o(δ 2 ) truncation error.
For any δ < 1, the truncation error of the central difference is much
154 Finite differences

smaller than the one of the forward or backward differences as confirmed


empirically in Fig. 7.1.

7.4 Higher-accuracy finite differences

The truncation error can be further reduced by making use of additional


function evaluations. One can generalize the forward difference scheme
by a formula of the form
p
ai
∂f (w)[v] ≈ f (w + iδv)
X

i=0
δ

requiring p + 1 evaluations. To select the ai and reach a truncation error


of order o(δ p ), we can use a Taylor expansion on each term of the sum
to get
p p p
ai ij δ j−1 j
f (w + iδv) = ∂ f (w)[v, . . . , v] + o(δ p ).
X X X
ai
k=0
δ k=0 j=0
j!

By grouping the terms in the sum for each order of derivative, we obtain
a set of p + 1 equations to be satisfied by the p + 1 coefficients a0 , . . . , ap ,
that is,

a0 + a1 + . . . + ap = 0
a1 + 2a2 + . . . + pap = 1
a1 + 2j a2 + . . . + pj ap = 0 ∀j ∈ {2, . . . , p}.

This system of equations can be solved analytically to derive the co-


efficients. Backward differences can be generalized similarly by using
∂f (w)[v] ≈ pi=0 aδi f (w − iδv). Similarly, the central difference scheme
P

can be generalized by using


p
ai
∂f (w)[v] ≈ f (w + iδv),
X

i=−p
δ

to reach a truncation error of order o(δ 2p ). Solving for the coefficients


a−p , . . . , ap as above reveals that a0 = 0. Therefore, only 2p evaluations
are necessary.
7.5. Higher-order finite differences 155

7.5 Higher-order finite differences

To approximate higher order derivatives, we can follow a similar rea-


soning. Namely, we can generalize the forward difference scheme to
approximate the derivative of order k by
p
ai
∂ k f (w)[v, . . . , v] ≈ f (w + iδv).
X

i=0
δk

As before, we can expand the terms in the sum. For the approximation
to capture only the k th derivative, we now require the coefficients ai to
satisfy

0j a0 + 1j a1 + 2j a2 + . . . + pj ap = 0 ∀j ∈ {0, . . . , k − 1}.
0k a0 + 1k a1 + 2k a2 + . . . + pk ap = k!
0j a0 + 1j a1 + 2j a2 + . . . + pj ap = 0 ∀j ∈ {k + 1, . . . , p}.

With the resulting coefficients, we obtain a truncation error of order


o(δ p−k+1 ), while making p + 1 evaluations. For example, for p = k = 2,
we can approximate the second-order derivative as

−(3/2)f (x) + 2f (x + δv) − (1/2)f (x + 2δv)


∂ 2 f (w)[v, v] ≈ ,
δ2
with a truncation error of order o(δ).
The central difference scheme can be generalized similarly by
p
ai
∂ k f (w)[v, . . . , v] ≈ f (w + iδv),
X

i=−p
δk

to reach truncation errors of order o(δ 2p+2−2⌈(k+1)/2⌉ ). For example, for


k = 2, p = 1, we obtain the second-order central difference

f (w + δv) + f (w − δv) − 2f (w)


∂ 2 f (w)[v, v] ≈ .
δ2
By using a Taylor expansion we see that, this time, the terms corre-
sponding to derivatives of odd order cancel out and the truncation
error is o(δ 2 ) while requiring 3 evaluations.
156 Finite differences

7.6 Complex-step derivatives

Suppose f is well defined on CP , the space of P -dimensional complex



numbers. Let us denote the imaginary unit by i = −1. Then, the
Taylor expansion of f reads
(iδ)2 2
f (w + (iδ)v) =f (w) + (iδ)∂f (w)[v] + ∂ f (w)[v, v]
2
(iδ)3 3
+ ∂ f (w)[v, v, v] + . . .
3!
δ2
=f (w) + (iδ)∂f (w)[v] − ∂ 2 f (w)[v, v]
2
i(δ)3 3
− ∂ f (w)[v, v, v] + . . . .
3!
We see that the real part corresponds to even-degree terms and the
imaginary part corresponds to odd-degree terms. We therefore obtain

Re(f (w + (iδ)v)) = f (w) + o(δ 2 )

and
f (w + (iδ)v)
 
Im = ∂f (w)[v] + o(δ 2 ).
δ
This suggests that we can compute directional derivatives using the
approximation
f (w + (iδ)v)
 
∂f (w)[v] ≈ Im ,
δ
for 0 < δ ≪ 1. This is called the complex-step derivative approxi-
mation (Squire and Trapp, 1998; Martins et al., 2003).
Contrary to forward, backward and central differences, we see that
only a single function call is necessary. A function call on complex
numbers may take roughly twice the cost of a function call on real
numbers. However, thanks to the fact that a difference of functions is
no longer needed, the complex-step derivative approximation usually
enjoys smaller round-off error as illustrated in Fig. 7.1. That said, one
drawback of the method is that all elementary operations within the
program implementing the function f must be well-defined on complex
numbers, e.g., using overloading.
7.7. Complexity 157

Table 7.1: Computational complexity in number of function evaluations for com-


puting the directional derivative and the gradient of a function f : RP → R by finite
differences and complex step derivatives.

Directional derivative Gradient


Forward difference 2 P +1
Backward difference 2 P +1
Central difference 2 2P
Complex step 1 P

7.7 Complexity

We now discuss the computational complexity in terms of function


evaluations of finite differences and complex-step derivatives. For con-
creteness, as this is the most common use case in machine learning, we
discuss the case of a single M = 1 output, i.e., we want to differentiate
a function f : RP → R. Whether we use forward, backward or central
differences, the computational complexity of computing the directional
derivative ∂f (w)[v] in any direction v amounts to two calls to f . For
computing the gradient ∇f (w), we can use (see Definition 2.7) that

[∇f (w)]j = ⟨∇f (w), ej ⟩ = ∂f (w)[ej ],

for j ∈ [P ]. For forward and backward differences, we therefore need


P + 1 function calls to compute the gradient, while we need 2P function
calls for central differences. For the complex step approximation, we need
P complex function calls. We summarize the complexities in Table 7.1.

7.8 Summary

• Finite differences are a simple way to numerically compute deriva-


tives using only function evaluations.

• Central differences achieve smaller truncation error than forward


and backward differences. It is possible to achieve smaller trunca-
tion error, at the cost of more function evaluations.
158 Finite differences

• Complex-step derivatives achieve smaller round-off error than


central differences but require the function and the program
implementing it to be well-defined on complex numbers.

• However, whatever the method used, finite differences require a


number of function calls that is proportional to the number of
dimensions. They are therefore seldom used in machine learning,
where there can be millions or billions of dimensions. The main
use cases of finite differences are therefore i) for blackbox functions
of low dimension and ii) for test purposes (e.g., checking that a
gradient function is correctly implemented).

• For modern machine learning, the main workhorse is automatic


differentiation, as it leverages the compositional structure of func-
tions. This is what we study in the next chapter.
8
Automatic differentiation

In Chapter 2, we reviewed the fundamentals of differentiation and


stressed the importance of two linear maps: the Jacobian-vector product
(JVP) and its adjoint, the vector-Jacobian product (VJP). In this chap-
ter, we review forward-mode and reverse-mode autodiff using these
two linear maps. We start with computation chains and then gener-
alize to feedforward networks and general computation graphs. We
also review checkpointing, reversible layers and randomized estimators.

8.1 Computation chains

To begin with, consider a computation chain (Section 4.1.1) repre-


senting a function f : S0 → SK expressed as a sequence of compositions
f := fK ◦ · · · ◦ f1 , where fk : Sk−1 → Sk . The computation of f can be

159
160 Automatic differentiation

unrolled into a sequence of operations

s0 ∈ S0
s1 := f1 (s0 ) ∈ S1
..
.
sK := fK (sK−1 ) ∈ SK
f (s0 ) := sK . (8.1)

Our goal is to compute the variations of f around a given input s0 . In


a feedforward network, this amounts to estimating the influence of a
given input s0 for fixed parameters (we will see how to estimate the
variations w.r.t. parameters w in the sequel).

Jacobian matrix. We first consider the computation of the full Jaco-


bian ∂f (s0 ), seen as a matrix, as the notation ∂ indicates. Following
Proposition 2.2, we have

∂f (s0 ) = ∂fK (sK−1 )∂fK−1 (sK−2 ) . . . ∂f2 (s1 )∂f1 (s0 ), (8.2)

where ∂fk (sk−1 ) are the Jacobians of the intermediate functions com-
puted at s0 , . . . , sK , as defined in Eq. (8.1). The main drawback of
this approach is computational: computing the full ∂f (s0 ) requires
to materialize the intermediate Jacobians in memory and to perform
matrix-matrix multiplications. However, in practice, computing the full
Jacobian is rarely needed. Indeed, oftentimes, we only need to right-
multiply or left-multiply with ∂f (s0 ). This gives rise to forward-mode
and reverse-mode autodiff, respectively.

8.1.1 Forward-mode

We now interpret the Jacobian ∂f (s0 ) as a linear map, as the non-bold


∂ indicates. Following Proposition 2.6, ∂f (s0 ) is the composition of the
intermediate linear maps,

∂f (s0 ) = ∂fK (sK−1 ) ◦ ∂fK−1 (sK−2 ) ◦ . . . ◦ ∂f2 (s1 ) ◦ ∂f1 (s0 ).


8.1. Computation chains 161

... ...

... ...

Figure 8.1: Forward-mode autodiff for a chain of computations. For readability,


we denoted the intermediate JVP as a function of two variables ∂fk : sk−1 , tk−1 7→
∂fk (sk−1 )[tk−1 ] with ∂fk (sk−1 )[tk−1 ] = tk .

Evaluating ∂f (s0 ) on an input direction v ∈ S0 can be decomposed,


like the function Eq. (8.1) itself, into intermediate computations

t0 := v
t1 := ∂f1 (s0 )[t0 ]
..
.
tK := ∂fK (sK−1 )[tK−1 ]
∂f (s0 )[v] := tK .

Each intermediate ∂fk (sk−1 )[tk−1 ] amounts to a Jacobian-vector prod-


uct (JVP) and can be performed in a forward manner, along the
computation of the intermediate states sk . This can also be seen as
multiplying the matrix defined in Eq. (8.2) with a vector, from right
to left. This is illustrated in Fig. 8.1 and the procedure is summarized
in Algorithm 8.1.

Computational complexity. The JVP follows exactly the computations


of f , with an additional variable tk being propagated. If we consider that
computing ∂fk is roughly as costly as computing fk , then computing a
JVP has roughly twice the computational cost of f . See Section 8.3.3
for a more general and more formal statement.
162 Automatic differentiation

Algorithm 8.1 Forward-mode autodiff for chains of computations


Functions: f := fK ◦ . . . ◦ f1
Inputs: input s0 ∈ S0 , input direction v ∈ S0
1: Initialize t0 := v
2: for k := 1, . . . , K do
3: Compute sk := fk (sk−1 ) ∈ Sk
4: Compute tk := ∂fk (sk−1 )[tk−1 ] ∈ Sk
Outputs: f (s0 ) := sK , ∂f (s0 )[v] = tK

Memory usage. The memory usage of a program at a given evaluation


step is the number of variables that need to be stored in memory to
ensure the execution of all remaining steps. The memory cost of a
program is then the maximal memory usage over all evaluation steps.
For our purposes, we analyze the memory usage and memory cost
by examining the given program. Formal definitions of operations on
memory such as read, write, delete and associated memory costs are
presented by Griewank and Walther (2008, Chapter 4).
For example, to execute the chain f = fK ◦ · · · ◦ f1 , at each step k,
we only need to have access to sk−1 to execute the rest of the program.
As we compute sk , we can delete sk−1 from memory and replace it by
sk . Therefore, the memory cost associated to the evaluation of f is just
the maximal dimension of the sk variables.
For forward mode autodiff, as we follow the computations of f , at
each step k, we only need to have access to sk−1 and tk−1 to execute the
rest of the program. The memory used by sk−1 and tk−1 can directly be
used for sk , tk once they are computed. The memory usage associated
to the JVP is summarized in Fig. 8.2. Overall the memory cost of the
JVP is then exactly twice the memory cost of the function itself.

8.1.2 Reverse-mode
In machine learning, most functions whose gradient we need to compute
take the form ℓ ◦ f , where ℓ is a scalar-valued loss function and f is a
network. As seen in Proposition 2.3, the gradient takes the form

∇(ℓ ◦ f )(s0 ) = ∂f (s0 )∗ [∇ℓ(f (s0 ))].


8.1. Computation chains 163

Memory
usage

Algorithm steps

Figure 8.2: Memory usage of forward-mode autodiff for a computation chain. Here
t0 = v, sK = f (s0 ), tK = ∂f (s0 )[v].

This motivates the need for applying the adjoint ∂f (s0 )∗ to ∇ℓ(f (s0 )) ∈
SK and more generally to any output direction u ∈ SK . From Propo-
sition 2.7, we have

∂f (s0 )∗ = ∂f1 (s0 )∗ ◦ . . . ◦ ∂fK (sK−1 )∗ .

Evaluating ∂f (s0 )∗ on an output direction u ∈ SK is decomposed as

rK = u
rK−1 = ∂fK (sK−1 )∗ [rK ]
..
.
r0 = ∂f1 (s0 )∗ [r1 ]
∂f (s0 )∗ [u] = r0 .

Each intermediate adjoint ∂fk (sk−1 )∗ amounts to a vector-Jacobian


product (VJP). The key difference with the forward mode is that the
procedure runs backward through the chain, hence the name reverse
mode autodiff. This can also be seen as multiplying Eq. (8.2) from left
to right. The procedure is illustrated in Fig. 8.3 and summarized in
Algorithm 8.2.

Computational complexity. In terms of number of operations, the


VJP simply passes two times through the chain, once forward, then
backward. If we consider the intermediate VJPs to be roughly as costly
as the intermediate functions themselves, the VJP amounts just to twice
the cost of the original function, just as the JVP. See Section 8.3.3 for
a more generic and formal statement.
164 Automatic differentiation

Forward pass

... ...

... ...

Backward pass

Figure 8.3: Reverse mode of automatic differentiation for a computation chain.


For readability, we denoted the intermediate VJPs as functions of two variables
∂fk∗ : (sk−1 , rk ) 7→ ∂fk (sk−1 )∗ [rk ], with ∂fk (sk−1 )∗ [rk ] = rk−1 .

Algorithm 8.2 Reverse-mode autodiff for chains of computations


Functions: f := fK ◦ . . . ◦ f1 ,
Inputs: input s0 ∈ S0 , output direction u ∈ SK
1: for k := 1, . . . , K do ▷ Forward pass
2: Compute sk := fk (sk−1 ) ∈ Sk
3: Initialize rK := u.
4: for k := K, . . . , 1 do ▷ Backward pass
5: Compute rk−1 := ∂fk (sk−1 )∗ [rk ] ∈ Sk−1
Outputs: f (s0 ) := sK , ∂f (s0 )∗ [u] = r0
8.1. Computation chains 165

Memory
usage

Algorithm steps
Forward pass Backward pass

Figure 8.4: Memory usage of reverse mode autodiff for a computation chain.

Memory usage. Recall that the memory usage of a program at a given


evaluation step is the number of variables that need to be stored in
memory to ensure the execution of the remaining steps. If we inspect
Algorithm 8.3, to execute all backward steps, that is the loop in line 4,
we need to have access to all the intermediate inputs s0 , . . . , sK−1 .
Therefore, the memory cost of reverse-mode autodiff is proportional to
the length of the chain K. Fig. 8.4 illustrates the memory usage during
reverse mode autodiff. It grows linearly until the end of the forward
pass and then progressively decreases until it outputs the value of the
function and the VJP. The memory cost can be mitigated by means of
checkpointing techniques presented in Section 8.5.

Decoupled function and VJP evaluations. The additional memory


cost of reverse mode autodiff comes with some advantages. If we need
to compute ∂f (s0 )∗ [ui ] for n different output directions ui , we only
need to compute and store once the intermediate computations sk and
then make n calls to the backward pass. In other words, by storing in
memory the intermediate computations sk , we may instantiate a VJP
operator, which we may apply to any u through the backward pass.
166 Automatic differentiation

Formally, the forward and backward passes can be decoupled as

forward(f, s0 ) := (f (s0 ), ∂f (s0 )∗ )

where
∂f (s0 )∗ [u] := backward(u; s0 , . . . , sK−1 ).
In functional programming terminology, the VJP ∂f (s0 )∗ is a closure,
as it contains the intermediate computations s0 , . . . sK . The same can
be done for the JVP ∂f (s0 ) if we want to apply to multiple directions
vi .

Example 8.1 (Multilayer perceptron with fixed parameters ). As a run-


ning example, consider a multilayer perceptron (MLP) with one
hidden layer and (for now) given fixed weights. As presented in
Chapter 4, an MLP can be decomposed as

s0 = x
s1 = f1 (s0 ) = σ(A1 s0 + b1 )
s2 = f2 (s1 ) = A2 s1 + b2
f (x) = s2 ,

for A1 , A2 , b1 , b2 some fixed parameters and σ an activation func-


tion such as the softplus activation function σ(x) = log(1 + ex ) with
derivative σ ′ (x) = ex /(1 + ex ).
Evaluating the JVP of f on an input x along a direction v can
then be decomposed as

t0 = v
t1 = σ ′ (A1 s0 + b1 ) ⊙ (A1 t0 )
t2 = A2 t1
∂f (x)[v] = t2 ,

where we used in the second line the JVP of element-wise function


as in Example 8.3.
Evaluating the VJP of f at x requires to evaluate the interme-
8.1. Computation chains 167

diate VJPs at the stored activations

r2 = u
r1 = ∂f2 (s1 )∗ [r2 ] = A⊤
2 r2
r0 = ∂f1 (s0 )∗ [r1 ] = A⊤ ′
1 (σ (A1 s0 + b1 ) ⊙ r1 )
∂f (x)∗ [u] = r0 .

8.1.3 Complexity of entire Jacobians


In this section, we analyze the time and space complexities of forward-
mode and reverse-mode autodiff for computing the entire Jacobian
matrix ∂f (s0 ) of a computation chain f = fK ◦ · · · ◦ f1 . We assume
Sk ⊆ RDk , DK = M and D0 = D.

Complexity of forward-mode autodiff


Using Definition 2.9, we find that we can extract each column ∂j f (s0 ) ∈
RM of the Jacobian matrix, for j ∈ [D], by multiplying with the standard
basis vector ej ∈ RD :

∂1 f (s0 ) = ∂f (s0 )e1


..
.
∂D f (s0 ) = ∂f (s0 )eD .

Computing the full Jacobian matrix therefore requires D JVPs with


vectors in RD . Assuming each fk in the chain composition has the form
fk : RDk−1 → RDk , seen as a matrix, ∂fk (sk−1 ) has size Dk × Dk−1.
Therefore, the computational cost of D JVPs is O D k=1 Dk Dk−1 .
PK

The memory cost is O(maxk∈[K] Dk ), since we can release intermediate


computations after each layer is processed. Setting D1 = · · · = DK−1 =
D for simplicity and using DK = M , we obtain that the computational
cost of computing D JVPs and therefore of computing the full Jacobian
matrix by forward-mode autodiff is O(M D2 + KD3 ). The memory cost
is O(max{D, M }). If a function is single-input D = 1, then the forward
mode computes at once all the Jacobian, which reduces to a single
directional derivative.
168 Automatic differentiation

Forward-mode Reverse-mode
Time O(M D2 + KD3 ) O(M 2 D + KM D2 )
Space O(max{M, D}) O(KD + M )

Table 8.1: Time and space complexities of forward-mode and reverse-mode autodiff
for computing the full Jacobian of a chain of functions f = fK ◦ · · · ◦ f1 , where
fk : RD → RD if k = 1, . . . , K − 1 and fK : RD → RM . We assume ∂fk is a dense
linear operator. Forward mode requires D JVPs. Reverse mode requires M VJPs.

Complexity of reverse-mode autodiff

Using Definition 2.9, we find that we can extract each row of the Jaco-
bian matrix, which corresponds to the transposed gradients ∇fi (s0 ) ∈
RD , for i ∈ [M ], by multiplying with the standard basis vector ei ∈ RM :

∇f1 (s0 ) = ∂f (s0 )∗ e1


..
.
∇fM (s0 ) = ∂f (s0 )∗ eM .

Computing the full Jacobian matrix therefore requires M VJPs with


vectors in RM . Assuming as before that each fk in the chain composition
has the
 form f : RDk−1
 → R , the computational cost ofPM VJPs
Dk
PK k
is O M k=1 Dk Dk−1 . However, the memory cost is O( K k=1 Dk ),
as we need to store the intermediate computations for each of the
K layers. Setting D0 = · · · = DK−1 = D for simplicity and using
DK = M , we obtain that the computational cost of computing M VJPs
and therefore of computing the full Jacobian matrix by reverse-mode
autodiff is O(M 2 D + KM D2 ). The memory cost is O(KD + M ). If the
function is single-output (M = 1), reverse-mode autodiff computes at
once the Jacobian, which reduces to the gradient.

When to use forward-mode vs. reverse-mode autodiff?

We summarize the time and space complexities in Table 8.1. Generally,


if M < D, reverse-mode is more advantageous at the price of some
memory cost. If M ≥ D, forward mode is more advantageous.
8.2. Feedforward networks 169

8.2 Feedforward networks

In the previous section, we derived forward-mode autodiff and reverse-


mode autodiff for computation chains with an input s0 ∈ S0 . In this
section, we now derive reverse-mode autodiff for feedforward networks,
in which each layer fk is now allowed to depend explicitly on some
additional parameters wk ∈ Wk . The recursion is
s0 := x ∈ S0
s1 := f1 (s0 , w1 ) ∈ S1
..
.
sK := fK (sK−1 , wK ) ∈ SK
f (x, w) := sK ,
where S0 = X and w = (w1 , . . . , wK ) ∈ W1 × · · · × WK . Each fk is
now a function of two arguments. The first argument depends on the
previous layer, but the second argument does not. This is illustrated in
Fig. 8.5. We now explain how to differentiate a feedforward network.

8.2.1 Computing the adjoint


The function has the form f : E → F, where E := X × (W1 × · · · ×
WK ) and F := SK . From Section 2.3, we know that the VJP has the
form ∂f (x, w)∗ : F → E. Therefore, we want to be able to compute
∂f (x, w)∗ [u] ∈ E for any u ∈ F.
Fortunately, the backward recursion is only a slight modification
of the computation chain case. Indeed, since fk : Ek → Fk , where
Ek := Sk−1 × Wk and Fk := Sk , the intermediate VJPs have the form
∂fk (sk−1 , wk )∗ : Fk → Ek . We therefore arrive at the recursion
rK = u ∈ SK
(rK−1 , gK ) = ∂fK (sK−1 , wK )∗ [rK ] ∈ SK−1 × WK
..
.
(r0 , g1 ) = ∂f1 (s0 , w1 )∗ [r1 ] ∈ S0 × W1 .
The final output is
∂f (x, w)∗ [u] = (r0 , (g1 , . . . , gK )).
170 Automatic differentiation

... ...

Figure 8.5: Computation graph of an MLP as a function of its parameters.

8.2.2 Computing the gradient


We often compose a network with a loss function

L(w; x, y) := ℓ(f (x, w); y) ∈ R.

From Proposition 2.7, the gradient is given by

∇L(w; x, y) = (g1 , . . . , gK ) ∈ W1 × · · · × WK

where
∂f (x, w)∗ [u] = (r0 , (g1 , . . . , gK )),
with u = ∇ℓ(f (x, w); y) ∈ SK . The output r0 ∈ S0 , where S0 = X ,
corresponds to the gradient w.r.t. x ∈ X and is typically not needed,
except in generative modeling settings. The full procedure is summarized
in Algorithm 8.3.
8.2. Feedforward networks 171

Forward pass

... ... ...

... ...

Backward pass

Figure 8.6: Reverse mode of automatic differentiation, a.k.a., gradient back-


propagation to compute the gradient of the loss of an MLP on an input label pair.
For readability, we denoted the intermediate VJPs as functions of three variables
∂fk∗ : (sk−1 , wk−1 , rk ) 7→ ∂fk (sk−1 , wk )[rk ] with ∂fk (sk−1 , wk )∗ [rk ] = (rk−1 , gk ).

Algorithm 8.3 Gradient back-propagation for feedforward networks


Functions: f1 , . . . , fK in sequential order
Inputs: data point (x, y) ∈ X × Y
parameters w = (w1 , . . . wK ) ∈ W1 × · · · × WK
1: Initialize s0 := x ▷ Forward pass
2: for k := 1, . . . , K do
3: Compute and store sk := fk (sk−1 , wk ) ∈ Sk
4: Compute ℓ(sK ; y) and u := ∇ℓ(sK ; y) ∈ SK
5: Initialize rK := u ∈ SK ▷ Backward pass
6: for k := K, . . . , 1 do
7: Compute (rk−1 , gk ) := ∂fk (sk−1 , wk )∗ [rk ] ∈ Sk−1 × Wk
8: Outputs: L(w; x, y) := ℓ(sK ; y), ∇L(w; x, y) = (g1 , . . . , gK )
172 Automatic differentiation

Algorithm 8.4 Forward-mode autodiff for computation graphs


Functions: f1 , . . . , fK in topological order
Inputs: input s0 ∈ S0 , input direction v ∈ S0
1: Initialize t0 := v
2: for k := 1, . . . , K do ▷ Forward pass
3: Retrieve parent nodes i1 , . . . , ipk := pa(k)
4: Compute sk := fk (si1 , . . . , sipk )
5: Compute

tk := ∂fk (si1 , . . . , sipk )[ti1 , . . . , tipk ]


pk
= ∂j fk (si1 , . . . , sipk )[tij ].
X

j=1

6: Outputs: f (s0 ) := sK , ∂f (s0 )[v] = tK

8.3 Computation graphs

In the previous sections, we reviewed autodiff for computation chains


and its extension to feedforward networks. In this section, we generalize
autodiff to computation graphs introduced in Section 4.1.3.

8.3.1 Forward-mode
The forward mode corresponds to computing a JVP in an input direction
v ∈ S0 . The algorithm consists in computing intermediate JVPs along
the forward pass. We initialize t0 := v ∈ S0 . Using Proposition 2.8, the
derivatives on iteration k ∈ [K] are propagated as

tk := ∂fk (si1 , . . . , sipk )[ti1 , . . . , tipk ]


pk
= ∂j fk (si1 , . . . , sipk )[tij ],
X

j=1

where i1 , . . . , ipk = pa(k). The final output is ∂f (s0 )[v] = tK . The


resulting generic forward-mode autodiff is summarized in Algorithm 8.4.
Although not explicitly mentioned, we can release sk and tk from
memory when no child node depends on node k.
8.3. Computation graphs 173

8.3.2 Reverse-mode

The reverse mode corresponds to computing a VJP in an output di-


rection u ∈ SK . We first perform a forward pass to compute the
intermediate values and store the corresponding VJP as a pointer, since
the VJP shares some computations with the function itself. From Propo-
sition 2.8, the VJP returns a tuple with the same length as the number
of inputs to the function:

δi1 ,k , . . . , δipk ,k = ∂fk (si1 , . . . , sipk )∗ [rk ].

After the forward pass, we traverse the graph in reverse topological


order as illustrated in Fig. 8.7. If an intermediate value sk is used by
later functions fj1 , . . . , fjck for j1 , . . . , jck = ch(k), the derivatives with
respect to sk need to sum all the variations through the fj functions in
a variable rk ,
rk :=
X
δk,j .
j∈{j1 ,...,jck }

In practice, as we go backward through the computation graph, we can


accumulate the VJPs corresponding to node k by doing in-place updates
of the rk values. The topological ordering ensures that rk has been fully
computed when we reach node k. The resulting generic reverse-mode
autodiff is presented in Algorithm 8.5.

8.3.3 Complexity, the Baur-Strassen theorem

For computing the gradient of a function f : E → R represented by


a computation graph, we saw that the reverse mode is more efficient
than the forward mode. As we previously stated, assuming that the
elementary functions fk in the DAG and their VJP have roughly the
same computational complexity, then f and ∇f have roughly the same
computational complexity. This fact is crucial and is the pillar on which
modern machine learning relies: it allows us to optimize high-dimensional
functions by gradient descent.
For arithmetic circuits, reviewed in Section 4.1.4, this crucial fact
is made more precise in the celebrated Baur-Strassen theorem (Baur
and Strassen, 1983). If f is a polynomial, then so is its gradient ∇f .
174 Automatic differentiation

Algorithm 8.5 Reverse-mode autodiff for computation graphs


Functions: f1 , . . . , fK in topological order
Inputs: input s0 ∈ S0 , output direction u ∈ SK
1: for k := 1, . . . , K do ▷ Forward pass
2: Retrieve parent nodes i1 , . . . , ipk := pa(k)
3: Compute sk := fk (si1 , . . . , sipk )
4: Instantiate VJP lk := ∂fk (si1 , . . . , sipk )∗
5: Initialize rK := u, rk := 0 ∀k ∈ {0, . . . , K − 1} ▷ Backward pass
6: for k := K, . . . , 1 do
7: Retrieve parent nodes i1 , . . . , ipk = pa(k)
8: Compute δi1 ,k , . . . , δipk ,k := lk [rk ]
9: Compute rij ← rij + δij ,k ∀j ∈ {1, . . . , pk }
10: Outputs: f (s0 ) := sK , ∂f (s0 )∗ [u] = r0

The theorem gives an upper bound on the size of the best circuit for
computing ∇f from the size of the best circuit for computing f .

Proposition 8.1 (Baur-Strassen’s theorem). For any polynomial


f : E → R, we have
S(∇f ) ≤ 5 · S(f ),
where the size S(f ) of a polynomial f is defined in Definition 4.1.

A simpler proof by backward induction was given by Morgenstern


(1985). See also the proof of Theorem 9.10 in Chen et al. (2011). For
general computation graphs, that have more primitive functions than
just + and ×, a similar result can be obtained; see, e.g., (Bolte et al.,
2022, Theorem 2).

8.4 Implementation

8.4.1 Primitive functions

An autodiff system implements a set A of primitive or elementary


functions, which serve as building blocks for creating other functions, by
function composition. For instance, we saw that in arithmetic circuits
8.4. Implementation 175

Computations Forward mode Reverse mode

Figure 8.7: Left: Assuming a topological order, the computation of fk on iteration


k involves pk inputs computed by fi1 , . . . , fipk , where {i1 , . . . , ipk } = pa(k), and is
used in ck functions fj1 , . . . , fjck , where {j1 , . . . , jck } = ch(k). Middle: Forward-
mode autodiff on iteration k, denoting the shorthand spk := (si1 , . . . , sipk ). This
computes ∂i fk (spk )[tij ] for incoming tij , j = 1, . . . , pk , then sum these results to
pass tk to the next iterations. Right: Reverse-mode autodiff on iteration k. As we
traverse the graph backward, we first sum the contributions coming from each child
computation fj1 , . . . , fjck for {j1 , . . . , jck } = ch(k) to get rk . We then feed this rk
to each ∂i fk (spk ) and continue the procedure in reverse topological order.

(Section 4.1.4), A = {+, ×}. More generally, A may contain all the
necessary functions for expressing programs. We emphasize, however,
that A is not necessarily restricted to low-level functions such as log
and exp, but may also contain higher-level functions. For instance,
even though the log-sum-exp can be expressed as the composition
of elementary operations (log, sum, exp), it is usually included as a
primitive on its own, both because it is a very commonly-used building
block, but also for numerical stability reasons.

8.4.2 Closure under function composition

Each function fk in a computation graph belongs to a set F, the class of


functions supported by the system. A desirable property of an autodiff
implementation is that the set F is closed under function composition,
meaning that if f ∈ F and g ∈ F, then f ◦ g ∈ F. This means that
composed functions can themselves be used for composing new functions.
This property is also crucial for supporting higher-order differentiation
(Chapter 9) and automatic linear transposition (Section 8.4.4). When
176 Automatic differentiation

fk is a composition of elementary functions in A, then fk itself is a


nested DAG. However, we can always inline each composite function,
such that all functions in the DAG belong to A.

8.4.3 Examples of JVPs and VJPs

An autodiff system must implement for each f ∈ A its JVP for support-
ing the forward mode, and its VJP for supporting the reverse mode.
We give a couple of examples. We start with the JVP and VJP of linear
functions.

Example 8.2 (JVP and VJP of linear functions). Consider the matrix-
vector product f (W ) = W x ∈ RM , where x ∈ RD is fixed and
W ∈ RM ×D . As already mentioned in Section 2.3.1, the JVP of f
at W ∈ RM ×D along an input direction V ∈ RM ×D is simply

∂f (W )[V ] = f (V ) = V x ∈ RM .

To find the associated VJP, we note that for any u ∈ RM and


V ∈ RM ×D , we must have ⟨∂f (W )[V ], u⟩ = ⟨V , ∂f (W )∗ [u]⟩.
Using the properties of the trace, we have

⟨∂f (W )[V ], u⟩ = ⟨V x, u⟩ = tr(x⊤ V ⊤ u) = ⟨V , ux⊤ ⟩.

Therefore, we find that the VJP is given by

∂f (W )∗ [u] = ux⊤ ∈ RM ×D .

Similarly, consider now a matrix-matrix product f (W ) = W X,


where W ∈ RM ×D and where X ∈ RD×N is fixed. The JVP at
W ∈ RM ×D along an input direction V ∈ RM ×D is simply

∂f (W )[V ] = f (V ) = V X ∈ RM ×N .

The VJP along the output direction U ∈ RM ×N is

∂f (W )∗ [U ] = U X ⊤ ∈ RM ×D .

Another simple example are element-wise separable functions.


8.4. Implementation 177

Example 8.3 (JVP and VJP of separable function). Consider the func-
tion f (w) := (g1 (w1 ), . . . , gP (wP )), where each gi : R → R has a
derivative gi′ . The Jacobian matrix is then a diagonal matrix

∂f (w) = diag(g1′ (w1 ), . . . , gP′ (wP )) ∈ RP ×P .

In this case, the JVP and VJP are actually the same

∂f (w)[v] = ∂f (w)∗ [v] = (g1′ (w1 ), . . . , gP′ (wP )) ⊙ v,

where ⊙ indicates element-wise multiplication.

8.4.4 Automatic linear transposition


On first sight, if we want to support both forward ans reverse modes, it
appears like we need to implement both the JVP and the VJP for each
primitive operation f ∈ A. Fortunately, there exists a way to recover
VJPs from JVPs, and vice-versa.
We saw in Section 2.3 that if l(w) is a linear map, then its JVP is
∂l(w)[v] = l(v) (independent of w). Conversely, the VJP is ∂l(w)∗ [u] =
l∗ (u), where l∗ is the adjoint operator of l (again, independent of w).
Let us define l(u; w) := ∂f (w)∗ [u], i.e., the VJP of f in the output
direction u. Since l(u; w) is linear in u, we can apply the reasoning
above to compute its VJP
∂l(u; w)∗ [v] = l∗ (v; w) = ∂f (w)∗∗ [v] = ∂f (w)[v],
which is independent of u. In words, the VJP of a VJP is the cor-
responding JVP! This means that we can implement forward-mode
autodiff even if we only have access to VJPs. As an illustration and
sanity check, we give the following example.

Example 8.4 (Automatic transpose of “dot”). If we define


f (x, W ) := W x, from Example 8.2, we know that

∂f (x, W )∗ [u] = (W ⊤ u, ux⊤ )


= (f (u, W ⊤ ), f (x⊤ , u))
=: l(u; x, W ).
178 Automatic differentiation

Using Proposition 2.9, we obtain

∂l(u; x, W )∗ [v, V ] = f (v, W ) + f (x, V )


= Wv + V x
= ∂f (x, W )[v, V ].

The other direction, automatically creating a VJP from a JVP, is


also possible but is more technical and relies on the notion of partial
evaluation (Frostig et al., 2021; Radul et al., 2022).

8.5 Checkpointing

We saw that forward-mode autodiff can release intermediate computa-


tions from memory along the way, while reverse-mode autodiff needs to
cache all of them. This means that the memory complexity of reverse-
mode autodiff, in its standard form, grows linearly with the number of
nodes in the computation graph. A commonly-used technique to circum-
vent this issue is checkpointing, which trades-off computation time for
better memory usage. Checkpointing works by selectively storing only a
subset of the intermediate nodes, called checkpoints, and by comput-
ing others on-the-fly. The specific choice of the checkpoint locations in
the computation graph determines the memory-computation trade-off.
While it is possible to heuristically set checkpoints at user-specified
locations, it is also possible to perform a checkpointing strategy algo-
rithmically, as studied in-depth by Griewank (1992) and Griewank and
Walther (2008). In this section, we review two such algorithms: recursive
halving and dynamic programming (divide-and-conquer). Our exposi-
tion focuses on computation chains f = fK ◦ . . . ◦ f1 with fi : RD → RD
for simplicity.

Computational and memory complexities at two extremes. Let C(K)


be the number of calls to the individual functions fi (we ignore the
cost of computing the intermediate VJPs) and M(K) be the number
of function inputs cached, when performing reverse-mode autodiff on
a chain f = fK ◦ . . . ◦ f1 . On one extreme, if we store all intermediate
computations, as done in Algorithm 8.1, to compute only the VJP
8.5. Checkpointing 179

Algorithm 8.6 Reverse-mode autodiff with constant memory


vjp_full_recompute(fK ◦ . . . ◦ f1 , s0 , u) := ∂(fK ◦ . . . ◦ f1 )(s0 )∗ [u]
Inputs: Chain fK ◦ . . . ◦ f1 , input s0 ∈ S0 , output direction u ∈ SK
1: if K = 1 then
2: return ∂f1 (s0 )∗ [u]
3: else
4: Set rK = u
5: for k := K, . . . , 1 do
6: Compute sk−1 = (fk−1 ◦ . . . ◦ f1 )(s0 )
7: Compute rk−1 = ∂fk (sk−1 )∗ [rk ]
8: return: r0

∂f (s0 )∗ [u], we have


C(K) = K − 1 and M(K) = K.
This is optimal w.r.t. computational complexity, but suboptimal w.r.t.
memory. On the other extreme, if we only store the initial input, as
done in Algorithm 8.6, then we have
C(K) = K(K − 1)/2 and M(K) = 1.
This is optimal w.r.t. memory but leads to a computational complexity
that is quadratic in K.

8.5.1 Recursive halving


As a first step towards obtaining a better computation-memory trade-off,
we may split the chain sK = fK ◦ · · · ◦ f1 (s0 ) as
sK/2 = fK/2 ◦ . . . ◦ f1 (s0 )
sK = fK ◦ . . . ◦ fK/2+1 (sK/2 ),
for K even. Then, rather than recomputing all intermediate computa-
tions sk from the input s0 as in Algorithm 8.6, we can store sK/2 and
recompute sk for k > K/2 starting from sK/2 . Formally, this strategy
amounts to the following steps.
1. Compute sK/2 = fK/2 ◦ . . . ◦ f1 (s0 )
180 Automatic differentiation

Algorithm 8.7 Reverse-mode autodiff with recursive halving


vjp_halving(fK ◦ . . . ◦ f1 , s0 , u) := ∂(fK ◦ . . . ◦ f1 )(s0 )∗ [u]
Functions: Chain fK ◦ . . . ◦ f1
Inputs: input s0 ∈ S0 , output direction u ∈ SK
1: if K = 1 then
2: return ∂f1 (s0 )∗ [u]
3: else
4: Compute sK/2 = fK/2 ◦ . . . ◦ f1 (s0 )
5: Compute rK/2 = vjp_halving(fK ◦ . . . ◦ fK/2+1 , sK/2 , u)
6: Compute r0 = vjp_halving(fK/2 ◦ . . . ◦ f1 , s0 , rK/2 )
7: return: r0

2. Compute rK/2 = vjp_full_recompute(fK ◦ . . . ◦ fK/2+1 , sK/2 , u)

3. Compute r0 = vjp_full_recompute(fK/2 ◦ . . . ◦ f1 , s0 , rK/2 )

At the expense of having to store the additional checkpoint sK/2 , this


already roughly halves the computational complexity compared to Al-
gorithm 8.6.
We can then apply this reasoning recursively, as formalized in Al-
gorithm 8.7. The algorithm is known as recursive binary schedule
(Griewank, 2003) and illustrated in Fig. 8.8. In terms of number of
function evaluations C(K), for K even, we make K/2 function calls,
and we call the procedure recursively twice, that is,

C(K) = 2C(K/2) + K/2.

If the chain is of length 1, we directly use the VJP, so C(1) = 0. Hence,


the numbers of function calls, if K is a power of 2, is
K
C(K) = log2 K.
2
In terms of memory usage, Algorithm 8.7 uses s0 not only at line 4
but also at line 6. So when the algorithm is called recursively on the
second half of the chain at line 5, one memory slot is taken by s0 . This
line is called recursively until the chain is reduced to a single function.
At that point, the total number of memory slots used is equal to the
8.5. Checkpointing 181

Function step

2
1
0
Time step

Forward computation Storage in memory Backward computation

Figure 8.8: Illustration of checkpointing with recursive halving, for a chain of


8 functions. The chain is first fully evaluated while storing some computations
as checkpoints in memory. Then, during the backward pass, we recompute some
intermediate values from the latest checkpoint available. In contrast, vanilla reverse-
mode autodiff (with full caching of the intermediate computations) would lead to a
simple triangle shape.

number of times we split the function in half, that is log2 K for K a


power of 2. On the other hand, the input s0 is no longer used after line 6
of Algorithm 8.7. At that line, the memory slot taken by s0 can be
consumed by the recursive call on the first-half. In other words, calling
the algorithm recursively on the first half does not incur extra memory
cost. So if K is a power of 2, the memory cost of Algorithm 8.7 is

M(K) = log2 K.

8.5.2 Dynamic programming


Recursive halving requires log2 K memory slots for a chain of length
K. However, as illustrated in Fig. 8.8, at a given time step, all memory
slots may not be exploited.
To optimize the approach, we observe that recursive halving is
just one instance of a program that splits the chain and calls itself
recursively on each part. In other words, it is a form of divide-and-
conquer algorithm. Rather than splitting the chain in half, we may
consider splitting the chain at some index l. One split is used to reverse
182 Automatic differentiation

the computations from l + 1 to K by a recursive call that consumes one


memory slot. The other split is used on a recursive call that reverses
the computations from 0 to l. That second call does not require an
additional memory slot, as it can use directly the memory slot used
by the original input s0 . To split the chain in such two parts, we need
l intermediate computations to go from s0 to sl . The computational
complexity C(k, s), counted as the number of function evaluations, for
a chain of length k with s memory slots then satisfies the recurrence

C(k, s) = C(k − l, s − 1) + C(l, s) + l,

for all l ∈ {1, . . . , k − 1}. By simply taking l = k/2, we recover exactly


the computational complexity of recursive halving. To refine the latter,
we may split the chain by selecting l to minimize the complexity. An
optimal scheme must satisfy the recursive equation,

C ∗ (k, s) := min {C ∗ (k − l, s − 1) + C ∗ (l, s) + l}. (8.3)


1≤l≤K−1

Note that C ∗ (K, S) can be computed from C ∗ (k, s) for k = 1, . . . , K − 1,


s = 1, . . . , S − 1. This suggests a dynamic programming approach to
find an optimal scheme algorithmically. For a chain of length k = 1, the
cost is null as we directly reverse the computation, so C ∗ (1, s) := 0. On
the other hand for a memory s = 1, there is only one possible scheme
that saves only the initial input as in Algorithm 8.6, so C ∗ (k, 1) :=
(k(k − 1))/2. The values C ∗ (k, s) can then be computed incrementally
from k = 1 to K and s = 1 to S using Eq. (8.3). The optimal splits can
be recorded along the way as

l∗ (k, s) := arg min{C ∗ (k − l, s − 1) + C ∗ (l, s) + l}.


1≤l≤k−1

The optimal split for K, S can then be found by backtracking the


optimal splits along both branches corresponding to C ∗ (k − l, s − 1)
and C ∗ (l, s). As the final output consists in traversing a binary tree,
it was called treeverse (Griewank, 1992). Note that the dynamic
programming procedure is generic and could a priori incorporate varying
computational costs for the intermediate functions fk .
8.5. Checkpointing 183

Analytical formula

It turns out that we can also find an optimal scheme analytically. This
scheme was found by Griewank (1992), following the analysis of optimal
inversions of sequential programs by divide-and-conquer algorithms
done by Grimm et al. (1996); see also Griewank (2003, Section 6) for
a simple proof. The main idea consists in considering the number of
times an evaluation step fk is repeated. As we split the chain at l, all
steps from 1 to l will be repeated at least once. In other words, treating
the second half of the chain incurs one memory cost, while treating the
first half of the chain incurs one repetition cost. Griewank (1992) shows
that for fixed K, S, we can find the minimal number of repetitions
analytically and build the corresponding scheme with simple formulas
for the optimal splits.
Compared to the dynamic programming approach, it means that we
do not need to compute the pointers l∗ (k, s), and we can use a simple
formula to set l∗ (k, s). We still need to traverse the corresponding binary
tree given K, S and the l∗ (k, s) to obtain the schedules. Note that such
optimal scheme does not take into account varying computational costs
for the functions fk .

8.5.3 Online checkpointing

The optimal scheme presented above requires knowing the total number
of nodes in the computation graph ahead of time. However, when
differentiating through for example a while loop (Section 5.10), this is
not the case. To circumvent this issue, online checkpointing schemes
have been developed and proven to be nearly optimal (Stumm and
Walther, 2010; Wang et al., 2009). These schemes start by defining a set
of S checkpoints with the first S computations, then these checkpoints
are rewritten dynamically as the computations keep going. Once the
computations terminate, the optimal approach presented above for a
fixed length is applied on the set of checkpoints recorded.
184 Automatic differentiation

Algorithm 8.8 Reverse-mode autodiff for reversible chains.


Functions: f := fK ◦ . . . ◦ f1 , with each fk invertible
Inputs: input s0 ∈ S0 , output direction u ∈ SK
1: Compute sK = fK ◦ . . . ◦ f1 (s0 )
2: for k := K, . . . , 1 do
3: Compute sk−1 = fk−1 (sk )
4: Compute rk−1 = ∂fk (sk−1 )∗ [rk ]
Outputs: f (s0 ) := sK , ∂f (s0 )∗ [u] = r0

8.6 Reversible layers

8.6.1 General case

The memory requirements of reverse-mode autodiff can be completely


alleviated when the functions fk are invertible (meaning that fk−1 exists)
and when fk−1 is easily accessible. In that case, rather than storing the
intermediate computations sk−1 , necessary to compute the VJP rk 7→
∂fk (sk−1 )∗ [rk ], one can compute them on the fly during the backward
pass from sk using sk−1 = fk−1 (sk ). We summarize the procedure for
the case of computation chains in Algorithm 8.8. Compared to vanilla
reverse-mode autodiff in Algorithm 8.2, the algorithm has optimal
memory complexity, as we can release sk and rk as we go.
In practice, fk−1 often does not exist or may not be easily accessible.
However, network architectures can be constructed to be easily invertible
by design. Examples include reversible residual networks (Gomez et
al., 2017), orthonormal RNNs (Helfrich et al., 2018), neural ODEs
(Section 12.6), and momentum residual neural networks (Sander et al.,
2021a); see also references therein.

8.6.2 Case of orthonormal JVPs

When the JVP of each fk is an orthonormal linear mapping, i.e.,

∂fk (sk−1 )−1 = ∂fk (sk−1 )∗ ,


8.7. Randomized forward-mode estimator 185

it is easy to check that the VJP of f = fK ◦ . . . ◦ f1 is equal to the JVP


of f −1 = f1−1 ◦ . . . ◦ fK
−1
, that is
∂f (s0 )∗ [u] = ∂f −1 (sK )[u].
In other words, in the case of orthormal JVPs, reverse-mode autodiff of
f coincides with forward-mode autodiff of f −1 .

8.7 Randomized forward-mode estimator

Forward-mode autodiff does not require to store intermediate activations.


However, for a function f : RP → R, computing the gradient ∇f using
forward-mode autodiff requires P JVPs, which is intractable if P is large.
Can we approximate ∇f with fewer JVPs? The following proposition
gives an unbiased estimator of ∇f that only involves JVPs.

Proposition 8.2 (Unbiased forward-mode estimator of the gradient).


Let f : RP → R be a differentiable function. Then,

∇f (µ) = EZ∼p [∂f (µ)[Z]Z]


= EZ∼p [⟨∇f (µ), Z⟩Z] .

where p := Normal(0, 1)P is the isotropic Gaussian distribution.

This estimator is for instance used by Baydin et al. (2022). It


can be seen as the zero-temperature limit of the gradient of a
perturbed function, estimated by the score-function estimator (SFE);
see Section 14.4.6.
In practice, the expectation above can be approximated by drawing
M noise vectors z1 , . . . , zM , and averaging ⟨∇f (µ), zi ⟩ over i ∈ [M ].
A word of caution: while this estimator can be useful for example
when we do not want to store the intermediate activations for memory
reasons, this of course comes at the cost of increasing the variance,
which influences the convergence rate of SGD, as seen in Section 16.2.

8.8 Summary

• Computer programs can be seen as directed acyclic graphs, where


nodes correspond to the output of intermediate operations in
186 Automatic differentiation

the program, and edges represent the dependencies of current


operations on past operations.

• Automatic differentiation (autodiff) for a function f : RP → RM


has two main modes: forward mode and reverse mode.

• The forward mode: i) uses JVPs, ii) builds the Jacobian one
column at a time, iii) is efficient for tall Jacobians (M ≥ P ), iv)
need not store intermediate computations.

• The reverse mode: i) uses VJPs, builds the Jacobian one row at
a time, iii) is efficient for wide Jacobians (P ≥ M ), iv) needs to
store intermediate computations, in order to be computationally
optimal.

• To trade computational efficiency for better memory efficiency,


we can use checkpointing techniques.

• The complexity of computing the gradient of a function f : RP →


R using the reverse mode is at most a constant time bigger than
that of evaluating the function itself. This is the Baur-Strassen
theorem, in arithmetic circuits. This astonishing result is one of
the pillars of modern machine learning.
9
Second-order automatic differentiation

We review in this chapter how to perform automatic differentiation for


second-order derivatives.

9.1 Hessian-vector products

We consider in this section a function f : E → R. Similarly to the


Jacobian, for most purposes, we do not need access to the full Hessian
but rather to the Hessian-vector product (HVP) ∇2 f (w)[v] at w ∈ E,
in a direction v ∈ E, as defined in Definition 2.19. The latter can be
computed in four different ways, depending on how we combine the two
main modes of autodiff.

9.1.1 Four possible methods


An HVP can be computed in four different ways.

1. Reverse on reverse: The Hessian can be seen as the transposed


Jacobian of the gradient, hence the HVP can be computed as the
VJP of the gradient,

∇2 f (w)[v] = ∂(∇f )(w)∗ [v].

187
188 Second-order automatic differentiation

2. Forward on reverse: Owing to its symmetry (see Proposi-


tion 2.10), the Hessian can also be seen as the Jacobian of the
gradient, hence the HVP can be computed as the JVP of the
gradient,
∇2 f (w)[v] = ∂(∇f )(w)[v].

3. Reverse on forward: Recall that for any function g : E → E, the


VJP can equivalently be defined as the gradient along an output
direction v ∈ E, that is,
∂g(w)∗ [v] = ∇⟨g, v⟩(w),
where we recall the shorthand ⟨g, v⟩(w) := ⟨v, g(w)⟩, so that
⟨g, v⟩ is a function of w. In our case, we can therefore rewrite the
reverse-on-reverse approach as
∂(∇f )(w)∗ [v] = ∇⟨∇f, v⟩(w).
We know that ⟨∇f, v⟩(w) = ⟨∇f (w), v⟩ = ∂f (w)[v] is the JVP
of f at w along v. Therefore, we can also compute the HVP as
the gradient of the JVP of f at w along v,
∇2 f (w)[v] = ∇(∂f (·)[v])(w),
where we use the notation (∂f (·)[v])(w) := ∂f (w)[v] to insist on
the fact that it is a function of w.

4. Forward on forward: Finally, we can use the definition of the


HVP in Definition 2.19 as a vector of second partial derivatives
along v and each canonical direction. That is, assuming E = RP ,
we can compute the JVP of the JVP P times,
∇2 f (w)[v] = (∂ 2 f (w)[v, ei ])Pi=1 .

The four different ways of computing the HVP are summarized in


Table 9.1.

9.1.2 Complexity
To get a sense of the computational and memory complexity of the four
approaches, we consider a chain of functions f := fK ◦ · · · ◦ f1 as done
9.1. Hessian-vector products 189

Method Computation
Reverse on reverse (VJP of gradient) ∂(∇f )(w)∗ [v]
Forward on reverse (JVP of gradient) ∂(∇f )(w)[v]
Reverse on forward (gradient of JVP) ∇(∂f (·)[v])(w)
Forward on forward (JVPs of JVPs) (∂ 2 f (w)[v, ei ])Pi=1

Table 9.1: Four different ways of computing the HVP ∇2 f (w)[v].

Figure 9.1: Computation graph corresponding to reverse mode autodiff for eval-
uating the gradient of f = fK ◦ . . . f1 . While f is a simple chain, ∇f is a DAG.

in Section 8.1. To simplify our analysis, we assume fk : RP → RP for


k ∈ {1, . . . , K − 1} and fK : RP → R.
The computation graph of the reverse mode is illustrated in Fig. 9.1.
While f = fK ◦ · · · ◦ f1 would be represented by a simple chain, the
computational graph of ∇f is no longer a chain: it is a DAG. This
is due to the computations of ∂fk (sk−1 )[rk ], where both sk−1 and rk
depend on s0 .
We illustrate the computation graphs of reverse-on-reverse and
forward-on-reverse in Fig. 9.2 and Fig. 9.3 respectively. By applying
reverse mode on reverse mode, at each fan-in operation sk−1 , rk 7→
∂fk (sk−1 )[rk ], the reverse mode on ∇f branches out in two paths that
are later merged by a sum. By applying forward mode on top of reverse
mode, the flow of computations simply follows the one of ∇f .
With this in mind, following a similar calculation as for Table 8.1,
we obtain the following results. We assume that each ∂fk (sk−1 ) is a
dense linear operator, so that its application has the same cost as a
matrix-vector multiplication. For the memory complexity, we consider
that the inputs of each operation is saved to compute the required
190 Second-order automatic differentiation

Gradient computation HVP computations


by reverse mode auto-diff by reverse mode on top of reverse mode

Figure 9.2: Computation graph for computing the HVP ∇2 f (x)[v] by using reverse
mode on top of reverse mode. As the computation graph of ∇f induces fan-in
operations sk−1 , rk 7→ ∂fk (sk−1 )[rk ], the reverse mode applied on ∇f induces
branching of the computations at each such node.

Gradient computation HVP computations


by reverse mode auto-diff by forward mode on top of reverse mode

Figure 9.3: Computation graph for computing the HVP ∇2 f (x)[v] by using forward
mode on top of reverse mode. The forward mode naturally follows the computations
done for the gradient, except that it passes through the derivatives of the intermediate
operations.
9.1. Hessian-vector products 191

derivatives in the backward passes.

1. Reverse on reverse: O(KP 2 ) time and O(KP ) space.

2. Forward on reverse: O(KP 2 ) time and O(KP ) space.

3. Reverse on forward: O(KP 2 ) time and O(KP ) space.

4. Forward on forward: O(KP 3 ) time and O(3P ) space for the


P JVPs with e1 , . . . , eP .

We see that, for chains of functions, “reverse on reverse”, “forward


on reverse” and “reverse on forward” all have similar time complexities
up to some constant factors. Using reverse mode on top of reverse
mode requires storing the information backpropagated, i.e., the rk
(resp. the information forwarded, i.e., the tk in Fig. 8.1), to perform
the final reverse pass. By using forward mode on top of reverse mode,
this additional cost is not incurred, making it slightly less memory
expensive. In addition, reverse mode on top of reverse mode induces a
few additional summations due to the branching and merge operations
depicted in Fig. 9.2. The same holds when using reverse on top of
forward as we cannot avoid fan-in operations (this time of the form
sk−1 , tk−1 7→ ∂fk (sk−1 )[tk−1 ]). Unfortunately, “forward on forward” is
prohibitively expensive.
To summarize, among the four approaches presented to compute
HVPs, the forward-over-reverse mode is a priori the most preferable
in terms of computational and memory complexities. Note, however,
that computations of higher derivatives can benefit from dedicated
autodiff implementations such as Taylor mode autodiff, that do not
merely compose forward and reverse modes. For general functions f , it
is reasonable to benchmark the first three methods to determine which
method is the best for the function at hand.
192 Second-order automatic differentiation

9.2 Gauss-Newton matrix

9.2.1 An approximation of the Hessian


The Hessian matrix ∇2 L(w) of a function L : W → R is often used to
construct a quadratic approximation of L(w),
1
L(w + v) ≈ ⟨∇L(w), v⟩ + ⟨v, ∇2 L(w)v⟩.
2
Unfortunately, when L is nonconvex, ∇2 L(w) is typically an indefinite
matrix, which means that the above approximation is a nonconvex
quadratic w.r.t. v. For instance, if L = ℓ ◦ f with ℓ convex, then L is
convex if f is linear, but it is typically nonconvex if f is nonlinear. The
(generalized) Gauss-Newton matrix is a principled alternative to the
Hessian, which is defined for L := ℓ ◦ f .

Definition 9.1 (Gauss-Newton matrix). Given a differentiable func-


tion f : W → M and a twice differentiable function ℓ : M → R, the
(generalized) Gauss-Newton matrix of the composition L = ℓ ◦ f
evaluated at a point w ∈ W is defined as

∇2GN (ℓ ◦ f )(w) := ∂f (w)∗ ∇2 ℓ(f (w))∂f (w).

As studied in Section 17.2, the Gauss-Newton matrix is a key ingre-


dient of the Gauss-Newton method. An advantage of the Gauss-Newton
matrix is its positive semi-definiteness provided that ℓ is convex.

Proposition 9.1 (Positive semi-definiteness of the GN matrix). If ℓ is


convex, then ∇2GN (ℓ ◦ f )(w) is positive semi-definite for all f .

This means that the approximation


1
L(w + v) ≈ ⟨∇L(w), v⟩ + ⟨v, ∇2GN L(w)v⟩
2
is a convex quadratic w.r.t. v.
Using the chain rule, we find that the Hessian of L = ℓ◦f decomposes
into the sum of two terms (see also Proposition 9.7).
9.2. Gauss-Newton matrix 193

Proposition 9.2 (Approximation of the Hessian). For f differentiable


and ℓ twice differentiable, we have

∇2 (ℓ ◦ f )(w) = ∂f (w)∗ ∇2 ℓ(f (w))∂f (w) + ∂ 2 f (w)∗ [∇ℓ(f (w))]


Z
= ∇2GN (ℓ ◦ f )(w) + ∇j ℓ(f (w))∇2 fj (w).
X

j=1

If f is linear, then the Hessian and Gauss-Newton matrices coincide,

∇2 (ℓ ◦ f )(w) = ∇2GN (ℓ ◦ f )(w).

The Gauss-Newton operator ∇2GN (ℓ ◦ f ) can therefore be seen as an


approximation of the Hessian ∇2 (ℓ ◦ f ), with equality if f is linear.

9.2.2 Gauss-Newton chain rule

A chain rule for computing the Hessian of a composition of two functions


is presented in Proposition 9.7, but the formula is relatively complicated,
due to the cross-terms. In contrast, a Gauss-Newton chain rule is
straightforward.

Proposition 9.3 (Gauss-Newton chain rule).

∇2GN (ℓ ◦ f ◦ g)(w) = ∂g(w)∗ ∇2GN (ℓ ◦ f )(g(w))∂g(w).

9.2.3 Gauss-Newton vector product

As for the Hessian, we rarely need to materialize the full Gauss-Newton


matrix in memory. Indeed, we can define the Gauss-Newton vector
product (GNVP), a linear map for a direction v ∈ W, as

∇2GN (ℓ ◦ f )(w)[v] := ∂f (w)∗ ∇2 ℓ(f (w))∂f (w)v, (9.1)

where ∇2 ℓ(θ)u is the HVP of ℓ, a linear map from M to M. The


GNVP can be computed using the JVP of f , the HVP of ℓ and the
VJP of f . Instantiating the VJP requires 1 forward pass through f ,
from which we get both the value f (w) and the adjoint linear map
u 7→ (∂f (w)∗ u). Evaluating the VJP requires 1 backward pass through
194 Second-order automatic differentiation

f . Evaluating the JVP requires 1 forward pass through f . In total,


evaluating v 7→ ∇2GN (ℓ ◦ f )(w)v therefore requires 2 forward passes and
1 backward pass through f .

9.2.4 Gauss-Newton matrix factorization

In this section, we assume W ⊆ RP and M ⊆ RM . When ℓ is convex,


we know that the Gauss-Newton matrix is positive semi-definite and
therefore it can be factorized into ∇2GN (ℓ ◦ f )(w) = V V ⊤ for some
V ∈ RP ×R , where R ≤ min{P, M } is the rank of the matrix. Such a
decomposition can actually be computed easily from a factorization of
the Hessian of ℓ. For instance, suppose we know the eigendecomposition
of the Hessian of ℓ, ∇2 ℓ(f (w)) = M j=1 λi ui ui , where the ui are the

P

eigenvectors and the λi ≥ 0 are the eigenvalues (which we know are


non-negative due to positive semidefiniteness). Then, the Gauss-Newton
matrix can be decomposed as

M
∇2GN (ℓ ◦ f ) = λi ∂f (w)∗ ui u⊤
i ∂f (w)

X

j=1
M p  p ⊤
= λi ∂f (w)∗ ui λi ∂f (w)∗ ui
X

j=1
M
= vi vi⊤ where vi := λi ∂f (w)∗ ui .
X p

j=1

Stacking the vectors vi into a matrix V = (v1 , . . . , vM ), we recover


the factorization ∇2GN (ℓ ◦ f )(w) = V V ⊤ . To form this decomposition,
we need to perform the eigendecomposition of ∇2 ℓ(f (w)) ∈ RM ×M ,
which takes O(M 3 ) time. We also need M calls to the VJP of f at w.
Compared to the direct implementation in Eq. (9.1), the factorization,
once computed, allows us to compute the Gauss-Newton vector product
(GNVP) as ∇2GN (ℓ ◦ f )(w)[v] = V V ⊤ v. The factorization only requires
P × M memory, while the direct implementation in Eq. (9.1) requires
us to maintain the intermediate computations of f . The computation-
memory trade-offs therefore depend on the function considered.
9.3. Fisher information matrix 195

9.2.5 Stochastic setting

Suppose the objective function is of the form

L(w; x, y) := ℓ(f (w; x); y).

With some slight abuse of notation, we then have that the Gauss-Newton
matrix associated with a pair (x, y) is

∇2GN L(w; x, y) := ∂f (w; x)∗ ∇2 ℓ(θ; y)∂f (w; x).

Given a distribution ρ over (x, y) pairs, the Gauss-Newton matrix


associated with the averaged loss

L(w) := EX,Y ∼ρ [L(w; X, Y )]

is then h i
∇2GN L(w) = EX,Y ∼ρ ∇2GN L(w; X, Y ) .

9.3 Fisher information matrix

9.3.1 Definition using the score function

The Fisher information is a way to measure the amount of information


in a random variable S.

Definition 9.2 (Fisher information matrix). The Fisher informa-


tion matrix, or Fisher for short, associated with the negative
log-likelihood L(w; S) = − log qw (S) of a probability distribution
qw with parameters w is the covariance of the gradients of L at w
for S distributed according to qw ,

∇2F L(w) := ES∼qw [∇L(w; S) ⊗ ∇L(w; S)]


= ES∼qw [∇w log qw (S) ⊗ ∇w log qw (S)].

The gradient ∇w log qw (S) is known as the score function.

As studied in Section 17.3, the Fisher information matrix is a key


ingredient of the natural gradient descent method.
196 Second-order automatic differentiation

9.3.2 Link with the Hessian


Provided that the probability distribution is twice differentiable w.r.t.
w with integrable second derivatives, the Fisher information matrix can
also be expressed as the Hessian of the negative log-likelihood (Amari,
1998; Martens, 2020).

Proposition 9.4 (Connection with the Hessian). The Fisher infor-


mation matrix of the negative log-likelihood L(w; S) = − log qw (S)
satisfies

∇2F L(w) = ES∼qw [∇2 L(w; S)] = ES∼qw [−∇2w log qw (S)].

Remark 9.1 (Empirical Fisher). We emphasize that in the above


definitions, S is sampled from the model distribution qw , not from
the data distribution ρ. That is, we have

∇2F L(w) = ES∼qw [∇w log qw (S)∇w log qw (S)⊤ ]


̸= ES∼ρ [∇w log qw (S)∇w log qw (S)⊤ ]

The latter is sometimes called ambiguously the “empirical” Fisher,


though this name has generated confusion (Kunstner et al., 2019).

9.3.3 Equivalence with the Gauss-Newton matrix


So far, we discussed the Fisher information for a generic random variable
S ∼ qw . We now discuss the supervised probabilistic learning setting
where S = (X, Y ) and where, using the product rule of probability,
we define the PDF qw (X, Y ) := ρX (X)pθ (Y ), with the shorthand θ :=
f (w; X).

Proposition 9.5 (Fisher matrix in supervised setting). Suppose


(X, Y ) ∼ qw where the PDF of qw is qw (X, Y ) := ρX (X)pθ (Y ).
In that case, the Fisher information matrix of the negative log-
9.3. Fisher information matrix 197

likelihood L(w; x, y) = − log qw (x, y) decomposes as,

∇2F L(w) = E(X,Y )∼qw [∇w log qw (X, Y ) ⊗ ∇w log qw (X, Y )]


= EX∼ρX [EY ∼pθ [∇w log pθ (Y ) ⊗ ∇w log pθ (Y )]]
h i
= EX∼ρX ∂f (w; X)∗ ∇2F ℓ(θ)∂f (w; X) ,

where we defined the shorthand θ := f (w; X) and where we defined


the negative log-likelihood loss ℓ(θ; Y ) := − log pθ (Y ).

When pθ is an exponential family distribution, we can show that the


Fisher information matrix and the Gauss-Newton matrix are equivalent.

Proposition 9.6 (Equivalence between Fisher and Gauss-Newton). If


pθ is an exponential family distribution, then

∇2F L(w) = EX∼ρX EY ∼pθ [∇L(w; X, Y ) ⊗ ∇L(w; X, Y )]


= EX∼ρX EY ∼pθ [∂f (w; X)∗ ∇ℓ(θ; Y ) ⊗ ∇ℓ(θ; Y )∂f (w; X)]
= EX∼ρX EY ∼pθ [∂f (w; X)∗ ∇2 ℓ(θ; Y )∂f (w; X)]
h i
= EX,Y ∼ρ ∇2GN L(w; X, Y ) ,

where ρX (x) :=
R
ρ(x, y)dy.

Proof. From Proposition 3.3, if pθ is an exponential family distribution,


∇2 ℓ(θ, y) is actually independent of y. Using Bartlett’s second identity
Eq. (12.3), we then obtain

∇2 ℓ(θ; ·) = EY ∼pθ [∇2 ℓ(θ; Y )]


= EY ∼pθ [∇2 ℓ(θ; Y )]
= EY ∼pθ [−∇2θ log pθ (Y )]
= EY ∼pθ [∇θ log pθ (Y ) ⊗ ∇θ log pθ (Y )]
= EY ∼pθ [∇ℓ(θ; Y ) ⊗ ∇ℓ(θ; Y )],

where we used · to indicate that the results holds for all y. Plugging the
result back in the Fisher information matrix concludes the proof.
198 Second-order automatic differentiation

9.4 Inverse-Hessian vector product

9.4.1 Definition as a linear map

We saw in Section 17.1 that Newton’s method uses iterations as

wt+1 = wt − ∇2 L(wt )−1 ∇L(wt ).

The inverse is well-defined if for example L is strictly convex. Otherwise,


we saw that some additional regularization can be added. Newton’s
method therefore requires to access inverse-Hessian vector products
(IHVPs), as defined below.

Definition 9.3 (Inverse-Hessian vector product). For a twice differ-


entiable function L : RP → R, we define the inverse-Hessian
Vector Product (IHVP) of L at w ∈ RP as the linear map

u 7→ ∇2 L(w)−1 u,

provided that it exists. In other words, it is the linear map which


to u associates v such that ∇2 L(w)v = u.

9.4.2 Implementation with matrix-free linear solvers

Numerous direct methods exist to compute the inverse of a matrix,


such as the Cholesky decomposition, QR decomposition and Gaussian
elimination. However, these algorithms require accessing elementary
entries of the matrix, while an autodiff framework gives access to the
Hessian through HVPs. Fortunately, there exists so-called matrix-free
algorithms, that can solve a linear system of equations

H[v] = u

by only accessing the linear map v 7→ H[v] for any v. Among such
algorithms, we have the conjugate gradient (CG) method, that applies
for H positive-definite, i.e., such that ⟨v, H[v]⟩ > 0 for all v ̸= 0, or the
generalized minimal residual (GMRES) method, that applies for
any invertible H. A longer list of solvers can be found in public software
such as SciPy (Virtanen et al., 2020). The IHVP of a strictly convex
9.4. Inverse-Hessian vector product 199

function (ensuring that the Hessian is positive definite) can therefore


be computed by instantiating CG on the HVP,

∇2 L(w)−1 u ≈ CG(∇2 L(w)[·], u).

Positive-definiteness of the Hessian is indeed guaranteed for strictly


convex functions for example, while for generic non-convex functions,
such property may be verified around a minimizer but not in general.
The conjugate gradient method is recalled in Algorithm 9.1 in its
simplest form. In theory, the exact solution of the linear system is found
after at most T = P iterations of CG, though in practice numerical
errors may prevent from getting an exact solution.

Algorithm 9.1 Conjugate gradient method


Inputs: linear map H[·] : RP → RP , target u ∈ RP , initialization
v0 (default 0), number of iterations T (default P ), target accuracy
ε (default machine precision)
1: r0 = u − H[v0 ]
2: p0 = r0
3: for t = 0, . . . T do
4: αt = ⟨p⟨r t ,rt ⟩
t ,H[pt ]⟩
5: vt+1 = vt + αt pt
6: rt+1 = rt − αt H[pt ]
7: if ⟨rt+1 , rt+1 ⟩ ≤ ε then break
8: βt = ⟨rt+1 ,rt+1 ⟩
⟨rt ,rt ⟩
9: pt+1 = rt+1 + βt pt
Output: vT , such that H[vT ] ≈ u

9.4.3 Complexity

For a given matrix H ∈ RP ×P , solving Hv = u can be done with


decomposition methods (LU, QR, Cholesky) in O(P 3 ) time. For matrix-
free methods such as CG or GMRES, the cost per iteration is O(P 2 ).
Since they theoretically solve the linear system in O(P ) iterations, the
cost to obtain an exact solution is theoretically the same, O(P 3 ).
200 Second-order automatic differentiation

However, CG or GMRES differ from decomposition methods in that


they are iterative methods, meaning that, at each iteration, they get
closer to a solution. Unlike decomposition methods, this means that we
can stop them before an exact solution is found. In practice, the number
of iterations required to find a good approximate solution depends on
the matrix. Well conditioned matrices require only few iterations. Badly
conditioned matrices lead to some numerical instabilities for CG, so
that more than P iterations may be needed to get a good solution. In
contrast, decomposition methods proceed in two steps: first they build
a decomposition of H at a cost of O(P 3 ), and second they solve a linear
system at a cost of O(P 2 ), by leveraging the structure. LU and QR
decompositions are known to be generally more stable and are therefore
often preferred in practice, when we can access entries of H at no cost.
If we do not have access to the Hessian H, but only to its HVP,
accessing entries of H comes at a prohibitive cost. Indeed, entries of H
can still be recovered from HVPs, since e⊤ i Hej = Hi,j , but accessing
each row or column of H costs one HVP (matrix-vector product). To
access the information necessary to use a decomposition method, we
therefore need P calls to HVPs before being able to actually compute
the solution. For the same number of calls, CG or GMRES will already
have found an approximate solution. In addition, a CG method does
not require to store any memory.

9.5 Second-order backpropagation

9.5.1 Second-order Jacobian chain rule

The essential ingredient to develop forward-mode and reverse-mode


autodiff hinged upon the chain rule for composed functions, h = g ◦ f .
For second derivatives, a similar rule can be obtained. To do so, we
slightly abuse notations and denote

∂ 2 h(w)∗ [u] := ∇2 ⟨h, u⟩(w) ∈ RP ×P ,

where h : RP → RQ , w ∈ RP , u ∈ RQ , and where we recall the


shorthand notation ⟨u, h⟩(w) := ⟨u, h(w)⟩. Moreover, we view the
above quantity as a linear map. Strictly speaking, the superscript ∗ is
9.5. Second-order backpropagation 201

not a linear adjoint anymore, since v1 , v2 7→ ∂ 2 h(w)[v1 , v2 ] is no longer


linear but bilinear. However, this superscript plays the same role as the
VJP, since it takes an output vector and returns the input derivatives
that correspond to infinitesimal variations along that output vector.

Proposition 9.7 (Hessian chain-rule). For two twice differentiable


functions f : RP → RM and g : RM → RQ , the second directional
derivative of the composition g ◦ f is a bilinear map from RP × RP
to RQ along input directions v1 , v2 ∈ RP of the form

∂ 2 (g ◦ f )(w)[v1 , v2 ] = ∂g(f (w))[∂ 2 f (w)[v1 , v2 ]]


+ ∂ 2 g(f (w))[∂f (w)[v1 ], ∂f (w)[v2 ]].

The Hessian of the composition g ◦ f along an output direction


u ∈ RQ is, seen as a linear map,

∂ 2 (g ◦ f )(w)∗ [u] = ∂ 2 f (w)∗ [∂g(f (w))∗ u] (9.2)


∗ 2 ∗
+ ∂f (w) ∂ g(f (w)) [u]∂f (w).

For the composition of f : RP → RM with a scalar-valued function


function ℓ : RM → R, we have in matrix form
M
∇2 (ℓ ◦ f )(w) = (∇ℓ(f (w)))j ∇2 fj (w)
X

j=1

+ ∂f (w)⊤ ∇2 ℓ(f (w))∂f (w).

Note that, while the Hessian is usually defined for scalar-valued


functions h : RP → R, the above definition is for a generalized notion
of Hessian that works for any function h : RP → RQ .
The Hessian back-propagation rule in Eq. (9.2) reveals two terms.
The first one ∂ 2 f (w)∗ [∂g(f (w))∗ u] simply computes the Hessian of
the intermediate function along the output direction normally back-
propagated by a VJP. The second term ∂f (w)∗ ∂ 2 g(f (w))∗ [u]∂f (w)
shows how intermediate first-order variations influence second order
derivatives of the output.
202 Second-order automatic differentiation

Example 9.1 (Composition with an elementwise nonlinear function).


Consider the element-wise application of a twice differentiable
scalar-valued function f (x) = (f (xi ))M
i=1 followed by some twice
differentiable function ℓ. Note that ∇ fi (x) = f ′′ (xi )ei e⊤
2
i . Hence,
the Hessian of the composition reads
M
∇2 (ℓ ◦ f )(x) = (∇ℓ(f (x)))i f ′′ (xi )ei e⊤
X
i
i=1
+ diag(f ′ (x))∇2 ℓ(f (x)) diag(f ′ (x))
= diag(∇ℓ(f (w)) ⊙ f ′′ (x))
+ ∇2 ℓ(f (x)) ⊙ (f ′ (x)f ′ (x)⊤ ),

where f ′ (x) := (f ′ (xi ))M


i=1 and f (x) := (f (xi ))i=1 .
′′ ′′ M

Example 9.2 (Hessian of the composition with a linear function). Consider


a linear function f (W ) = W x, for W ∈ RM ×D , composed with
some twice differentiable function ℓ : RM → R. From Proposi-
tion 9.7, we get, in terms of linear maps,

∇2 (ℓ ◦ f )(W ) = ∂f (W )∗ ∇2 ℓ(f (W ))∂f (W ).

As already noted in Section 2.3, we have that ∂f (W )[V ] = V x


and ∂f (W )∗ [u] = ux⊤ . Hence, the Hessian seen as a linear map
reads

∇2 (ℓ ◦ f )(W )[V ] = ∂f (W )∗ [∇2 ℓ(f (W ))[∂f (W )[V ]]] = HV xx⊤ ,

where H := ∇2 ℓ(f (W )).

9.5.2 Computation chains

For a simple computation chain f = fK ◦ . . . ◦ f1 as in Section 8.1, the


formula derived in Proposition 9.7 suffices to develop an algorithm that
backpropagates the Hessian, as shown in Algorithm 9.2. Compared to
Algorithm 8.2, we simply backpropagate both the vectors rk and the
matrices Rk using intermediate first and second derivatives.
9.5. Second-order backpropagation 203

Algorithm 9.2 Hessian backprop for computation chains


Functions: f := fK ◦ . . . ◦ f1 ,
Inputs: input x, output direction u
1: Initialize and store s0 := x ▷ Forward pass
2: for k := 1, . . . , K do
3: Compute and store sk := fk (sk−1 )
4: Initialize rK := ∇ℓ(sK ), RK := ∇2 ℓ(sK ) ▷ Backward pass
5: for k := K, . . . , 1 do
6: Compute rk−1 := ∂fk (sk−1 )∗ [rk ]
7: Compute Rk−1 := ∂ 2 fk (sk−1 )∗ [rk ] + ∂fk (sk−1 )∗ Rk ∂fk (sk−1 )
8: Release sk−1 from memory
Outputs: ℓ(f (x)) = ℓ(sK ), ∇(ℓ ◦ f )(x) = r0 , ∇2 (ℓ ◦ f )(x) = R0

9.5.3 Fan-in and fan-out


For generic computation graphs (see Section 8.3), we saw that multi-
input functions (fan-in) were crucial. For Hessian backpropagation in
computation graphs, we therefore need to develop a similar formula.

Proposition 9.8 (Hessian chain-rule for fan-in). Consider n+1 twice


differentiable functions f1 , . . . , fn and g with fi : RP → RMi and
g : RM1 × . . . × RMn → RQ . The Hessian of g ◦ f for f (w) =
(f1 (w), . . . , fn (w)) along an output direction u ∈ RQ is given by
n
∂ 2 (g ◦ f )(w)∗ [u] = ∂ 2 fi (w)∗ [∂i g(f (w))∗ [u]]
X

i=1
n
+ ∂fi (w)∗ ∂i,j g(f (w))∗ [u]∂fj (w).
X
2

i,j=1

The gradient backpropagation expression for fan-in is simple because


the functions fi are not linked by any path. In contrast, the Hessian
backpropagation involves cross-product terms
∂fi (w)∗ ∂i,j
2 g(f (w))∗ [u]∂f (w) for i ̸= j. The nodes associated to the
j
fi computations cannot be treated independently anymore.
On the other hand, developing a backpropagation rule for fan-out
does not pose any issue, since each output function can be treated
204 Second-order automatic differentiation

independently.

Proposition 9.9 (Hessian chain-rule for fan-out). Consider n+1 twice


differentiable functions g1 , . . . , gn and f with gi : RM → RQi and
f : RP → RM . The Hessian of g ◦ f for g(w) = (g1 (w), . . . , gn (w))
along a direction u = (u1 , . . . , un ) ∈ RQ1 × . . . × RQn is given by
n

∂ (g ◦ f )(w) [u] = ∂ 2 f (w)∗ [∂gi (f (w))∗ [ui ]]
X
2

i=1
n
+ ∂f (w)∗ ∂ 2 gi (f (w))∗ [u]∂f (w).
X

i=1

9.6 Block diagonal approximations

Rather than computing the whole Hessian or Gauss-Newton matrices,


we can consider computing block-diagonal or diagonal approximations,
which are easier to invert. The approximation rules we present in this
section build upon the Hessian chain rule studied in Section 9.5.

9.6.1 Feedforward networks

Recall the definition of a feedforward network:

s0 := x
sk := fk (sk−1 , wk ) ∀k ∈ {1, . . . , K}
f (x, w) := sK ,

where w := (w1 , . . . , wK ). Rather than computing the entire Hessian of


ℓ ◦ f w.r.t. w, we can compute the Hessians w.r.t. each set of parameters
wk . For the case of computation chains, the Hessian backpropagation
recursion we used in Algorithm 9.2 was

Rk−1 := ∂ 2 fk (sk−1 )∗ [rk ] + ∂fk (sk−1 )∗ Rk ∂fk (sk−1 ).


9.6. Block diagonal approximations 205

Extending this recursion to the feedforward network case, we obtain,


starting from rK := ∇ℓ(sK ) and RK := ∇2 ℓ(sK ),
rk−1 := ∂fk (sk−1 , wk )∗ [rk ]
!
Rk−1 ∼
:= ∂ 2 fk (sk−1 , wk )∗ [rk ]
∼ Hk
+ ∂fk (sk−1 , wk )∗ Rk ∂fk (sk−1 , wk ),
where we used ∼ to indicate that these blocks are not used. The Hessians
w.r.t each set of parameters are then
R0 = ∇2xx (ℓ ◦ f )(x, w)
H1 = ∇2w1 w1 (ℓ ◦ f )(x, w)
..
.
HK = ∇2wK wK (ℓ ◦ f )(x, w)).
The validity of this result stems from the fact that we can view the
Hessian w.r.t. wk as computing the Hessian w.r.t. wk of
f˜K ◦ . . . ◦ f˜k+1 ◦ fk (sk−1 , wk )
where f˜i := fi (·, wk ), for i ∈ {k + 1, . . . , K}. As the computations of
the block-wise Hessians share most of the computations, they can be
evaluated in a single backward pass just as the gradients.

Example 9.3 (Block-wise computation of the Gauss-Newton matrix).


Our blockwise backpropagation scheme can readily be adapted for
the Gauss-Newton matrix as
!
Rk−1 ∼
:= ∂fk (sk−1 , wk )∗ Rk ∂fk (sk−1 , wk ),
∼ Gk

starting from RK := ∇2 ℓ(sK ). The outputs R0 , G0 , . . . , GK give a


block-wise approximation of the Gauss-Newton matrix.
Now, consider a simple multilayer perceptron such that

fk (sk−1 , wk ) := a(Wk sk−1 ) with wk := vec(Wk )

Using Example 9.2 and Example 9.1 adapted to the Gauss-Newton


206 Second-order automatic differentiation

matrix, we can compute the block-wise decomposition of the Gauss-


Newton matrix as, for k = K, . . . , 1,

Rk−1 := Wk⊤ Jk Wk
Jk := Rk ⊙ (a′ (Wk sk−1 )a′ (Wk sk−1 )⊤ )
Gk := Jk ⊗ sk−1 s⊤
k−1

starting from RK := ∇2 ℓ(sK ). The outputs G1 , . . . , GK correspond


to the block-wise elements of the Gauss-Newton matrix of f for the
vectorized weights w1 , . . . , wK . Similar computations were done in
KFRA (Botev et al., 2017) and BackPack (Dangel et al., 2019).

9.6.2 Computation graphs


For generic computation graphs, consider a function f (x, w) defined
by, denoting i1 , . . . , ipk := pa(k),
sk := fk (si1 , . . . , sipk ) ∀k ∈ {1, . . . , K}
such that f (x, w) = sK , and k is following a topological ordering of the
graph (see Section 8.3). We can consider the following backpropagation
scheme, for k = K, . . . , 1 and j ∈ pa(k)
rij ← rij + ∂j fk (si1 , . . . , sipk )∗ [rk ] (9.3)
Rij ← Rij + ∂jj
2
fk (si1 , . . . , sipk )∗ [rk ]
+ ∂j fk (si1 , . . . , sipk )∗ Rk ∂j fk (si1 , . . . , sipk ), (9.4)
starting from RK := and rK := ∇ℓ(sK ). Recall that for
∇2 ℓ(sK )
multiple inputs, the chain-rule presented in Proposition 9.8 involves
the cross-derivatives. For this reason the back-propagation scheme
in Eq. (9.3) only computes an approximation. For example, one can
verify that using Eq. (9.3) to compute the Hessian of ℓ(f1 (w), f2 (w))
does not provide an exact expression for the Hessian of f . This scheme
is easy to implement and may provide a relevant proxy for the Hessian.

9.7 Diagonal approximations

Similarly to the idea of designing a backpropagation scheme that ap-


proximates blocks of the Hessian, we can design a backpropagation
9.7. Diagonal approximations 207

scheme that approximates the diagonal of the Hessian. The approach


was originally proposed by Becker and Le Cun (1988) for feedforward
networks, but our exposition, new to our knowledge, has the benefit
that it naturally extends to computational graphs, as we shall see.

9.7.1 Computation chains


The idea stems from modifying the Hessian backpropagation rule
in Proposition 9.7 to only keep the diagonal of the Hessian. Formally,
given a matrix M ∈ RD×D , we denote by diag(M ) = (Mii )D i=1 ∈ R
D

the vector of diagonal entries of M , and for a vector m ∈ R , we D

denote Diag(m) = D i=1 mi ei ei the diagonal matrix with entries mi .



P

For the backpropagation of the Hessian of ℓ ◦ fK ◦ . . . ◦ f1 , we see from


Algorithm 9.2 that diag(Hk−1 ) can be expressed in terms of Hk as
diag(Hk−1 ) = diag(∂ 2 fk (sk−1 )∗ rk )
+ diag(∂fk (sk−1 )∗ Hk ∂fk (sk−1 )).
Unfortunately, that recursion needs access to the whole Hessian Hk ,
and would therefore be too expensive. A natural idea is to modify the
recursion to approximate diag(Hk ) by backpropagating vectors:
dk−1 := diag(∂ 2 fk (sk−1 )∗ rk )
+ diag(∂fk (sk−1 )∗ Diag(dk )∂fk (sk−1 )).
The diagonal matrix Diag(dk ) serves as a surrogate for Hk . Each
iteration of this recursion can be computed in linear time in the output
dimension Dk since
Dk Dk
dk−1,i = fk,j (sk−1 ) + dk,j (∂i fk,j (sk−1 ))2 .
X X
2
rk,j · ∂i,i
j=1 j=1

To initialize the recursion, we can set dK := diag(∇2 ℓ(sK )). As an


alternative, as proposed by Elsayed and Mahmood (2022), if HK has
a simple form, we can use ∇2 ℓ(sK ) instead of Diag(dK ) at the first
iteration. This is the case for instance if fK is a cross-entropy loss. The
recursion is repeated until we obtain the approximate diagonal Hessian
d0 ≈ diag(∇2 (ℓ ◦ f )(x)). The gradients rk , needed to compute dk , are
computed along the way and the algorithm can therefore also return
r0 = ∇(ℓ ◦ f )(x).
208 Second-order automatic differentiation

9.7.2 Computation graphs


Although this diagonal approximation was originally derived for feed-
forward networks Becker and Le Cun (1988), it is straightforward to
generalize it to computation graphs. Namely, for a function f (x, w) de-
composed along a computation graph, we can backpropagate a diagonal
approximation in reverse topological order as

rij ← rij + ∂j fk (si1 , . . . , sipk )∗ [rk ]


dij ← dij + diag(∂jj
2
fk (si1 , . . . , sipk )∗ [rk ])
+ diag(∂j fk (si1 , . . . , sipk )∗ Diag(dk )∂j fk (si1 , . . . , sipk )), (9.5)

for j ∈ pa(k), starting from rK = ∇ℓ(sK ) and dK = diag(∇2 ℓ(sK ))


or Diag(dK ) = ∇2 ℓ(sK ). To implement such an algorithm, each ele-
mentary function in the computational graph needs to be augmented
with an oracle that computes the Hessian diagonal approximation of
the current function, given the previous ones. An example with MLPs
is presented in Example 9.4.

Example 9.4 (Hessian diagonal approximation for MLPs ). Consider


a multilayer perceptron

sk := ak (Wk sk−1 ) ∀k ∈ {1, . . . , K − 1}


f (w, x) := sK

starting from s0 = x. Here ak is the element-wise activation func-


tion (potentially the identity) and w encapsulates the weight ma-
trices W1 , . . . , WK . We consider the derivatives w.r.t. the flattened
matrices, so that gradients and diagonal approximations w.r.t. these
flattened quantities are vectors. The backpropagation scheme (9.5)
9.8. Randomized estimators 209

then reduces to, denoting tk = Wk sk−1 ,

rk−1 := Wk⊤ (a′ (tk ) ⊙ rk )


gk := vec((a′ (tk ) ⊙ rk )s⊤
k−1 )
δk := rk ⊙ a′′ (tk ) + dk ⊙ a′ (tk )2
 Dk
Dk
dk−1 := 
X
2
Wk,ij δk,j 
j=1 i=1
hk := vec(δk (s2k−1 )⊤ )

starting from rK = ∇ℓ(sK ) and, e.g., dK = diag(∇2 ℓ(sK )). The


algorithm returns g1 , . . . , gK as the gradients of f w.r.t. w1 , . . . , wK ,
with wi = vec(Wi ), and h1 , . . . , hK as the diagonal approximations
of the Hessian w.r.t. w1 , . . . , wK .

9.8 Randomized estimators

In this section, we describe randomized estimators of the diagonal of


the Hessian or Gauss-Newton matrices.

9.8.1 Girard-Hutchinson estimator

We begin with a generic estimator, originally proposed for trace es-


timation by Girard (1989) and extended by Hutchinson (1989). Let
A ∈ RP ×P be an arbitrary square matrix, whose matrix-vector product
(matvec) is available. Suppose ω ∈ RP is an isotropic random vector,
i.e., such that Eω∼p [ωω ⊤ ] = I. For example, two common choices are
the Rademacher distribution p = Uniform({−1, 1}) and the standard
normal distribution p = Normal(0, I). Then, we have

Eω∼p [⟨ω, Aω⟩] = tr(A).

Applications include generalized cross-validation, computing the Kullback-


Leibler divergence between two Gaussians, and computing the deriva-
tives of the log-determinant.
The approach can be extended (Bekas et al., 2007; Baston and
Nakatsukasa, 2022; Hallman et al., 2023) to obtain an estimator of the
210 Second-order automatic differentiation

diagonal of A,
Eω∼p [ω ⊙ Aω] = Diag(A),
where ⊙ denotes the Hadamard product (element-wise multiplication).
This suggests that we can use the Monte-Carlo method to estimate the
diagonal of A,
S
1X
Diag(A) ≈ ωi ⊙ Aωi ,
S i=1
with equality as S → ∞, since the estimator is unbiased. Since, as
reviewed in Section 9.1 and Section 9.2, we know how to multiply
efficiently with the Hessian and the Gauss-Newton matrices, we can
apply the technique with these matrices. The variance is determined
by the number S of samples drawn and therefore by the number of
matvecs performed. More elaborated approaches have been proposed to
further reduce the variance (Meyer et al., 2021; Epperly et al., 2023).

9.8.2 Bartlett estimator for the factorization

Suppose the objective function is of the form L(w; x, y) := ℓ(f (w; x); y)
where ℓ is the negative log-likelihood ℓ(θ; y) := − log pθ (y) of an ex-
ponential family distribution, and θ := f (w; x), for some network f .
We saw from the equivalence between the Fisher and Gauss-Newton
matrices in Proposition 9.6 (which follows from the Bartlett identity)
that

∇2GN L(w; x, ·) = EY ∼pθ [∂f (w; x)∗ ∇ℓ(θ; Y ) ⊗ ∇ℓ(θ; Y )∂f (w; x)]
= EY ∼pθ [∇L(w; x, Y ) ⊗ ∇L(w; x, Y )],

where · indicates that the result holds for any value of the second
argument. This suggests a Monte-Carlo scheme
S
1X
∇2GN L(w; x, ·) ≈ [∇L(w; x, yij ) ⊗ ∇L(w; x, yij )]
S j=1

where yi1 , . . . , yiS ∼ pθ and θ = f (w, x). In words, we can approxi-


mate the Gauss-Newton matrix with S gradient computations. This
factorization can also be used to approximate the GNVP in Eq. (9.1).
9.8. Randomized estimators 211

9.8.3 Bartlett estimator for the diagonal


Following a similar approach, we obtain
diag(∇2GN L(w; x, ·)) = EY ∼pθ [∇L(w; x, Y ) ⊙ ∇L(w; x, Y )],
where ⊙ indicates the element-wise (Hadamard) product. Using a Monte-
Carlo scheme, sampling yi1 , . . . , yiS from pθ , we therefore obtain
S
1X
diag(∇2GN L(w; x, ·)) ≈ ∇L(w; x, yij ) ⊙ ∇L(w; x, yij ),
S j=1
with equality when all labels in the support of pθ have been sampled.
That estimator, used for instance in (Wei et al., 2020, Appendix C.1.),
requires access to individual gradients evaluated at the sampled
labels. Another possible estimator of the diagonal is given by
1
diag(∇2GN L(w; x, ·))
S
S S
1X 1X
" #
=EY1 ,...,YS ∼pθ ∇ L(w; x, Yi ) ⊙ ∇ L(w; x, Yi ) .
S i=1 S i=1
Letting γi := ∇L(w; x, Yi ), this follows from
   

γj  = E  γi ⊙ γi +
X X X X
E  γi ⊙ γi ⊙ γj 
i j i i̸=j
" #
=E
X
γi ⊙ γi
i
where we used that E[γi ⊙ γj ] = E[γi ] ⊙ E[γj ] = 0 since γi and γj are
independent variables for i ̸= j and have zero mean, from Bartlett’s
first identity Eq. (12.2). We can then use the Monte-Carlo method to
obtain
   
S S
1 1X 1X
diag(∇2GN L(w; x, ·)) ≈ ∇ L(w; x, yij ) ⊙ ∇ L(w; x, yij ) ,
S S j=1
S j=1

with equality when all labels in the support of pθ have been sampled.
This estimator can be more convenient to implement, since it only needs
access to the gradient of the averaged losses. However, it may suffer
from higher variance. A special case of this estimator is used by Liu
et al. (2023), where they draw only one y for each x.
212 Second-order automatic differentiation

9.9 Summary

• By using a Hessian chain rule, we can develop a “Hessian backprop-


agation”. While it is reasonably simple for computation chains,
it becomes computationally prohibitive for computation graphs,
due to the cross-product terms occurring with fan-in.

• A better approach is to use Hessian-vector products (HVPs). We


saw that there are four possible methods to compute HVPs, but
the forward-over-reverse method is a priori the most efficient.
Similarly as for computing gradients, computing HVPs is only a
constant times more expensive than evaluating the function itself.

• The Gauss-Newton matrix associated with the composition ℓ ◦ f


can be seen as an approximation of the Hessian. It is a positive
semidefinite matrix if ℓ is convex, and can be used to build a
principled quadratic approximation of a function. It is equivalent
to the Fisher information matrix in the case of exponential families.
Gauss-Newton-vector products can be computed efficiently, like
HVPs.

• We also described other approximations, such as (block) diagonal


approximations, and randomized estimators.
10
Inference in graphical models as differentiation

A graphical model specifies how random variables depend on each


other and therefore determines how their joint probability distribution
factorizes. In this chapter, we review key concepts in graphical models
and how they relate to differentiation, drawing in the process analogies
with computation chains and computation graphs.

10.1 Chain rule of probability

The chain rule of probability is a fundamental law in probability theory


for computing the joint probability of events. In the case of only two
events A1 and A2 , it reduces to the product rule

P(A1 ∩ A2 ) = P(A2 |A1 )P(A1 ).

For two discrete random variables S1 and S2 , using the events A1 :=


{S1 = s1 } and A2 := {S2 = s2 }, the product rule becomes

P(S1 = s1 , S2 = s2 ) = P(S2 = s2 |S1 = s1 )P(S1 = s1 ).

More generally, using the product rule, we have for K events

P (A1 ∩ . . . ∩ AK ) = P (AK | A1 ∩ . . . ∩ AK−1 ) P (A1 ∩ . . . ∩ AK−1 ) .

213
214 Inference in graphical models as differentiation

Applying the product rule one more time, we have


P (A1 ∩ . . . ∩ AK−1 ) = P (AK−1 | A1 ∩ . . . ∩ AK−2 ) P (A1 ∩ . . . ∩ AK−2 ) .
Repeating the process recursively, we arrive at the chain rule of
probability
K
P (A1 ∩ . . . ∩ AK ) = P(Aj | A1 ∩ · · · ∩ Aj−1 )
Y

j=1
 
K j−1
=
Y \
P Aj Ai  .
j=1 i=1

For K discrete random variables Sj , using the events Aj := {Sj = sj },


the chain rule of probability becomes
K
P (S1 = s1 , . . . , SK = sK ) = P(Sj = sj | S1 = s1 , . . . , Sj−1 = sj−1 ).
Y

j=1

Importantly, this factorization holds without any independence assump-


tion on the variables S1 , . . . , SK . We can further simplify this expression
if we make conditional independence assumptions.

10.2 Conditional independence

We know that if two events A and B are independent, then


P(A|B) = P(A).
Similarly, if two random variables S1 and S2 are independent, then
P(S2 = s2 |S1 = s1 ) = P(S2 = s2 ).
More generally, if we work with K variables S1 , . . . , SK , some variables
may depend on each other, while others may not. To simplify the
notation, given a set C, we define the shorthands
SC := (Si : i ∈ C)
sC := (si : i ∈ C).
We say that a variable Sj is independent of SD conditioned on SC ,
with C ∩ D = ∅, if for any sj , sC , sD
P(Sj = sj | SC = sC , SD = sD ) = P(Sj = sj | SC = sC ).
10.3. Inference problems 215

10.3 Inference problems

10.3.1 Joint probability distributions


We consider a collection of K variables s := (s1 , . . . , sK ), potentially
ordered or unordered. Each s belongs to the Cartesian product
S := S1 × · · · × SK . Throughout this chapter, we assume that the sets
Sk are discrete for concreteness, with Sk := {v1 , . . . , vMk }. Note that
because Sk is discrete, we can always identify it with {1, . . . , Mk }. A
graphical model specifies a joint probability distribution
P(S = s) = P(S1 = s1 , . . . , SK = sK )
= p(s)
= p(s1 , . . . , sK ),
where p is the probability mass function of the joint probability distribu-
tion. Summing over the Cartesian product of all possible configurations,
we obtain
p(s) = p(s1 , . . . , sK ) = 1.
X X

s∈S s1 ,...,sK ∈S
As we shall see, the graph of a graphical model encodes the dependen-
cies between the variables (S1 , . . . , SK ) and therefore how their joint
distribution factorizes. Given access to a joint probability distribution,
there are several inference problems one typically needs to solve.

10.3.2 Likelihood
A simple task is to compute the likelihood of some observations
s = (s1 , . . . , sK ),
P(S1 = s1 , . . . , Sk = sk ) = p(s1 , . . . , sk ).
It is also common to compute the log-likelihood,
log P(S1 = s1 , . . . , Sk = sk ) = log p(s1 , . . . , sk ).

10.3.3 Maximum a-posteriori inference


Another common task is to compute the most likely configuration,
arg max p(s1 , . . . , sK ).
s1 ∈S1 ,...,sK ∈SK
216 Inference in graphical models as differentiation

This is the mode of the joint probability distribution. This is also known
as maximum a-posteriori (MAP) inference in the literature (Wainwright
and Jordan, 2008).

10.3.4 Marginal inference


The operation of marginalization consists in summing (or integrat-
ing) over all possible values of a given variable in a joint probability
distribution. This allows us to compute the marginal probability of
the remaining variables. For instance, we may want to marginalize all
variables but Sk = sk . To do so, we define the Cartesian product
Ck (sk ) := S1 × · · · × Sk−1 ×{sk } × Sk+1 × · · · × SK . (10.1)
| {z } | {z }
Ak−1 Bk+1

Summing over all variables but Sk , we obtain the marginal probability


of Sk = sk as
P(Sk = sk ) = p(s1 , . . . , sK )
X

s1 ,...,sK ∈Ck (sk )

= p(s1 , . . . , sK )
X X

s1 ,...,sk−1 ∈Ak−1 sk+1 ,...,sK ∈Bk+1

Defining similarly
Ck,l (sk , sl ) := S1 × · · · × {sk } × · · · × {sl } × · · · × SK ,
we obtain
P(Sk = sk , Sl = sl ) = p(s1 , . . . , sK ).
X

s1 ,...,sK ∈Ck,l (sk ,sl )

In particular, we may want to compute the marginal probability of two


consecutive variables, P(Sk−1 = sk−1 , Sk = sk ).

10.3.5 Expectation, convex hull, marginal polytope


Another common operation is to compute the expectation of ϕ(S) under
a distribution p. It is defined by
µ := ES∼p [ϕ(S)] =
X
p(s)ϕ(s) ∈ M
s∈S
10.3. Inference problems 217

For the expectation under pθ , we write

µ(θ) := ES∼pθ [ϕ(S)] = pθ (s)ϕ(s) ∈ M.


X

s∈S

In exponential family distributions (Section 3.4), the function ϕ is called


a statistic. It decomposes as

ϕ(s) := (ϕC (sC ))C∈C ,

where C ⊆ [K]. Intuitively, ϕ(s) can be thought as an encoding or


embedding of s (a potentially discrete object such as a sequence of
integers) in a vector space. Under this decomposition, we can also
compute
µC := ES [ϕC (SC )] = p(s)ϕC (sC ).
X

s∈S

Convex hull

The mean µ belongs to the convex hull of ϕ(S) := {ϕ(s) : s ∈ S},


( )
M := conv(ϕ(S)) := p(s)ϕ(s) : p ∈ P(S) ,
X

s∈S

where P(S) is the set of all possible probability distributions over S. In


other words, M is the set of all possible convex combinations of ϕ(s)
for s ∈ S. The vertices of M are all the s ∈ S.

Case of binary encodings: the marginal polytope

In the special case of a discrete set Sk = {v1 , . . . , vM } and of a binary


encoding (indicator function) ϕ(s), the set M is called the marginal
polytope (Wainwright and Jordan, 2008), because each point µ ∈
M contains marginal probabilities. To see why, consider the unary
potential
[ϕ(s)]k,i = [ϕk (sk )]i = I(sk = vi ) (10.2)
218 Inference in graphical models as differentiation

where I(p) := 1 if p is true, 0 otherwise. We then obtain the marginal


probability of Sk = vi ,

[µ]k,i = ES [ϕ(S)k,i ]
= ESk [ϕk (Sk )i ]
= ESk [I(Sk = vi )]
= P(Sk = sk )I(sk = vi )
X

sk ∈Sk

= P(Sk = vi ).

Likewise, consider the pairwise potential

[ϕ(s)]k,l,i,j = [ϕk,l (sk , sl )]i,j = I(sk = vi , sl = vj ). (10.3)

We then obtain the marginal probability of Sk = vi and Sl = vj ,

[µ]k,l,i,j = ES [ϕ(S)k,l,i,j ]
= ESk ,Sl [ϕk,l (Sk , Sl )i,j ]
= ESk ,Sl [I(Sk = vi , Sl = vj )]
= P(Sk = sk , Sl = sl )I(sk = vi , sl = vj )
X X

sk ∈Sk sl ∈Sl

= P(Sk = vi , Sl = vj ).

We can do the same with higher-order potential functions.

10.3.6 Complexity of brute force

Apart from computing the likelihood, which is trivial, computing the


marginal, mode and expectation by brute force takes O( K k=1 |Sk |) time.
Q

In particular, if |Sk | = M ∀k ∈ [K], brute force takes O(M K ) time.

10.4 Markov chains

In this section, we briefly review Markov chains. Our notation is chosen


to emphasize the analogies with computation chains.
10.4. Markov chains 219

10.4.1 The Markov property

When random variables are organized sequentially as S1 , . . . , SK ,


a simple example of conditional independence is when each variable
Sk ∈ Sk only depends on the previous variable Sk−1 ∈ Sk−1 , that is,

P(Sk = sk | Sk−1 = sk−1 , . . . , S1 = s1 ) = P(Sk = sk | Sk−1 = sk−1 )


:= pk (sk | sk−1 ),

A probability distribution satisfying the above is said to satisfy the


Markov property, and is called a Markov chain. A computation
chain is specified by the functions fk , that take sk−1 as input and
output sk . In analogy, a Markov chain is specified by the conditional
probability distributions pk of Sk given Sk−1 . We can then define the
generative process

S0 := s0
S1 ∼ p1 (· | S0 )
S2 ∼ p2 (· | S1 )
..
.
SK ∼ pK (· | SK−1 ).

Strictly speaking, we should write Sk | Sk−1 ∼ pk (· | Sk−1 ). We choose


our notation both for conciseness and for analogy with computation
chains. Furthermore, to simplify the notation, we assume without loss
of generality that S0 is deterministic (if this is not the case, we can
always move S0 to S1 and add a dummy variable as S0 ). That is,
P(S0 = s0 ) = p0 (s0 ) := 1 and S0 := {s0 }. This amounts to setting the
initial distribution of S1 as

P(S1 = s1 ) := P(S0 = s0 )P(S1 = s1 |S0 = s0 ) = P(S1 = s1 |S0 = s0 ).


...

220 Inference in graphical models as differentiation

...

... start ... end

...
...

Figure 10.1: Left: Markov chain. Right: Computation graph of the forward-
backward and the Viterbi algorithms: a lattice.
start ... end

We can then compute the joint probability of the Markov chain by


...

P(S1 = s1 , . . . , SK = sK ) = p(s1 , . . . , sK )
K
= P(Sk = sk | Sk−1 = sk−1 )
Y

k=1
K
= pk (sk | sk−1 ),
Y

k=1

where we left the dependence on s0 implicit, since p0 (s0 ) = 1. A Markov


chain with Sk = {1, 2, 3} is illustrated in Fig. 10.1. A chain defines
a totally ordered set {1, . . . , K}, since two nodes in the graph are
necessarily linked to each other by a path.

Example 10.1 (Chain of categorical distributions). Suppose our goal


is predict, from x ∈ X , a sequence of length K, where each Sk
belongs to Sk = {1, . . . , M }. In natural language processing, this
task is called sequence tagging. We can define

Sk ∼ Categorical(πk−1,k,Sk−1 )

where

πk−1,k,i := softargmax(θk−1,k,i ) ∈ △M
= (πk−1,k,i,j )M
j=1
θk−1,k,i := (θk−1,k,i,j )M
j=1 ∈ R
M

θk−1,k,i,j := fk−1,k (x, i, j, wk ) ∈ R.


10.4. Markov chains 221

We therefore have

P(Sk = j | Sk−1 = i) = pk (j | i)
= πk−1,k,i,j
= [softargmax(θk−1,k,i )]j
exp(θk−1,k,i,j )
=P
j ′ exp(θk−1,k,i,j ′ )

and

log P(Sk = j | Sk−1 = i) = log pk (j | i)


= θk−1,k,i,j − logsumexp(θk−1,k,i )
= θk−1,k,i,j − log exp(θk−1,k,i,j ′ ).
X

j′

We emphasize that because k − 1 and k are always consecutive,


the representation θk−1,k,i,j is inefficient; we could use θk,i,j instead.
Our notation is designed for consistency with Markov random fields.

10.4.2 Time-homogeneous Markov chains

A time-homogeneous discrete-time Markov chain corresponds to the


case when the distribution of Sk given Sk−1 is the same regardless of k:

p1 = · · · = pK = p.

The finite-space case corresponds to when each Sk ∈ S can take a


finite set of values S = {v1 , . . . , vM } and

P(Sk = vj | Sk−1 = vi ) = p(vj |vi ) = πi,j ,

where πi,j ∈ [0, 1] is the transition probability from vi to vj . Because


the set S = {v1 , . . . , vM } is discrete, we can always identify it with
{1, . . . , M }. That is, we can instead write

P(Sk = j | Sk−1 = i) = p(j|i) = πi,j .


222 Inference in graphical models as differentiation

10.4.3 Higher-order Markov chains

More generally, a nth -order Markov chain may depend, not only on the
last variable, but on the last n variables,

P(Sk = sk | Sk−1 = sk−1 , . . . , S1 = s1 )


=P(Sk = sk | Sk−1 = sk−1 , . . . , Sk−n = sk−n )
=pk (sk |sk−1 , . . . , sk−n ).

Autoregressive models such as Transformers can be seen as specifying


a higher-order Markov chain, with a context window of size n. The
larger context makes exact inference using dynamic programming com-
putationally intractable. This is why practitioners use beam search or
ancestral sampling (Section 10.5.3) instead.

10.5 Bayesian networks

In this section, we briefly review Bayesian networks. Our notation is


chosen to emphasize the analogies with computation graphs.

10.5.1 Expressing variable dependencies using DAGs

Markov chains and more generally higher-order Markov chains are a spe-
cial case of Bayesian network. Similarly to computation graphs reviewed
in Section 8.3, variable dependencies can be expressed using a directed
acyclic graph (DAG) G = (V, E), where the vertices V = {1, . . . , K}
represent variables and edges E represent variable dependencies. The set
{i1 , . . . , ink } = pa(k) ⊆ V, where nk := |pa(k)|, indicates the variables
Si1 , . . . , Sink that Sk depends on. This defines a partially ordered
set (poset). For notational simplicity, we again assume without loss of
generality that S0 is deterministic. A computation graph is specified
by functions f1 , . . . , fK in topological order. In analogy, a Bayesian
network is specified by conditional probability distributions pk of Sk
10.5. Bayesian networks 223

given Spa(k) . We can then define the generative process

S0 := s0
S1 ∼ p1 (· | S0 )
S2 ∼ p2 (· | Spa(2) )
..
.
SK ∼ pK (· | Spa(K) ).

Using the chain rule of probability and variable independencies expressed


by the DAG, the joint probability distribution is then (assuming a
topological order for S0 , S1 , . . . , SK )

P(S = s) := P(S1 = s1 , . . . , SK = sK )
K
= P(Sk = sk |Spa(k) = spa(k) )
Y

k=1
K
:= pk (sk |spa(k) )
Y

k=1

This representation is well suited to express causal relationships between


random variables.

10.5.2 Parameterizing Bayesian networks

In a Bayesian framework, observed data, latent variables, parameters


and noise variables are all treated as random variables. If the conditional
distribution pk associated to node k depends on some parameters, they
can be provided to pk as conditioning, using parent nodes.
A Bayesian network is specified by the conditional distributions pk .
Therefore, unlike computation graphs, there is no notion of function fk in
a Bayesian network. However, the root nodes of the Bayesian network can
be the output of a neural network. For instance, autoregressive models,
such as RNNs or Transformers, specify the conditional probability
distribution of a token given past tokens, and the chain rule of probability
is used to obtain a probability distribution over entire sequences.
224 Inference in graphical models as differentiation

10.5.3 Ancestral sampling


A major advantage of Bayesian networks is that, provided that each
conditional distribution pk is normalized, the joint distribution of S =
(S1 , . . . , SK ) is automatically normalized. This means that we can very
easily draw i.i.d. samples from the joint distribution, by following the
generative process: we follow the topological order k = 1, . . . , K and
on iteration k we draw a value sk ∼ pk (·|spa(k) ) conditioned on the
previous values spa(k) . This is known as ancestral sampling.

10.6 Markov random fields

10.6.1 Expressing factors using undirected graphs


A Markov random field (MRF), a.k.a. undirected graphical model,
specifies a distribution that factorizes as
1 Y
P(S = s) = p(s) := ψC (sC ),
Z C∈C

where C is the set of maximal cliques of G, that is, subsets of V that


are fully connected, Z is a normalization constant defined by

Z := ψC (sC ),
X Y

s∈S C∈C

and ψC : SC → R+ is a potential function (a.k.a. compatibility func-


tion), with SC := (Sj )j∈C . According to the Hammersley-Clifford theo-
rem, an MRF can be equivalently defined in terms of Markov properties;
we refer the interested reader to Wainwright and Jordan (2008). For the
sake of this chapter, the definition above is sufficient for our purposes.

Example 10.2 (Markov chains as Markov random fields). For a chain,


letting S = (S1 , . . . , SK ) and s = (s1 , . . . , sK ), recall that
K
P(S = s) = pk (sk | sk−1 ).
Y

k=1

This is equivalent to an MRF with Z = 1 (since a chain is auto-


10.6. Markov random fields 225

matically normalized),

C := {{0, 1}, {1, 2}, . . . , {K − 1, K}}

and with potential function

ψ{k−1,k} (sk−1 , sk ) := pk (sk |sk−1 ).

More generally, a Bayesian network can be similarly written as an


MRF by creating appropriate potential functions corresponding to
the parents of each node.

10.6.2 MRFs as exponential family distributions

Let us define the potential functions

ψC (sC ; θC ) := exp(⟨θC , ϕC (sC )⟩)

for some sufficient statistic function ϕC : SC → ΘC and parameters


θC ∈ ΘC . Then,

1 Y
pθ (s) := ψC (sC ; θC )
Z(θ) C∈C
1 Y
= exp(⟨θC , ϕC (sC )⟩)
Z(θ) C∈C
1
!
= exp ⟨θC , ϕC (sC )⟩
X
Z(θ) C∈C
1
= exp (⟨θ, ϕ(s)⟩)
Z(θ)
= exp (⟨θ, ϕ(s)⟩ − A(θ))
226 Inference in graphical models as differentiation

where

ϕ(s) := (ϕC (sC ))C∈C


θ := (θC )C∈C
Z(θ) := ψC (sC ; θC )
X Y

s∈S C∈C

= exp (⟨θ, ϕ(s)⟩)


X

s∈S
A(θ) := log Z(θ)

Therefore, for this choice of potential functions, we can view an MRF


as an exponential family distribution (Section 3.4) with natural pa-
rameters θ, sufficient statistic ϕ and log-partition function A(θ).

Example 10.3 (Ising model). The Ising model is a classical example


of MRF. Let Y = (Y1 , . . . , YM ) ∈ {0, 1}M be an unordered collec-
tion of binary variables Yi ∈ {0, 1}. This forms a graph G = (V, E),
where V = [M ] and E ⊆ V 2 , such that (i, j) ∈ E means that Yi in-
teracts with Yj . In statistical physics, Yi may indicate the presence
or absence of particles, or the orientation of magnets. In image
processing, Yi may represent a black and white pixel. In multi-label
classification, Yi may indicate the presence or absence of a label.
The probability of y = (y1 , . . . , yM ) ∈ {0, 1}M is then

P(Y = y) = pθ (y)
 

= exp  θi yi +
X X
θi,j yi yj − A(θ)
i∈V (i,j)∈E
!
= exp ⟨θC , ϕC (y)⟩ − A(θ) ,
X

C∈C

where C := V ∪ E and θ ∈ R|V|+|E| is the concatenation of (θi )i∈V


and (θi,j )(i,j)∈E . These models are also known as Boltzmann ma-
chines in a neural network context. MAP inference in general
Ising models is known to be NP-hard, but when the interaction
weights θi,j are non-negative, MAP inference can be reduced to
10.7. Inference on chains 227

graph cut algorithms (Greig et al., 1989). There are two ways the
above equation can be extended. First, we can use higher-order
interactions, such as yi yj yk for (i, j, k) ∈ V 3 . Second, we may want
to use categorical variables, which leads to the Potts model.

10.6.3 Conditional random fields


Conditional random fields (Lafferty et al., 2001; Sutton, McCallum,
et al., 2012) are a special case of Markov random field, in which a
conditioning variable is explicitly incorporated. For example, when
the goal is to predict a variable y conditioned on a variable x, CRFs
are defined as
1 Y
P(Y = y | X = x) = p(y | x) = ψC (yC , x).
Z(x) C∈C

Note that the potential functions ΨC are allowed to depend on the


whole x, as x is just a conditioning variable.

10.6.4 Sampling
Contrary to Bayesian networks, MRFs require an explicit normalization
constant Z. As a result, sampling from a distribution represented by a
general MRF is usually more involved than for Bayesian networks. A
commonly-used technique is Gibbs sampling.

10.7 Inference on chains

In this section, we review how to perform marginal inference and


maximum a-posteriori inference on joint distributions of the form
K
1 Y
p(s1 , . . . , sK ) = ψk (sk−1 , sk ),
Z k=1

where
K
Z := ψk (sk−1 , sk−1 )
XY

s∈S k=1
and where we used ψk as a shorthand for ψk−1,k , since k − 1 and k are
consecutive. As explained in Example 10.2, this also includes Markov
228 Inference in graphical models as differentiation

chains by setting

ψk (sk−1 , sk ) := pk (sk | sk−1 ),

in which case Z = 1.

10.7.1 The forward-backward algorithm

The key idea of the forward-backward algorithm is to use the distribu-


tivity of multiplication over addition to write

Z= ψ1 (s0 , s1 ) ψ2 (s1 , s2 ) · · · ψK (sK−1 , sK ).


X X X

s1 ∈S1 s2 ∈S2 sK ∈SK

We can compute these sums recursively, either forward or backward.


Recalling the definitions of Ak−1 and Bk+1 in Eq. (10.1), we define the
summations up to and down to k,
k
αk (sk ) := ψj (sj−1 , sj )
X Y

s1 ,...,sk−1 ∈Ak−1 j=1

= ψk (sk−1 , sk ) · · · ψ2 (s1 , s2 )ψ1 (s0 , s1 )


X X

sk−1 ∈Sk−1 s1 ∈S1


K
βk (sk ) := ψj (sj−1 , sj )
X Y

sk+1 ,...,sK ∈Bk+1 j=k+1

= ψk+1 (sk , sk+1 ) · · · ψK (sK−1 , sK ).


X X

sk+1 ∈Sk+1 sK ∈SK

We can compute the two quantities by recursing forward and backward

αk (sk ) = ψk (sk−1 , sk )αk−1 (sk−1 )


X

sk−1 ∈Sk−1

βk (sk ) = ψk+1 (sk , sk+1 )βk+1 (sk+1 )


X

sk+1 ∈Sk+1

where we defined the initializations

α1 (s1 ) := ψ1 (s0 , s1 ) ∀s1 ∈ S1


βK (sK ) := 1 ∀sK ∈ SK .
10.7. Inference on chains 229

The normalization term can then be computed by


Z= αK (sK )βK (sK ) = α1 (s1 )β1 (s1 )
X X

sK ∈SK s1 ∈S1
and the marginal probabilities by
1
P(Sk = sk ) = αk (sk )βk (sk )
Z
1
P(Sk−1 = sk−1 , Sk = sk ) = αk−1 (sk−1 )ψk (sk−1 , sk )βk (sk ).
Z
We can also compute the conditional probabilities by
P(Sk−1 = sk−1 , Sk = sk )
P(Sk = sk | Sk−1 = sk−1 ) =
P(Sk−1 = sk−1 )
αk−1 (sk−1 )ψk (sk−1 , sk )βk (sk )
=
αk−1 (sk−1 )βk−1 (sk−1 )
ψk (sk−1 , sk )βk (sk )
= .
βk−1 (sk−1 )
In practice, the two recursions are often implemented in the log-domain
for numerical stability,
log αk (sk ) = log exp(log ψk (sk−1 , sk ) + log αk−1 (sk−1 ))
X

sk−1 ∈Sk−1

log βk (sk ) = log exp(log ψk+1 (sk , sk+1 ) + log βk+1 (sk+1 )).
X

sk+1 ∈Sk+1

We recognize the log-sum-exp operator, which can be implemented


in a numerically stable way (Section 4.4.2). The overall dynamic
programming procedure, a.k.a. forward-backward algorithm (Baum
and Petrie, 1966; Rabiner, 1989), is summarized in Algorithm 10.1. We
notice that the forward and backward passes are actually independent
of each other, and can therefore be performed in parallel.

10.7.2 The Viterbi algorithm


Similarly, using the distributivity of multiplication over maximization,
K
max ψk (sk−1 , sk )
Y
s1 ∈S1 ,...,sK ∈SK
k=1
= max max ψK (sK−1 , sK ) . . . max ψ2 (s1 , s2 )ψ1 (s0 , s1 ).
sK ∈SK sK−1 ∈SK−1 s1 ∈S1
230 Inference in graphical models as differentiation

Algorithm 10.1 Marginal inference on a chain


Potential functions: ψ1 , . . . , ψK
Input: s0
1: Initialize α1 (s1 ) := ψ1 (s0 , s1 ) ∀s1 ∈ S1
2: for k := 2, . . . , K do ▷ Forward pass
3: for sk ∈ Sk do X
4: αk (sk ) := ψk (sk−1 , sk )αk−1 (sk−1 )
sk−1 ∈Sk−1

5: Initialize βK (sK ) := 1 ∀sK ∈ SK


6: for k := K − 1, . . . , 1 do ▷ Backward pass
7: for sk ∈ Sk do X
8: βk (sk ) := ψk+1 (sk , sk+1 )βk+1 (sk+1 )
sk+1 ∈Sk+1

Compute Z = αK (sK )βK (sK ) = αK (sK )


X X
9:
sK ∈SK sK ∈SK
Outputs: ∀k ∈ [K]:
P(Sk = sk ) = Z1 αk (sk )βk (sk )
P(Sk−1 = sk−1 , Sk = sk ) = Z1 αk−1 (sk−1 )ψk (sk−1 , sk )βk (sk )

Let us define for k ∈ [K]

δk (sk ) := max ψk (sk−1 , sk ) . . . max ψ2 (s1 , s2 )ψ1 (s0 , s1 ).


sk−1 ∈Sk−1 s1 ∈S1

We can compute these quantities recursively, since for k ∈ [K]

δk (sk ) = max ψk (sk−1 , sk )δk−1 (sk−1 ),


sk−1 ∈Sk−1

with δ1 (sk ) := ψ(s0 , sk ). We finally have


1
max p(s1 , . . . , sK ) = max δK (sK ).
s1 ∈S1 ,...,sK ∈SK Z sK ∈SK
In practice, for numerical stability, we often implement the forward
recursion in the log-domain. Using the fact that the logarithm is
monotonic, we indeed have for all k ∈ [K]

log δk (sk ) = max log ψk (sk−1 , sk ) + log δk−1 (sk−1 ).


sk−1 ∈Sk−1
10.8. Inference on trees 231

To enable efficient backtracking, during the forward pass, we compute


qk (sk ) := arg max ψk (sk−1 , sk )δk−1 (sk−1 )
sk−1 ∈Sk−1

which can be thought as backpointers from s⋆k to s⋆k−1 .


The resulting dynamic programming procedure, a.k.a. Viterbi algo-
rithm (Viterbi, 1967; Forney, 1973), is summarized in Algorithm 10.2.

Algorithm 10.2 MAP inference on a chain


Potential functions: ψ1 , . . . , ψK
Input: s0
1: Initialize δ1 (s1 ) := ψ1 (s0 , s1 ) ∀s1 ∈ S1
2: for k := 2, . . . , K do ▷ Forward pass
3: for sk ∈ Sk do
4: δk (sk ) := max ψk (sk−1 , sk )δk−1 (sk−1 )
sk−1 ∈Sk−1
5: qk (sk ) := arg max ψk (sk−1 , sk )δk−1 (sk−1 )
sk−1 ∈Sk−1

6: δ ⋆ := max δK (sK )
sK ∈SK
7: s⋆K := arg max δK (sK )
sK ∈SK
8: for k := K − 1, . . . , 1 do ▷ Backtracking
9: s⋆k := qk+1 (s⋆k+1 )
Outputs: max p(s1 , . . . , sK ) ∝ δ ⋆
s1 ∈S1 ,...,sK ∈SK
arg max p(s1 , . . . , sK ) = (s⋆1 , . . . , s⋆K )
s1 ∈S1 ,...,sK ∈SK

10.8 Inference on trees

More generally, efficient inference based on dynamic programming


can be performed when dependencies between variables are expressed
using a tree or polytree. The resulting marginal inference and MAP
inference algorithms are often referred to as the sum-product and
max-sum algorithms. The sum-product algorithm is also known as
belief propagation or message passing, since it can be interpreted
as propagating “local messages” through the graph. See for instance
(Wainwright and Jordan, 2008, Section 2.5.1) for more details.
232 Inference in graphical models as differentiation

10.9 Inference as differentiation

In this section, we review the profound connections between differenti-


ating the log-partition function of an exponential family distribution
on one hand, and performing marginal inference (as well as maximum
a-posteriori inference in the zero-temperature limit) on the other hand.

10.9.1 Inference as gradient of the log-partition


We first discuss a well-known fact in the graphical model literature: when
using a binary encoding as the sufficient statistic ϕ in an exponential
family distribution, the gradient ∇A(θ) of the log-partition A(θ) gathers
all the marginals (Wainwright and Jordan, 2008).
To see why, recall from Section 3.4 the definition of an exponential
family distribution
pθ (s) = h(s) exp [⟨θ, ϕ(s)⟩ − A(θ)]
and of its log-partition
A(θ) := log h(s) exp [⟨θ, ϕ(s)⟩] .
X

s∈S

From Proposition 3.2,


µ(θ) := ∇A(θ) = EY ∼pθ [ϕ(Y )] ∈ M.
Therefore, with the binary encodings in Eq. (10.2) and Eq. (10.3),
P(Sk = vi ) = [∇A(θ)]k,i
P(Sk = vi , Sl = vj ) = [∇A(θ)]k,l,i,j .
Put differently, if we have an efficient algorithm for computing A(θ), we
can perform reverse-mode autodiff on A(θ) to obtain ∇A(θ), and
therefore obtain the marginal probabilities. Following Section 8.3.3, the
complexity of computing all marginal probabilities is therefore roughly
the same as that of computing A(θ).
In the special case of chains, we obtain
1
P(Sk = vi ) = [∇A(θ)]k,i = αk (vi )βk (vi )
Z
1
P(Sk−1 = vi , Sk = vj ) = [∇A(θ)]k−1,k,i,j = αk−1 (vi )ψk (vi , vj )βk (vj ),
Z
10.9. Inference as differentiation 233

where we left the dependence of Z, α and β on θ implicit.


If we define Aε (θ) := εA(θ/ε), in the zero-temperature limit ε → 0,
we obtain that µ(θ) is a binary encoding of the mode, i.e., of the
maximum a-posteriori inference solution.
We now show i) how to unify the forward pass of the forward-
backward and Viterbi algorithms using semirings and softmax operators
ii) how to compute the gradient of the log-partition using backpropaga-
tion.

10.9.2 Semirings and softmax operators

The forward passes in the forward-backward and Viterbi algorithms are


clearly similar. In fact, they can be formally linked to each other using
semirings.

Definition 10.1 (Semiring). A semiring is a set K equipped with


two binary operations (⊕, ⊗) such that

• ⊗ is commutative and associative,

• ⊕ is associative and distributive over ⊕,

• ⊗ and ⊕ have identity element 0̄ and 1̄, respectively.

We use the notations ⊕, ⊗, 0̄ and 1̄ to clearly distinguish them from


the classical addition, multiplication, 0 and 1.
We recall the following laws for binary operations:

• Commutativity of ⊕: a ⊕ b = b ⊕ a,

• Associativity of ⊕: a ⊕ (b ⊕ c) = (a ⊕ b) ⊕ c,

• Distributivity of ⊗ over ⊕: a ⊗ (b ⊕ c) = (a ⊗ b) ⊕ (a ⊗ c).

A set equipped with a binary operation supporting associativity and an


identity element is called a monoid. A monoid such that every element
has an inverse element is called a group. The difference between a ring
and a semiring is that the latter only requires (K, ⊕) and (K, ⊗) to be
monoids, not groups.
234 Inference in graphical models as differentiation

Equipped with these definitions, we can interpret the forward passes


in the Viterbi and forward-backward algorithms as follows:

• the forward-backward algorithm in the exponential domain uses


the semiring R+ equipped with (+, ×) and identity elements (0, 1);

• the Viterbi algorithm in the log domain uses the semiring R


equipped with (max, +) and identity elements (−∞, 0);

• the forward-backward algorithm in the log domain uses the semir-


ing R equipped with (maxε , +) and identity elements (−∞, 0),
where we defined the soft max operator (log-add-exp)

maxε (a, b) := ε log((exp(a) + exp(b))/ε),

with ε := 1 by default.

It can be checked that indeed maxε is commutative, associative, and


addition is distributive over maxε . Its identity element is −∞. By
associativity,

maxε (a1 , maxε (a2 , a3 )) = logsumexpε (a1 , a2 , a3 )


= ε log exp(ai /ε).
X

In contrast, note that the sparsemax in Section 13.5 is not associative.


Thanks to associativity, we can introduce the shorthand notations

maxε f (v) := ε log exp(f (v)/ε) ∈ R.


X
v∈V v∈V

and
!
argmaxε f (v) := exp(f (v ′ )/ε)/ exp(f (v)/ε)
X
∈ P(V).
v∈V v∈V v ′ ∈V

Many algorithms can be generalized thanks to the use of semirings;


see among others Aji and McEliece (2000) and Mohri et al. (2008). The
distributive and associative properties play a key role in breaking down
large problems into smaller ones (Verdu and Poor, 1987).
10.9. Inference as differentiation 235

10.9.3 Inference as backpropagation

In this section, we show that, algorithmically, backtracking is recovered


as a special case of backpropagation. See also (Eisner, 2016; Mensch
and Blondel, 2018).
For notation simplicity, we assume S0 = {1} and Sk = {1, . . . , M }
for all k ∈ [K]. We focus on the case

log ψk (i, j) = ⟨θk , ϕk (i, j)⟩ = θk,i,j .

We also introduce the shorthands

a1,j := log α1 (j) = θ1,1,j


ak,j := log αk (j) = maxε θk,i,j + ak−1,i
i∈[M ]

and

qk,j := argmaxε θk,i,j + ak−1,i .


i∈[M ]

Our goal is to compute the gradient w.r.t. θ ∈ RK×M ×M of

log Z = A = maxε aK,j .


j∈[M ]

The soft argmax counterpart of this quantity is

Q := argmaxε aK,j ∈ △M ,
j∈[M ]

where we used P([M ]) = △M .


Computing the gradient of A is similar to computing the gradient
of a feedforward network, in the sense that θk influences not only ak
but also ak+1 , . . . , aK . Let us introduce the adjoint variable
∂A
rk,i := ,
∂ak,i
which we initialize as
∂A
rK,i = = Qi .
∂aK,i
236 Inference in graphical models as differentiation

Since θk,i,j directly influences ak,j , we have for k ∈ [K], i ∈ [M ] and


j ∈ [M ]
∂A
µk,i,j :=
∂θk,i,j
∂A ∂ak,j
= ·
∂ak,j ∂θk,i,j
= rk,j · qk,j,i .

Since ak,i directly influences ak+1,j for j ∈ [M ], we have for k ∈


{1, . . . , K − 1} and i ∈ [M ]
∂A ∂A ∂ak+1,j
rk,i = =
X
∂ak,i j∈[M ] ∂ak+1,j ∂ak,i

=
X
rk+1,j qk+1,j,i
j∈[M ]

=
X
µk+1,i,j .
j∈[M ]

We summarize the procedure in Algorithm 10.3. The forward pass


uses the softmax operator maxε and the softargmax operator argmaxε .
In the hard max case, in Algorithm 10.2, we used q to store backpointers
from integer to integer. In the soft max case, in Algorithm 10.3, we used q
to store soft backpointers, that is, discrete probability distributions. In
the zero-temperature limit, backpropagation outputs a binary encoding
of the solution of backtracking.
10.10. Summary 237

Algorithm 10.3 Inference on a chain as backprop with max operators


Input: θ ∈ RK×M ×M
Max operator: maxε
1: Initialize a1,j := θ1,1,j ∀j ∈ [M ]
2: for k := 2, . . . , K do ▷ Forward pass
3: for j ∈ [M ] do
4: ak,j := maxε θk,i,j + ak−1,j ∈ R
i∈[M ]
5: qk,j := argmaxε θk,i,j + ak−1,j ∈ △M
i∈[M ]

6: A := maxε aK,i ∈ R
i∈[M ]
7: Q := argmaxε aK,i ∈ △M
i∈[M ]
8: Initialize rK,j = Qj ∀j ∈ [K]
9: for k := K − 1, . . . , 1 do ▷ Backward pass
10: for i ∈ [M ] do
11: for j ∈ [M ] do
12: µk+1,i,j = rk+1,j · qk+1,j,i
13: rk,i ← µk+1,i,j
maxε θ1,1,i1 + = A, ∇A(θ) = µ
PK
Outputs: k=2 θk,ik−1 ,ik
i1 ,...,iK ∈[M ]K

10.10 Summary

• Graphical models represent the conditional dependencies between


variables and therefore specify how their joint distribution factor-
izes.

• There are clear analogies between the worlds of functions and of


distributions: the counterparts of computation chains and compu-
tation graphs are Markov chains and Bayesian networks.

• Inference on chains and more generally on trees, for exponential


family distributions, is equivalent, both statistically and algorith-
mically, to differentiating the log-partition function.

• The forward-backward algorithm can be seen as using a sum-


238 Inference in graphical models as differentiation

product algebra, while the Viterbi algorithm can be seen as using


a max-plus algebra. Equivalently, in the log domain, we can see
the former as using a soft max, and the latter as using a hard
max.
11
Differentiating through optimization

In this chapter, we study how to differentiate through optimization


problems, and more generally through nonlinear systems of equations.

11.1 Implicit functions

Implicit functions are functions that do not enjoy an explicit decompo-


sition into elementary functions, for which automatic differentiation, as
studied in Chapter 8, can therefore not be directly applied. We describe
in this chapter techniques to differentiate through such functions and
how to integrate them into an autodiff framework.
Formally, we will denote an implicit function by w⋆ (λ), where
w⋆ : Λ → W. One question is then how to compute the Jacobian
∂w⋆ (λ). As a first application one can consider sensitivity analysis
of a system. For example, w⋆ (λ) could correspond to the equilibrium
state of a physical system and in this case, ∂w⋆ (λ) would tell us about
the sensitivity of the system to some parameters λ ∈ Λ.

239
240 Differentiating through optimization

11.1.1 Optimization problems


Another example is a function implicitly defined as the solution (assumed
unique) of an optimization problem

w⋆ (λ) = arg max f (w, λ),


w∈W

where f : W × Λ → R and W denotes a constraint set. Note that we


use an arg max for convenience, but the same applies when using an
arg min.

11.1.2 Nonlinear equations


More generally, w⋆ (λ) can be defined as the root of some function
F : W × Λ → W, i.e., w⋆ (λ) is implicitly defined as the function
satisfying the (potentially nonlinear) system of equations

F (w, λ) = 0

for all λ ∈ Λ.

11.1.3 Application to bilevel optimization


Besides sensitivity analysis, another example of application is bilevel
optimization. Many times, we want to minimize a function defined as
the composition of a fixed function and the solution of an optimization
problem. Formally, let f, g : W × Λ → R. We consider the composition
h(λ) defined as

h(λ) := g(w⋆ (λ), λ), where w⋆ (λ) = arg max f (w, λ). (11.1)
w∈W

This includes for instance hyperparameter optimization, where f is an


inner log-likelihood objective, g is an outer validation loss, w ∈ W
are model parameters and λ ∈ Λ are model hyperparameters, such as
regularization strength, as illustrated in Fig. 11.1. To minimize h(λ) one
generally resorts to a gradient descent scheme w.r.t. λ, which requires
computing ∇h(λ). Assuming that w⋆ (λ) is differentiable at λ, by the
chain rule, we obtain the Jacobian

∂h(λ) = ∂1 g(w⋆ (λ), λ)∂w⋆ (λ) + ∂2 g(w⋆ (λ), λ).


11.2. Envelope theorems 241

5 w ( 1)
w ( 2)
w ( )
10
0 1 2

Figure 11.1: Hyperparameter optimization in nonlinear regression can be cast as a bi-


level optimization problem. Each line corresponds to the estimator obtained by fitting
some training data (in blue circles) using a different hyperparameter λ. Formally,
denoting f the training objective, the estimators are w⋆ (λ) := arg minw f (w; λ).
The goal is to find the best hyperparameter that fits some validation data (here in
cyan diamonds), that is, minimizing h(λ) := g(w⋆ (λ), λ), where g is the validation
objective. A too small λ1 leads to overfitting the training objective and performs
badly on validation objective. Conversely, a larger λ2 underfits both training and
validation objectives. The optimal parameter λ⋆ minimizes the validation objective
and may be obtained by iterating gradient descent w.r.t. λ. This requires gradients
of h(λ) = g(w⋆ (λ), λ) w.r.t. λ.

Using ∂h(λ)⊤ = ∇h(λ) (see Remark 2.4), we obtain the gradient

∇h(λ) = ∂w⋆ (λ)⊤ ∇1 g(w⋆ (λ), λ) + ∇2 g(w⋆ (λ), λ).

The only problematic term is ∂w⋆ (λ), as it requires argmax differenti-


ation. Indeed, most of the time, there is no explicit formula for w⋆ (λ)
and it does not decompose into elementary functions.

11.2 Envelope theorems

In the special case g = f , the composition h defined in Eq. (11.1) is


simply given by

h(λ) = f (w⋆ (λ), λ) = max f (w, λ).


w∈W

That is, we no longer need argmax differentiation, but only max


differentiation, which, as we shall now see is much easier. The function
h is often called a value function (Fleming and Rishel, 2012). The
242 Differentiating through optimization

Functions for varying

Function

Figure 11.2: The graph of h(λ) = maxw∈W f (w, λ) is the upper-envelope of the
graphs of the functions λ 7→ f (w, λ) for all w ∈ W.

reason for the name “envelope” is illustrated in Fig. 11.2. We emphasize


that there is not one, but several envelope theorems, depending on the
assumptions on f .

11.2.1 Danskin’s theorem

When f is concave-convex, we can use Danskin’s theorem.

Theorem 11.1 (Danskin’s theorem). Let f : W × Λ → R and W be


a compact convex set. Let

h(λ) := max f (w, λ)


w∈W

and
w⋆ (λ) := arg max f (w, λ).
w∈W
If f is concave in w, convex in λ, and the maximum w⋆ (λ) is
unique, then the function h is differentiable with gradient

∇h(λ) = ∇2 f (w⋆ (λ), λ).

If the maximum is not unique, we get a subgradient.


11.2. Envelope theorems 243

Informally, Danskin’s theorem means that we can treat w⋆ (λ) as if


it were a constant of λ, i.e., we do not need to differentiate through it,
even though it depends on λ. Danskin’s theorem can also be used to
differentiate through a minimum, h(λ) = minw∈W f (w, λ), if f (w, λ)
is convex in w and concave in λ, as we now illustrate.
Example 11.1 (Ilustration of Danskin’s theorem). Let us define h(λ) =
minw∈R f (w, λ), where f (w, λ) = λ2 w2 +bw+c and λ > 0. Let w⋆ (λ)
be the minimum. The derivative of f w.r.t. λ is 12 w2 . From Dan-
skin’s theorem, we have h′ (λ) = 12 w⋆ (λ). Let us check that this
result is correct. The derivative of f w.r.t. w is λw + b. Setting it to
2
zero, we get w⋆ (λ) = − λb . We thus obtain h′ (λ) = 12 λb 2 . Let’s check
that the result is indeed correct. Plugging w⋆ (λ) back into f (w, λ),
2
we get h(λ) = − 12 bλ + c. Using ( λ1 )′ = − λ12 , we indeed obtain the
same result for h′ (λ).
Danskin’s theorem has a simple interpretation for functions that are
linear in λ as shown below.
Example 11.2 (Convex conjugate). Let f (w, λ) := ⟨w, λ⟩ − Ω(w)
with Ω convex. We then have h(λ) = maxw∈W ⟨w, λ⟩ − Ω(w) =:
Ω∗ (λ), where Ω∗ denotes the convex conjugate of Ω. Since f satisfies
the conditions of Danskin’s theorem and since we have ∇2 f (w, λ) =
w, we obtain ∇h(λ) = ∇Ω∗ (λ) = w⋆ (λ). In other words, in this
special case, the gradient of the max is equal to the argmax. This
is due to the fact that f (w, λ) is linear in λ.
Another application is saddle point optimization.
Example 11.3 (Saddle point problem). Consider the saddle point
problem minλ∈Λ maxw∈W f (w, λ). If it is difficult to minimize w.r.t.
λ but easy to maximize w.r.t. w, we can rewrite the problem as
minλ∈Λ h(λ), where h(λ) := maxw∈W f (w, λ), and use ∇h(λ) to
perform (projected) gradient descent w.r.t. λ.

11.2.2 Rockafellar’s theorem


A related theorem can be proved under different assumptions about f ,
in particular without concavity w.r.t. w.
244 Differentiating through optimization

Theorem 11.2 (Rockafellar’s envelope theorem). Let f : W × Λ →


R and W be a compact convex set. Let

h(λ) := max f (w, λ)


w∈W

and
w⋆ (λ) := arg max f (w, λ).
w∈W
If f is continuously differentiable in λ for all w ∈ W, ∇1 f is
continuous and the maximum w⋆ (λ) is unique, then the function
h is differentiable with gradient

∇h(λ) = ∇2 f (w⋆ (λ), λ).

See Rockafellar and Wets (2009, Theorem 10.31). Compared to


Danskin’s theorem, Rockafellar’s theorem does not require f to be
concave-convex, but requires stronger assumptions on the differentiabil-
ity of f .

11.3 Implicit function theorem

11.3.1 Univariate functions


The implicit function theorem (IFT) provides conditions under which
an implicit relationship of the form F (x, λ) = 0 can be rewritten as a
function x = x⋆ (λ) locally, and provides a way to compute its derivative
w.r.t. λ.

Theorem 11.3 (Implicit function theorem, univariate case). Let


F : R × R → R. Assume F (x, λ) is a continuously differentiable
function in a neighborhood U of (x0 , λ0 ) such that F (x0 , λ0 ) = 0
and ∂1 F (x0 , λ0 ) ̸= 0. Then there exists a neighborhood V ⊆ U of
(x0 , λ0 ) in which there is a function x⋆ (λ) such that

• x⋆ (λ0 ) = x0 ,

• F (x⋆ (λ), λ) = 0 for all λ in the neighborhood V,


11.3. Implicit function theorem 245


• ∂x⋆ (λ) = − ∂∂21 FF (x (λ),λ)
(x⋆ (λ),λ) .

We postpone the proof to the multivariate case and begin with a


classical example of application of the theorem.

Example 11.4 (Equation of the unit circle). We use w ≡ x and λ ≡


y for clarity. Let F (x, y) := x2 + y 2 − 1. In general, we cannot
rewrite the unit circle equation F (x, y) = 0 as a function from y
to x, because for everypy ∈ [−1, 1], therepare always two possible
x values, namely, x = 1 − y 2 or x = − 1 − y 2 . However, locally
around some point (x0 , y0 ), e.g., such that x0 >p0 and y0 > 0
(upper-right quadrant), the function x = x⋆ (y) = 1 − y 2 is well-
defined. Using ∂1 F (x, y) = 2x and ∂2 F (x, y) = 2y, we get ∂x⋆ (y) =
y
− 2x2y
⋆ (y) = − √ 2
in that neighborhood (the upper right quadrant
1−y
in this case). This is indeed the same derivative expression as if we
used the chain rule on 1 − y 2 and is well-defined on y ∈ [0, 1).
p

In the above simple example, we can easily derive an explicit function


relating y to x in a given neighborhood, but this is not always the case.
The IFT gives us conditions guaranteeing that such function exists and
a way to differentiate it, but not a way to construct such a function.
In fact, finding x⋆ (λ) such that F (x⋆ (λ), λ) = 0 typically involves a root
finding algorithm, an optimization algorithm, a nonlinear system solver,
etc.

Example 11.5 (Polynomial). Let F (w, λ) = w5 + w3 + w − λ. Ac-


cording to the Abel-Ruffini theorem (Tignol, 2015), quintics (poly-
nomials of degree 5) do no enjoy roots in terms of radicals and
one must resort to numerical root finding. In addition, odd-degree
polynomials have real roots. Moreover, ∂1 F (w, λ) = 5w4 + 3w2 + 1
is strictly positive. Therefore, by the intermediate value theorem,
there must be only one root w⋆ (λ) such that F (w⋆ (λ), λ) = 0. This
unique root can for example be found by bisection. Using the IFT,
its derivative is found to be ∂w⋆ (λ) = (5w⋆ (λ)4 + 3w⋆ (λ)2 + 1)−1 .

While an implicit function is differentiable at a point if the assump-


tions of the IFT hold in a neighborhood of that point, the reciprocal is
246 Differentiating through optimization

not true: failure of the IFT assumptions does not necessarily mean that
the implicit function is not differentiable, as we now illustrate.

Example 11.6 (IFT conditions are not necessary for differentiability).


Consider F (w, λ) = (w − λ)2 . We clearly have that F (w⋆ (λ), λ) = 0
if we define w⋆ (λ) = λ, the identity function. It is clearly differen-
tiable for all λ, yet the assumptions of the IFT fail, since we have
∂1 F (w, λ) = 2(w − λ) and therefore ∂1 F (0, 0) = 0.

11.3.2 Multivariate functions


We now present the IFT in the general multivariate setting. Informally,
if F (w⋆ (λ), λ) = 0, then by the chain rule, we have

∂1 F (w⋆ (λ), λ)∂w⋆ (λ) + ∂2 F (w⋆ (λ), λ) = 0,

meaning that the Jacobian ∂w⋆ (λ), assuming that it exists, satisfies

−∂1 F (w⋆ (λ), λ)∂w⋆ (λ) = ∂2 F (w⋆ (λ), λ).

The IFT gives us conditions for the existence of ∂w⋆ (λ).

Theorem 11.4 (Implicit function theorem, multivariate case). Let us


define F : W × Λ → W. Assume F (w, λ) is a continuously differen-
tiable function in a neighborhood of (w0 , λ0 ) such that F (w0 , λ0 ) =
0 and ∂1 F (w0 , λ0 ) is invertible, i.e., its determinant is nonzero.
Then there exists a neighborhood of λ0 in which there is a function
w⋆ (λ) such that

• w⋆ (λ0 ) = w0 ,

• F (w⋆ (λ), λ) = 0 for all λ in the neighborhood,

• −∂1 F (w⋆ (λ), λ)∂w⋆ (λ) = ∂2 F (w⋆ (λ), λ)


⇐⇒ ∂w⋆ (λ) = −∂1 F (w⋆ (λ), λ)−1 ∂2 F (w⋆ (λ), λ).

We begin with a simple unconstrained optimization algorithm.


11.3. Implicit function theorem 247

Example 11.7 (Unconstrained optimization). Assume we want to


differentiate through w⋆ (λ) = arg minw∈RP f (w, λ), where f is
strictly convex in w, which ensures that the solution is unique. From
the stationary conditions, if we define F (w, λ) := ∇1 f (w, λ), then
w⋆ (λ) is uniquely characterized as the root of F in the first argu-
ment, i.e., F (w⋆ (λ), λ) = 0. We have ∂1 F (w, λ) = ∇21 f (w, λ), the
Hessian of f in w, and ∂2 F (w, λ) = ∂2 ∇1 f (w, λ), the cross deriva-
tives of f . Therefore, assuming that the Hessian is well-defined and
invertible at (w⋆ (λ), λ), we can use the IFT to differentiate through
w⋆ (λ) and obtain ∂w⋆ (λ) = −∇21 f (w⋆ (λ), λ)∂2 ∇1 f (w⋆ (λ), λ).

Next, we generalize the previous example, by allowing constraints


in the optimization problem.

Example 11.8 (Constrained optimization). Now, assume we want


to differentiate through w⋆ (λ) = arg minw∈C f (w, λ), where f is
strictly convex in w and C ⊆ W is a convex set. A solution is
characterized by the fixed point equation w⋆ (λ) = PC (w⋆ (λ) −
η∇1 f (w⋆ (λ), λ)), for any η > 0, where P (y) := arg minx∈C ∥x−y∥22
is the Euclidean projection of y onto C. Therefore, w⋆ (λ) is the
root of F (w, λ) = w − PC (w − η∇1 f (w, λ)) (see Chapter 16). We
can differentiate through w⋆ (λ) using the IFT, assuming that the
conditions of the theorem apply. Note that ∂1 F (w, λ) requires
the expression of the Jacobian ∂PC (y). Fortunately, PC (y) and its
Jacobian are easy to compute for many sets C (Blondel et al., 2021).

11.3.3 JVP and VJP of implicit functions


To integrate an implicit function w⋆ (λ) in an autodiff framework, we
need to be able to compute its JVP or VJP. This is the purpose of the
next proposition.

Proposition 11.1 (JVP and VJP of implicit functions). Let w⋆ : Λ →


W be a function implicitly defined as the solution of F (w⋆ (λ), λ) =
248 Differentiating through optimization

0, for some function F : W × Λ → W. Define

A := −∂1 F (w⋆ (λ), λ)


B := ∂2 F (w⋆ (λ), λ).

Assume the assumptions of the IFT hold. The JVP t := ∂w⋆ (λ)v
in the input direction v ∈ Λ is obtained by solving the linear system

At = Bv.

The VJP ∂w⋆ (λ)∗ u in the output direction u ∈ W is obtained by


solving the linear system

A∗ r = u.

Using the solution r, we get

∂w⋆ (λ)∗ u = ∂w⋆ (λ)∗ A∗ r = B ∗ r.


Note that in the above linear systems, we can access to A and B as
linear maps, the JVPs of F . Their adjoints, A∗ and B ∗ , correspond to
the VJPs of F . To solve these systems, we can therefore use matrix-free
solvers as detailed in Section 9.4. For example, when A is symmetric pos-
itive semi-definite, we can use the conjugate gradient method (Hestenes,
Stiefel, et al., 1952). When A is not symmetric positive definite, we can
use GMRES (Saad and Schultz, 1986) or BiCGSTAB (Vorst and Vorst,
1992).

11.3.4 Proof of the implicit function theorem


We prove the theorem using the inverse function theorem presented
in Theorem 11.5. Define
f (λ, w) = (λ, F (w, λ))
which goes from RQ × RP onto RQ × RP . The Jacobian of f is
!
I 0
∂f (λ, w) = .
∂2 F (w, λ) ∂1 F (w, λ)
So at w0 , λ0 , we have det(∂f (λ0 , w0 )) = det(I) det(∂1 F (w0 , λ0 )) >
0 since we assumed ∂1 F (w0 , λ0 ) invertible. By the inverse function
11.4. Adjoint state method 249

theorem, the function f is then invertible in a neighborhood N of


f (λ0 , w0 ) = (λ0 , 0). In particular, it is invertible in N ∩ {(λ, 0), λ ∈
RQ }. The solution of the implicit equation in a neighborhood of λ0
is then (λ, w∗ (λ)) = f −1 (λ, 0). By the inverse function theorem, f −1
is continuously differentiable inverse and so is w∗ (λ). The derivative
∂w∗ (λ) from the differential of the inverse as
!
∼ ∼
= ∂f −1 (λ, 0),
∂w∗ (λ) ∼

and by the inverse function theorem, we have ∂f −1 (λ, 0) = (∂f (λ, w∗ (λ)))−1 .
So using block matrix inversions formula
!−1 !
A B ∼ ∼
= ,
C D −(D − CA−1 B)−1 CA−1 ∼
we get the claimed expression. Though we expressed the proof in terms of
Jacobians and matrices, the result naturally holds for the corresponding
linear operators, JVPs, VJPs, and their inverses.

11.4 Adjoint state method

11.4.1 Differentiating nonlinear equations


We describe in this section the adjoint state method (a.k.a. adjoint
method, method of adjoints, adjoint sensitivity method). The method
can be used to compute the gradient of the composition of an explicit
function and an implicit function, defined through an equality
constraint (e.g., a nonlinear equation). The method dates back
to Céa (1986).
Suppose a variable s ∈ S (which corresponds to a state in optimal
control) is implicitly defined given some parameters w ∈ W through
the (potentially nonlinear) equation c(s, w) = 0, where c : S × W → S.
Assuming s is uniquely determined for all w ∈ W, this defines an
implicit function s⋆ (w) from W to S such that c(s⋆ (w), w) = 0.
Given an objective function L : S × W → R, the goal of the adjoint
state method is then to compute the gradient of

L(w) := L(s⋆ (w), w).


250 Differentiating through optimization

However, this is not trivial as s⋆ (w) is an implicit function. For instance,


this can be used to convert the equality-constrained problem

min L(s, w) s.t. c(s, w) = 0.


w∈W

into the unconstrained problem

min L(s⋆ (w), w).


w∈W

Access to ∇L(w) allows us to solve this problem by gradient descent.

Proposition 11.2 (Adjoint state method). Let c : S × W → S be a


mapping defining constraints of the form c(s, w). Assume that for
each w ∈ W, there exists a unique s⋆ (w) satisfying c(s⋆ (w), w) = 0
and that s⋆ (w) is differentiable. The gradient of

L(w) := L(s⋆ (w), w),

for some differentiable function L : S × W → R, is given by

∇L(w) = ∇2 L(s⋆ (w), w) + ∂2 c(s⋆ (w), w)∗ r ⋆ (w),

where r ⋆ (w) is the solution of the linear system

∂1 c(s⋆ (w), w)∗ r = −∇1 L(s⋆ (w), w).

As shown in the proof below, r ⋆ (w) corresponds to a Lagrange


multiplier. The linear system can be solved using matrix-free solvers.

11.4.2 Relation with envelope theorems


Because s is uniquely determined for any w ∈ W by c(s, w) = 0, we can
alternatively rewrite L(w) as the trivial minimization or maximization,

L(w) = min L(s, w) s.t. c(s, w) = 0


s∈S
= max L(s, w) s.t. c(s, w) = 0.
s∈S

Therefore, the adjoint state method can be seen as an envelope theorem


for computing ∇L(w), for the case when w is involved in both the
objective function and in the equality constraint.
11.4. Adjoint state method 251

11.4.3 Proof using the method of Lagrange multipliers


Classically, the adjoint state method is derived using the method of
Lagrange multipliers. Let us introduce the Lagrangian associated with
L and c,
L(s, w, r) := L(s, w) + ⟨r, c(s, w)⟩,
where r ∈ S is the Lagrange multiplier associated with the equality
constraint c(s, w) = 0. In the optimal control literature, r is often
called the adjoint variable or adjoint state. The gradients of the
Lagrangian are

∇s L(s, w, r) = ∇1 L(s, w) + ∂1 c(s, w)∗ r


∇w L(s, w, r) = ∇2 L(s, w) + ∂2 c(s, w)∗ r
∇r L(s, w, r) = c(s, w),

where ∂i c(s, w)∗ are the adjoint operators. Setting ∇r L(s, w, r) to


zero gives the constraint c(s, w) = 0. Setting ∇s L(s, w, r) to zero gives
the so-called adjoint state equation

∂1 c(s, w)∗ r = −∇1 L(s, w).

Solving this linear system w.r.t. r at s = s⋆ (w) gives the adjoint variable
r ⋆ (w). We then get

∇L(w) = ∇2 L(s⋆ (w), w, r ⋆ (w))


= ∇2 L(s⋆ (w), w) + ∂2 c(s⋆ (w), w)∗ r ⋆ (w),

which concludes the proof.

11.4.4 Proof using the implicit function theorem


A more direct proof is possible thanks to the implicit function theorem
(Section 11.3). Using the chain rule, we get

∇L(w) = ∇2 L(s⋆ (w), w) + ∂s⋆ (w)∗ ∇1 L(s⋆ (w), w),

where ∂s⋆ (w)∗ is the VJP of s⋆ , a linear map from S to W.


Computationally, the main difficulty is to apply ∂s⋆ (w)∗ to the
vector u = ∇1 L(s⋆ (w), w) ∈ S. Using the implicit function theorem
252 Differentiating through optimization

(Section 11.3) on the implicit function c(s⋆ (w), w) = 0, and Proposi-


tion 11.1, we get the linear system A∗ r = u, where A∗ := ∂1 c(s⋆ (w), w)∗
is a linear map from S to S. After solving for r, we get ∂s⋆ (w)∗ u = B ∗ r,
where B ∗ := ∂2 c(s⋆ (w), w)∗ is a linear map from S to W. Putting ev-
erything together, we get

∇L(w) = ∇2 L(s⋆ (w), w) + ∂2 c(s⋆ (w), w)∗ r.

11.4.5 Reverse mode as adjoint method with backsubstitution


In this section, we revisit reverse-mode autodiff from the perspective
of the adjoint state method. For clarity, we focus our exposition on
feedforward networks with input x ∈ X and network weights w =
(w1 , . . . , wk ) ∈ W1 × . . . × WK ,

s0 := x ∈ X
s1 := f1 (s0 , w1 ) ∈ S1
..
.
sK := fK (sK−1 , wK ) ∈ SK
f (w) := sK . (11.2)

Here we focus on gradients with respect to the parameters w, hence


the notation f (w). We can use the adjoint state method to recover
reverse-mode autodiff, and prove its correctness in the process. While
we focus for simplicity on feedforward networks, our exposition can be
generalized to computation graphs.

Feedforward networks as the solution of a nonlinear equation


While we defined the set of intermediate computations s = (s1 , . . . , sK ) ∈
S1 × . . . × SK as a sequence of operations, they can also be defined as
the unique solution of the nonlinear equation c(s, w) = 0, where

s1 − f1 (x, w1 )
 

s2 − f2 (s1 , w2 ) 
c(s, w) := 
 
.. .
.
 
 
sK − fK (sK−1 , wK )
11.4. Adjoint state method 253

This defines an implicit function s⋆ (w) = (s⋆1 (w), . . . , s⋆K (w)), the
solution of this nonlinear system, which is given by the variables
s1 , . . . , sK defined in (11.2). The output of the feedforward network is
then f (w) = s⋆K (w).
In machine learning, the final layer s⋆K (w) is typically fed into a
loss ℓ, to define
L(w) := ℓ(s⋆K (w); y).
Note that an alternative is to write L(w) as
L(w) = min ℓ(s; y) s.t. c(s, w) = 0.
s∈S

More generally, if we just want to compute the VJP of s⋆K (w) in


some direction uK ∈ SK , we can define the scalar-valued function
L(w) := ℓ(s⋆K (w); uK ) := ⟨s⋆K (w), uK ⟩
so that
∂f (w)∗ uK = ∇L(w).
Let us define u ∈ S1 × · · · × SK−1 × SK as u := (0, . . . , 0, ∇1 ℓ(f (w); y))
(gradient of the loss ℓ case) or u := (0, . . . , 0, uK ) (VJP of f in the
direction uK case). Using the adjoint state method, we know that the
gradient of this objective is obtained as
∇L(w) = ∂2 c(s(w), w)∗ r ⋆ (w),
for r ⋆ (w) the solution of the linear system
∂1 c(s(w), w)∗ r = −u.

Solving the linear system using backsubtitution


The JVP of the constraint function c at s⋆ (w), materialized as a matrix,
takes the form of a block lower-triangular matrix
I 0 0
 
... ...
.. .. 
I . .

−A
1
.. 
 
∂1 c(s (w), w) =  0

 ..

−A2 I . . ,
 .
 . .. .. ..

 . . . . 0

0 ... 0 −AK I
254 Differentiating through optimization

where Ak := ∂1 fk (sk−1 , wk ). Crucially the triangular structure of the


JVP stems from the fact that each intermediate activation only de-
pends from the past intermediate activations. Therefore, the constraints,
corresponding to the lines of the Jacobian, cannot introduce non-zero
values beyond its diagonal. The VJP takes the form of a block upper-
triangular matrix

I −A∗1 0 0
 
...
.. .. 
0 I . . 

−A∗2
.
 
.. ..
 ..

∂1 c(s (w), w) = 


. I . 0  .
.
. .. ..

. . . −A∗K 

0 ... ... 0 I

Solving an upper triangular system like ∂1 c(s(w), w)∗ r = u can then


be done efficiently by backsubstitution. Starting from the last adjoint
state rK = u, we can compute each adjoint state rk from that computed
at k + 1. Namely, for k ∈ (K − 1, . . . , 1), we have

rk − A∗k+1 rk+1 = 0 ⇐⇒ rk = ∂fk+1 (sk , wk+1 )∗ rk+1 .

The VJPs with respect to the parameters are then obtained by

∂2 f1 (x, w1 )∗ r1
 



∂2 f2 (s1 (w), w2 )∗ r2 
∂2 c(s(w), w) r = 
 
.. ,
.
 
 
∂2 fK (s1 (w), wK )∗ rK

recovering reverse-mode autodiff.


The Lagrangian perspective of backpropagation for networks with
separate parameters w = (w1 , . . . , wK ) is well-known; see for instance
LeCun (1988) or Recht (2016). The Lagrangian perspective of back-
propagation through time (Werbos, 1990) for networks with shared
parameter w is discussed for instance by Franceschi et al. (2017). Our
exposition uses the adjoint state method, which can itself be proved
either using the method of Lagrange multipliers (Section 11.4.3) or by
the implicit function theorem (Section 11.4.4), combined with backsub-
titution for solving the upper-triangular linear system. Past works often
11.5. Inverse function theorem 255

minimize over w but we do not require this, as gradients are not neces-
sarily used for optimization. Our exposition also supports computing
the VJP of any vector-valued function f , while existing works derive
the gradient of a scalar-valued loss function.

11.5 Inverse function theorem

11.5.1 Differentiating inverse functions

In some cases (see for instance Section 12.4.4), it is useful to compute


the Jacobian of an inverse function f −1 . The inverse function theorem
below allows us to relate the Jacobian of f −1 with the Jacobian of f .

Theorem 11.5 (Inverse function theorem). Assume f : W → W is


continuously differentiable with invertible Jacobian ∂f (w0 ) at w0 .
Then f is bijective from a neighborhood of w0 to a neighborhood
of f (w0 ). Moreover, the inverse f −1 is continuously differentiable
near ω0 = f (w0 ) and the Jacobian of the inverse ∂f −1 (ω) is

∂f (w)∂f −1 (ω) = I ⇔ ∂f −1 (ω) = (∂f (w))−1 ,

with w = f −1 (ω).

11.5.2 Link with the implicit function theorem

The inverse function theorem can be used to prove the implicit function
theorem; see proof of Theorem 11.4. Conversely, recall that, in order to
use the implicit function theorem, we need to choose a root objective
F : W × Λ → W. If we set W = Λ = RQ and F (w, ω) = f (w) −
ω, with f : RQ → RQ , then we have that the root w⋆ (ω) satisfying
F (w⋆ (ω), ω) = 0 is exactly w⋆ (ω) = f −1 (ω). Moreover, ∂1 F (w, ω) =
∂f (w) and ∂2 F (w, ω) = −I. By applying the implicit function theorem
with this F , we indeed recover the inverse function theorem.

11.5.3 Proof of inverse function theorem

We first give a proof of the formula assuming that f −1 is well-defined


and continuously differentiable in a neighborhood of f (w0 ). In that
256 Differentiating through optimization

case, we have for any ω in a neighborhood of f (w0 ),

f ◦ f −1 (ω) = ω.

Differentiating both sides w.r.t. ω, we get

∂f (f −1 (ω))∂f −1 (ω) = I,

where I is the identity function in RQ . In particular, for w = f −1 (ω)


we recover the formula presented in the statement.
Now, it remains to show that invertibility of the JVP ensures that
the function is invertible in a neighborhood of f (w0 ) and that the
inverse is continuously differentiable. For that, denote l = ∂f (w0 ) such
that l−1 is well-defined by definition. f is invertible with continuously
differentiable inverse, if and only if l−1 (f (w)) − f (w0 ) is invertible with
continuously differentiable inverse. So without loss of generality, we
consider ∂f (w0 ) = I, f (w0 ) = 0, w0 = 0.
As f is continuously differentiable, there exists a neighborhood
N = {w : ∥w − w0 ∥2 ≤ δ} on which we have ∥∂f (w) − I ∥2 ≤ 1/2. In
this neighborhood, the function g(w) = f (w) − w is contractive by the
mean value theorem with contraction factor 1/2. For any ω such that
∥ω − f (w0 )∥2 ≤ δ/2, the sequence wk+1 = wk − f (wk ) − ω ′ remains in
N and converges (since it is a Cauchy sequence by the contraction of g)
to a unique fixed point w∞ satisfying w∞ = w∞ − f (w∞ ) − ω ⇐⇒
f (w∞ ) = ω. This shows the existence of the inverse in the neighborhood
M = {ω : ∥ω − ω0 ∥2 ≤ δ/2} of ω0 = f (w0 ) onto N .
We tackle now the differentiability (hence the continuity) of f −1 .
For any ω in the neighborhood of ω0 with inverse w := f −1 (ω) ∈ N ,
the JVP of f at w satisfies by assumption ∥∂f (w) − I ∥2 ≤ 1/2. Hence,
a = ∂f (w) − I defines a convergent series b = +∞ k=0 a and one verifies
k
P

easily that b = ∂f (w)−1 , that is ∂f (w) is invertible and ∥∂f (w)−1 ∥ ≤ 2.


To compute the JVP of the inverse, we consider then ∂f (w)−1 as the
candidate JVP and examine
∥f −1 (ω + η) − f (ω) − (∂f (w))−1 η∥2
.
∥η∥2
Denote then v such that f −1 (ω + η) = w + v. As g(w) = f (w) − w is
1/2-contractive in N , we have ∥v−η∥2 = ∥g(w+v)−g(w)∥2 ≤ 1/2∥v∥2 .
11.6. Summary 257

So ∥v∥2 ≥ ∥η∥/2. We then get


∥f −1 (ω + η) − f (ω) − (∂f (w))−1 η∥2
∥η∥2
∥v − (∂f (w))−1 (f (w + v) − f (w))∥2
=
∥η∥2
∥f (w + v) − f (w) − ∂f (w)v∥2
≤4
∥u∥2
As ∥η∥2 → 0, we have ∥v∥2 → 0 and so ∥f (w+v)−f (w)−∂f (w)v∥2 /∥v∥2 →
0. Hence, f −1 is differentiable with JVP ∂f −1 (ω) = (∂f (w))−1 =
(∂f (f −1 (ω)))−1 . This shows that f −1 is continuous and so ∂f −1 (ω) is
continuous as a composition of continuous functions.

11.6 Summary

• Implicit functions are functions that cannot be decomposed into


elementary operations and for which autodiff can therefore not
be directly applied. Examples are optimization problems and
nonlinear equations.

• Envelope theorems can be used for differentiating through the


min or max value (not solution) of a function.

• More generally, the implicit function theorem allows us to dif-


ferentiate through implicit functions. It gives conditions for the
existence of derivatives and how to obtain them.

• The adjoint state method can be used to obtain the gradient of


the composition of an explicit function and of an implicit function,
specified by equality constraints. It can be used to prove the
correctness of reverse-mode autodiff.

• The inverse function theorem can be used to differentiate function


inverses.

• In a sense, the implicit function theorem can be thought as the


mother theorem, as it can be used to prove envelope theorems,
the adjoint state method and the inverse function theorem.
12
Differentiating through integration

In this chapter, we study how to differentiate through integrals, with a


focus on expectations and solutions of ordinary differential equations.

12.1 Differentiation under the integral sign

Given two Euclidean spaces Θ and Y, and a function f : Θ × Y → R,


we often want to differentiate an integral of the form
Z
F (θ) := f (θ, y)dy.
Y
Provided that we can swap integration and differentiation, we have
Z
∇F (θ) = ∇θ f (θ, y)dy.
Y
The conditions enabling us to do so are best examined in the context
of measure theory. We refer the reader to e.g. (Cohn, 2013) for a course
on measure theory and Flanders (1973) for an in-depth study of the
differentiation under the integral sign. Briefly, if Θ = Y = R, the
following conditions are sufficient.
1. f is measurable in both its arguments, and f (θ, ·) is integrable
for almost all θ ∈ Θ fixed,

258
12.2. Differentiating through expectations 259

2. f (·, y) is absolutely continuous for almost all y ∈ Y, that is,


there exists an integrable function g(·, y) such that f (θ, y) =
f (θ0 , y) + θθ0 g(τ, y)dτ ,
R

3. ∂1 f (θ, y) (which exists almost everywhere if f (·, y) is absolutely


continuous), is locally integrable, that is, for any closed interval
[θ0 , θ1 ], the integral θθ01 |∂1 f (θ, y)|dydθ is finite.
R R

Any differentiable function f : Θ × Y → R is absolutely continuous.


However, the conditions also hold if f is just absolutely continuous, that
is, if f (·, y) is differentiable for almost all y. This weaker assumption
can be used to smooth out differentiable almost-everywhere functions,
such as the ReLu, as we study in Section 14.4.

12.2 Differentiating through expectations

A special case of differentiating through integrals is differentiating


through expectations. We can distinguish between two cases, depending
on whether the parameters θ we wish to differentiate are involved in
the distribution or in the function, whose expectation we compute.

12.2.1 Parameter-independent distributions

We first consider expectations of the form


Z
F (θ) := EY ∼p [g(Y, θ)] = g(y, θ)p(y)dy,
Y

for a random variable Y ∈ Y ⊆ RM , distributed according to a distri-


bution p, and a function g : Y × Θ → R. Importantly, the distribution
is independent of the parameters θ. Under mild conditions recalled
in Section 12.1, we can swap differentiation and integration to obtain
Z
∇F (θ) = ∇θ g(y, θ)p(y)dy
Y
Z
= ∇θ g(y, θ)p(y)dy
Y
= EY ∼p [∇θ g(Y, θ)].
260 Differentiating through integration

Generally, the expectation cannot be computed in closed form. However,


provided that we can sample from p, we can define a Monte-Carlo
estimator of the value
N
1 X
FbN (θ) := g(Yi , θ)
N i=1

and of the gradient


N
1 X
∇FbN (θ) = ∇θ g(Yi , θ),
N i=1
for N i.i.d. samples Y1 , . . . , YN from p. These estimators are unbiased,
meaning that E[FbN (θ)] = F (θ) and E[∇FbN (θ)] = ∇F (θ), and converge
to the true quantity as N → +∞. This suggests a simple implementation
in an autodiff framework of the approximation of ∇F (θ):

1. Sample y1 , . . . , yn from p.

2. Compute FbN (θ) = 1 Pn


n i=1 g(yi , θ).

3. Compute the gradient ∇FbN (θ) by automatic differentiation.

Computing higher order derivatives follow the same principle: to get


an approximation of ∇2 F (θ), we can simply compute ∇2 FbN (θ) by
autodiff. As such, the implementation delineated above is akin to the
“discretize-then-optimize” approach used to differentiate through the
solution of an ODE (Section 12.6): we implement an approximation of
the objective and simply call autodiff on it.

12.2.2 Parameter-dependent distributions


A more challenging case arises when the distribution depends on the
parameters θ:
Z
E(θ) := EY ∼pθ [g(Y )] = g(y)pθ (y)dy,
Y

where Y ∈ Y ⊆ RM is a random variable, distributed according to


a distribution pθ parameterized by θ ∈ Θ and where g : Y → R is,
depending on the setting, potentially a blackbox function (i.e., we do not
12.2. Differentiating through expectations 261

have access to its gradients). Typically, θ ∈ Θ could be parameters we


wish to estimate, or it could indirectly be generated by θ = f (x, w) ∈ Θ,
where f is a neural network with parameters w ∈ W we wish to estimate.
The main difficulty in computing ∇E(θ) stems from the fact that θ
are the parameters of the distribution pθ . Estimating an expectation
E(θ) = EY ∼pθ [g(Y )] using Monte-Carlo estimation requires us to sample
from pθ . However, it is not clear how to differentiate E w.r.t. θ if θ is
involved in the sampling process.

Continuous case
When Y is a continuous set (that is, pθ (y) is a probability density
function), we can rewrite E(θ) as
Z
E(θ) = pθ (y)g(y)dy.
Y

Provided that we can swap integration and differentiation (see Sec-


tion 12.1), we then have
Z
∇E(θ) = ∇θ pθ (y)g(y)dy
Y
Z
= ∇θ pθ (y)g(y)dy.
Y

Unfortunately, this integral is not an expectation and it could be in-


tractable in general.

Discrete case
When Y is a discrete set (that is, pθ (y) is a probability mass function),
we can rewrite E(θ) as

E(θ) = pθ (y)g(y).
X

y∈Y

We then obtain
∇E(θ) = g(y)∇θ pθ (y).
X

y∈Y

Again ∇E(θ) is not an expectation. We therefore cannot use Monte-


Carlo estimation to estimate the gradient. Instead, we can compute it
262 Differentiating through integration

by brute force, i.e., by summing over all possible y ∈ Y. However, this is


clearly only computationally tractable if |Y| is small or if pθ is designed
to have sparse support, i.e., so that the set {y ∈ Y : pθ (y) ̸= 0} is
small. Moreover, even if these conditions hold, summing over y could be
problematic if g(y) is expensive to compute. Therefore, exact gradients
are seldom used in practice.
In Sections 12.3 and 12.4, we review the score function and pathwise
gradient estimators, to (approximately) compute ∇E(θ), allowing us
to optimize θ (or w using the chain rule) by gradient-based algorithms.

12.2.3 Application to expected loss functions

Differentiating through expectations is particularly useful when working


with expected loss functions of the form

L(θ; y) := EŶ ∼pθ [ℓ(Ŷ , y)],

where y is some ground truth. Equivalently, we can set ℓ = −r, where


r is a reward function. As we shall see, the score function estimator
will support a discrete loss function ℓ : Y × Y → R, while the pathwise
gradient estimator will require a differentiable loss function ℓ : RM ×Y →
R. Intuitively, L(θ; y) will be low if pθ assigns high probability to
predictions yb with low loss value ℓ(y,b y).
In the classification setting, where Y = [M ], pθ is often chosen to be
the Gibbs distribution, which is a categorical distribution induced
by a softargmax
exp(θy )
pθ (y) := P = [softmax(θ)]y ∈ (0, 1),
i∈[M ] exp(θi )

where θy := f (x, y, w) ∈ R are logits produced by a neural network f .


More generally, in the structured prediction setting, where Y ⊆ RM but
|Y| ≫ M , we often use the distribution

exp(⟨ϕ(y), θ⟩)
pθ (y) := P ,
y ′ ∈Y exp(⟨ϕ(y ), θ⟩)

where θ = f (x, w) ∈ RM .
12.2. Differentiating through expectations 263

Given a distribution ρ over X × Y, we then want to minimize the


expected loss function, also known as risk,

R(w) := E(X,Y )∼ρ [L(f (X, w); Y )].

Typically, minimizing R(w) is done through some form of gradient


descent, which requires us to be able to compute

∇R(w) = E(X,Y )∼ρ [∇w L(f (X, w); Y )]


= E(X,Y )∼ρ [∂2 f (x, w)∗ ∇L(f (X, w); Y )].

Computing ∇R(w) therefore boils down to computing the gradient of


L(θ; y), which is the gradient of an expectation.

12.2.4 Application to experimental design

In experimental design, we wish to minimize a function g(λ), which we


assume costly to evaluate. As an example, evaluating g(λ) could require
us to run a scientific experiment with parameters λ ∈ RQ . As another
example, in hyperparameter optimization, evaluating g(λ) would require
us to run a learning algorithm with hyperparameters λ ∈ RQ . Instead
of solving the problem arg minλ∈RQ g(λ), we can lift the problem to
probability distributions and solve arg minθ∈RM E(θ), where E(θ) =
Eλ∼pθ [g(λ)]. This requires the probability distribution pθ to assign
high probability to λ values that achieve small g(λ) value. Solving this
problem by stochastic gradient descent requires us to be able to compute
estimates of ∇E(θ). This can be done for instance with SFE explained
in Section 12.3, which does not require gradients of g, unlike implicit
differentiation explained in Chapter 11. This approach also requires us
to choose a distribution pθ over λ. For continuous hyperparameters,
a natural choice would be the normal distribution λ ∼ Normal(µ, Σ),
setting θ = (µ, Σ). Once we obtained θ by minimizing E(θ), we need a
way to recover λ. This can be done for example by choosing the mode of
the distribution, i.e., arg maxλ∈RQ pθ (λ), or the mean of the distribution
Eλ∼pθ (λ) [λ]. Of course, in the case of the normal distribution, they
coincide.
264 Differentiating through integration

12.3 Score function estimators, REINFORCE

12.3.1 Scalar-valued functions

The key idea of the score function estimator (SFE), also known as
REINFORCE, is to rewrite ∇E(θ) as an expectation. The estimator is
based on the logarithmic derivative identity
∇θ pθ (y)
∇θ log pθ (y) = ⇐⇒ ∇θ pθ (y) = pθ (y)∇θ log pθ (y).
pθ (y)
Using this identity, we obtain the following gradient estimator.

Proposition 12.1 (SFE for scalar-valued functions). Given a family


of distributions pθ on Y, for θ ∈ Θ, define
Z
E(θ) := EY ∼pθ [g(Y )] = pθ (y)g(y)dy,
Y

where Y ∈ Y ⊆ RM and g : Y → R. Then,

∇E(θ) = EY ∼pθ [g(Y )∇θ log pθ (Y )].

Proof. Z
∇E(θ) = ∇θ pθ (y)g(y)dy
Y
Z
= pθ (y)g(y)∇θ log pθ (y)dy
Y
= EY ∼pθ [g(Y )∇θ log pθ (Y )].

The gradient of the log-PDF w.r.t. θ, ∇θ log pθ (y), is known as the


score function, hence the estimator name. SFE is suitable when two
requirements are met: it is easy to sample from pθ and the score function
is available in closed form. Since the SFE gradient is an expectation,
we can use Monte-Carlo estimation to compute an unbiased estimator
of ∇E(θ):
N
1 X
bN (θ) :=
∇E(θ) ≈ γ g(Yi )∇θ log pθ (Yi ), (12.1)
N i=1
12.3. Score function estimators, REINFORCE 265

where Y1 , . . . , YN are sampled from pθ .


Interestingly, the gradient of g is not needed in this estimator.
Therefore, there is no differentiability assumption about g. This is why
SFE is useful when g is a discrete loss function or more generally a
blackbox function.

Example 12.1 (SFE with a language model). In a language model,


the probability of a sentence y = (y1 , . . . , yL ) is typically factored
using the chain rule of probability (see Section 10.1)

pθ (y) := pθ (y1 )pθ (y2 |y1 ) . . . pθ (yL |y1 , . . . , yL−1 ),

where pθ is modeled using a transformer or RNN. Note that the


probabilities are normalized by construction, so there is no need for
an explicit normalization constant. Thanks to this factorization, it is
easy to sample from pθ using ancestral sampling (see Section 10.5.3)
and the log-probability enjoys the simple expression

∇θ log pθ (y) = ∇θ log pθ (y1 ) + ∇θ log pθ (y2 |y1 ) + . . .


+ ∇θ log pθ (yL |y1 , . . . , yL−1 ).

This gradient is easy to compute, since the token-wise distributions


pθ (yj |y1 , . . . , yj−1 ) are typically defined using a softargmax. We can
therefore easily compute ∇E(θ) under pθ using SFE. This is for
instance useful to optimize an expected reward, in order to finetune
or align a language model (Ziegler et al., 2019).

Another example when ∇θ pθ (y) is available in closed form is in the


context of reinforcement learning, where pθ (y) is a Markov Decision
Process (MDP) and is called the policy. Applying the SFE leads to the
(vanilla) policy gradient method (Sutton et al., 1999) and can then be
used to compute the gradient of an expected cumulative reward. How-
ever, SFE is more problematic when used with the Gibbs distribution,
due to the explicit normalization constant.

Example 12.2 (SFE with a Gibbs distribution). The Gibbs distribu-


266 Differentiating through integration

tion is parameterized, for θ ∈ RY ,

pθ (y) := exp(θy /γ − A(θ)) = exp(θy /γ)/ exp(A(θ))

where we defined the log-partition function

A(θ) := log exp(θy /γ).


X

y∈Y

A typical parametrization is θy = f (x, y, w) with f the output of


network on a sample x with parameters w. We then have

log pθ (y) = θy /γ − A(θ),

so that
∇θ log pθ (y) = ey /γ − ∇A(θ).
We therefore see that ∇θ log pθ (y) crucially depends on ∇A(θ), the
gradient of the log-partition. This gradient is available for some
structured sets Y, see e.g. (Mensch and Blondel, 2018), but not in
general.

As another example, we apply SFE in Section 14.4 to derive the


gradient of perturbed functions.

Differentiating through both the distribution and the function

Suppose both the distribution and the function now depend on θ. When
g is scalar-valued and differentiable w.r.t. θ, we want to differentiate

E(θ) := EY ∼pθ [g(Y, θ)].

Using the product rule, we obtain

∇E(θ) = EY ∼pθ [g(Y, θ)∇θ log pθ (Y )] + EY ∼pθ [∇θ g(Y, θ)].

Differentiating through joint distributions

Suppose we now want to differentiate through

E(θ) := EY1 ∼pθ ,Y2 ∼qθ [g(Y1 , Y2 )].


12.3. Score function estimators, REINFORCE 267

The gradient is then given by


∇E(θ) = EY1 ∼pθ ,Y2 ∼qθ [(∇θ log pθ (Y1 ) + ∇ log qθ (Y2 ))g(Y1 , Y2 )],
which is easily seen by applying Proposition 12.1 on the joint distribution
ρθ := pθ ·qθ . The extension to more than two variables is straightforward.

12.3.2 Variance reduction


Bias and variance
Recall the definition of γ
bN in Eq. (12.1). SFE is an unbiased estimator,
meaning that
∇E(θ) = E[γ bN (θ)],
where the expectation is taken with respect to the N samples drawn.
Since the gradient is vector-valued, we need to define a scalar-valued
notion of variance. We do so by using the squared Euclidean distance
in the usual variance definition to define
bN (θ)] := E[∥γ
V[γ bN (θ) − ∇E(θ)∥22 ]
= E[∥γ
bN (θ)∥22 ] − ∥∇E(θ)∥22 .

The variance naturally goes to zero as N → ∞.

Baseline
SFE is known to suffer from high variance (Mohamed et al., 2020).
This means that this estimator may require us to draw many samples
from the distribution pθ to work well in practice. One of the simplest
variance reduction technique consists in shifting the function g with a
constant β, called a baseline, to obtain
∇E(θ) = EY ∼pθ [(g(Y ) − β)∇θ log pθ (Y )].
The reason this is still a valid estimator of ∇E(θ) stems from
∇θ pθ (Y )
 
EY ∼pθ [∇θ log pθ (Y )] = EY ∼pθ
pθ (Y )
= ∇θ EY ∼pθ [1]
= ∇θ 1
= 0,
268 Differentiating through integration

for any valid distribution pθ . The baseline β is often set to the running
average of past values of the function g, though it is neither optimal
nor does it guarantee to lower the variance (Mohamed et al., 2020).

Control variates
Another general technique are control variates. Let us denote the
expectation of a function h : RM → R under the distribution pθ as
H(θ) := EY ∼pθ [h(Y )].
Suppose that H(θ) and its gradient ∇H(θ) are known in closed form.
Then, for any γ ≥ 0, we clearly have
E(θ) = EY ∼pθ [g(Y )]
= EY ∼pθ [g(Y ) − γ(h(Y ) − H(θ))]
= EY ∼pθ [g(Y ) − γh(Y )] + γH(θ)
and therefore
∇E(θ) = ∇θ EY ∼pθ [g(Y ) − γh(Y )] + γ∇H(θ).
Applying SFE, we then obtain
∇E(θ) = EY ∼pθ [(g(Y ) − γh(Y ))∇θ log pθ (Y )] + γ∇H(θ).
Examples of h include a bound on f or a second-order Taylor expansion
of f , assuming that these approximations are easier to integrate than f
(Mohamed et al., 2020).

12.3.3 Vector-valued functions


It is straightforward to extend the SFE to vector-valued functions.
Proposition 12.2 (SFE for vector-valued functions). Given a family
of distributions pθ on Y, for θ ∈ Θ, define
Z
E(θ) := EY ∼pθ [g(Y )] = pθ (y)g(y)dy,
Y

where Y ∈ Y, g : Y → G. The JVP of E at θ ∈ Θ along v ∈ Θ is

∂E(θ)v = EY ∼pθ [⟨∇θ log pθ (Y ), v⟩g(Y )] ∈ G


12.3. Score function estimators, REINFORCE 269

and the VJP of E at θ ∈ Θ along u ∈ G is

∂E(θ)∗ u = EY ∼pθ [∇θ log pθ (Y )⟨u, g(Y )⟩] ∈ Θ

The Jacobian of E at θ ∈ Θ can then be written as

∂E(θ) = EY ∼pθ [g(Y ) ⊗ ∇θ log pθ (Y )],

where ⊗ denote the outer product.

Proof. The VJP of E at θ ∈ Θ along u ∈ Θ amounts to compute the


gradient of the scalar function

⟨E(θ), u⟩ = EY ∼pθ [⟨g(Y ), u⟩]

The expression of the VJP follows by using the SFE on the scalar valued
integrand ⟨g(Y ), u⟩. The JVP is obtained as the adjoint operator of the
VJP and the Jacobian follows.

Differentiating through both the distribution and the function

If θ now influences both the distribution and the function,

E(θ) := EY ∼pθ [g(Y, θ)],

then, we obtain

∂E(θ) = EY ∼pθ [g(Y, θ) ⊗ ∇θ log pθ (Y )] + EY ∼pθ [∂θ g(Y, θ)].

12.3.4 Second derivatives

Using the previous subsection with g(y, θ) = g(y)∇θ log pθ (θ), we easily
obtain an estimator of the Hessian.

Proposition 12.3 (SFE for the Hessian). Let us define the scalar-
valued function E(θ) := EY ∼pθ [g(Y )]. Then,

∇2 E(θ) =EY ∼pθ [g(Y )∇θ log pθ (Y ) ⊗ ∇θ log pθ (Y )]+


EY ∼pθ [g(Y )∇2θ log pθ (Y )].
270 Differentiating through integration

This can also be derived using the second-order log-derivative


1 1
∇2θ log pθ (y) = ∇2θ pθ (y) − ∇θ pθ (y) ⊗ ∇θ pθ (y)
pθ (y) pθ (y)2
so that
h i
∇2θ pθ (y) = pθ (y) ∇2θ log pθ (y) + ∇θ log pθ (y) ⊗ ∇θ log pθ (y) .

Link with the Bartlett identities


The Bartlett identities are expressions relating the moments of the score
function (gradient of the log-likelihood function). Using Proposition 12.1
with g(y) = 1 and Y pθ (y)dy = 1, we obtain
R

EY ∼pθ [∇θ log pθ (Y )] = 0, (12.2)

which is known as Bartlett’s first identity. Similarly, using Proposi-


tion 12.3, we obtain
EY ∼pθ [∇2θ log pθ (Y )] + EY ∼pθ [∇θ log pθ (Y ) ⊗ ∇θ log pθ (Y )]
=EY ∼pθ [∇2θ log pθ (Y )] + cov[log pθ (Y )] (12.3)
=0,
which is known as Bartlett’s second identity.

12.4 Path gradient estimators, reparametrization trick

As we saw previously, the main difficulty in computing gradients of


expectations arises when the parameters θ play a role in the distribution
pθ being sampled. The key idea of path gradient estimators (PGE),
also known as reparametrization trick, is to rewrite the expectation in
such a way that the parameters are moved from the distribution to the
function, using a change of variable.

12.4.1 Location-scale transforms


The canonical example of path gradient estimator is differentiating
through the expectation

E(µ, σ) := EU ∼Normal(µ,σ2 ) [g(U )],


12.4. Path gradient estimators, reparametrization trick 271

where g : R → R is a differentiable function. If we let Z ∼ Normal(0, 1),


it is easy to check that U = µ + σZ. We can therefore write

E(µ, σ) = EZ∼Normal(0,1) [g(µ + σZ)].

The key advantage is that we can now easily compute the derivatives
by mere application of the chain rule, since the parameters µ and σ are
moved from the distribution to the function:

E(µ, σ) = EZ∼Normal(0,1) [g ′ (µ + σZ)]
∂µ

E(µ, σ) = σ · EZ∼Normal(0,1) [g ′ (µ + σZ)].
∂σ
The change of variable
U := µ + σZ (12.4)

is called a location-scale transform. Such a transformation exists,


not only for the normal distribution, but for location-scale family
distributions, i.e., distributions parametrized by a location parameter µ
and a scale parameter σ > 0, such that U is distributed according to a
distribution in the same family as Z is distributed. Besides the normal
distribution, examples of location-scale family distributions include the
Cauchy distribution the uniform distribution, the logistic distribution,
the Laplace distribution, and Student’s t-distribution.
We can easily relate the cumulative distribution function (CDF)
and the probability density function (PDF) of Z to that of U , and
vice-versa.

Proposition 12.4 (CDF and PDF of location-scale family distributions).


Let FZ (z) := P(Z ≤ z) and fZ (z) := FZ′ (z). If U := µ + σZ, then
u−µ
 
FZ (z) = FU (µ + σz) ⇐⇒ FU (u) = FZ
σ
1 u−µ
 
fZ (z) = σfU (µ + σz) ⇐⇒ fU (u) = fZ .
σ σ
272 Differentiating through integration

Proof. We have

FZ (z) = P(Z ≤ z)
U −µ
 
=P ≤z
σ
= P(U ≤ µ + σz)
= FU (µ + σz)

and we obtain fZ (z) by differentiating FZ (z).

12.4.2 Differentiable transforms


We can generalize the idea of path gradient estimator (PGE) to any
change of variable
U := T (Z, θ),
where T : RM × RQ → RM is a differentiable transformation. For exam-
ple, if we gather µ and σ as θ := (µ, σ), we can write the location-scale
transform as
U = T (Z, θ) = µ + σZ.
We can derive the path gradient estimator for any such differentiable
transformation T .

Proposition 12.5 (Path gradient estimator). Let us define

E(θ) := EU ∼pθ [g(U )],

where U ∈ U ⊆ RM and g : RM → R is differentiable. Suppose


there is a differentiable transformation T : RM × RQ → RM such
that if Z ∼ p (where p does not depend on θ) and U := T (Z, θ),
then U ∼ pθ . Then, we have

E(θ) = EZ∼p [h(Z, θ)] = EZ∼p [g(T (Z, θ))],


12.4. Path gradient estimators, reparametrization trick 273

where h(z, θ) := g(T (z, θ)). This implies

∇E(θ) = EZ∼p [∇2 h(Z, θ)]


= EZ∼p [∂2 T (Z, θ)∗ ∇g(T (Z, θ))].

The path gradient estimator (a.k.a. reparametrization trick) gives an


unbiased estimator of ∇E(θ). It has however two key disadvantages.
First, it assumes that g is differentiable (almost everywhere), which
may not always be the case. Second, it assumes that g is well-defined
on RM , not on U, which could be problematic for some discrete loss
functions, such as the zero-one loss function or ranking loss functions.
As an example of differentiable transform, in machine learning, we
can sample Gaussian noise Z and make it go through a neural network
with parameters w to generate an image X := T (Z, w). In statistics,
many distributions are related to each other through differentiable
transforms, as we recall below.

Example 12.3 (Some differentiable transforms in statistics). We give


below a non-exhaustive list of differentiable transform examples.

• If X ∼ Normal(µ, σ 2 ), then exp(X) ∼ Lognormal(µ, σ 2 ).

• If U ∼ Uniform(0, 1), then − log(U )/λ ∼ Exponential(λ).

• If X1 , . . . , XN ∼ Exponential(λ) (i.i.d.), then


PN
i=1 Xi ∼
Gamma(N, λ).
 
• If Xi ∼ Gamma(αi , θ) for i ∈ [K], then PKX1 , . . . , PX
K
K

i=1
Xi i=1
Xi
Dirichlet(α1 , . . . , αK ).

12.4.3 Inverse transforms


The inverse transform method can be used for sampling from a proba-
bility distribution, given access to its associated quantile function.
Recall that the cumulative distribution function (CDF) associated with
a random variable Y is the function FY : R → [0, 1] defined by

FY (y) := P(Y ≤ y).


274 Differentiating through integration

The quantile function is then a function QY : [0, 1] → R such that


QY (π) = y for π = FY (y). Assuming FY is continuous and strictly
increasing, we have that QY is the inverse CDF,
QY (π) = FY−1 (π).
In the general case of CDF functions that are not strictly increasing,
the quantile function is usually defined as
QY (π) := inf{y ∈ R : π ≤ FY (y)}.
Given access to the quantile function QY (π) associated with a distribu-
tion p, inverse transform sampling allows us to sample from p by first
drawing a sample from the uniform distribution and then making
this sample go through the quantile function.

Proposition 12.6 (Inverse transform sampling). Suppose Y ∼ p,


where p is a distribution with quantile function QY . If U ∼
Uniform(0, 1), then QY (U ) ∼ p.

Proof. If π ≤ FY (t), then by definition of QY , QY (π) ≤ t. If π ≥ FY (t),


then by definition of QY , FY (QY (π)) ≥ π, so FY (QY (π)) ≥ FY (t) and
since a CDF is always non-decreasing, QY (π) ≥ t. Hence, we have,
QY (π) ≤ t ⇐⇒ π ≤ FY (t), so
P(QY (U ) ≤ t) = P(U ≤ FY (t))
= FY (t).
The CDFs of QY (U ) and Y coincide, hence they have the same distri-
bution.

If the quantile function is differentiable, we can therefore use it


as a transformation within the reparametrization trick. Indeed,
if Y ∼ pθ , where pθ is a distribution with parameter θ and quantile
function QY (π, θ), then we have
E(θ) = EY ∼pθ [g(Y )] = Eπ∼Uniform(0,1) [g(QY (π, θ))]
and therefore, by the reparametrization trick (Proposition 12.5),
∇E(θ) = Eπ∼Uniform(0,1) [∂2 QY (π, θ)∗ ∇g(QY (π, θ))].
12.4. Path gradient estimators, reparametrization trick 275

Example 12.4 (Examples of quantile functions). If


Y ∼ Exponential(λ), the CDF of Y is π = FY (y) = 1−exp(−λy) for
= − log(1−π)
y ≥ 0 and therefore the quantile function is QhY (π, λ)  i λ
.
If Y ∼ Normal(µ, σ 2 ), the CDF is FY (y) = 12 1 + erf σy−µ√
2
and

the quantile function is QY (π, θ) = µ + σ 2 · erf −1 (2π − 1), where
θ = (µ, σ). This therefore defines an alternative transformation to
the location-scale transformation in Eq. (12.4).

Note that, in the above example, the error function erf and its
inverse do not enjoy analytical expressions but autodiff packages usually
provide numerical routines to compute them and differentiate through
them. Nonetheless, one caveat of the inverse transform is that it indeed
requires access to (approximations of) the quantile function and its
derivatives, which may be difficult for complicated distributions.

12.4.4 Pushforward operators


Pushforward distributions
We saw so far that the reparametrization trick is based on using a
change of variables in order to differentiate an expectation w.r.t. the
parameters of the distribution. In this section, we further formalize that
approach using pushforward distributions.

Definition 12.1 (Pushforward distribution). Suppose Z ∼ p, where


p is a distribution over Z. Given a continuous map T : Z → U,
the pushforward distribution of p through T is the distribution q
according to which U := T (Z) ∈ U is distributed, i.e., U ∼ q.

Although not explicit in the above, the transformation T can depend


on some learnable parameters, for example if T is a neural network.
Intuitively, the pushforward distribution is obtained by moving the
position of all the points in the support of p. Inverse transform sampling
studied in Section 12.4.3 can be seen as performing the pushforward
of the uniform distribution through T = Q, where Q is the quantile
function. The Gumbel trick studied in Section 14.5 can be seen as a the
pushforward of Gumbel noise through T = argmax (a discontinuous
276 Differentiating through integration

function) and Gumbel noise can itself be obtained by pushing forward


the uniform distribution through T = − log(− log(·)) (Remark 14.2). In
a generative modeling setting, as we mentioned previously, we use the
pushforward of Gaussian noise through a parametrized transformation
X = T (Z, w) called a generator, typically a neural network.
A crucial aspect of the pushforward distribution q is that it can be
implicitly defined, meaning that we do not necessarily need to know
the explicit form of the associated PDF. In fact, it is easy to to sample
from q, provided that it is easy to sample from p:

U ∼ q ⇐⇒ Z ∼ p, U := T (Z).

Hence the usefulness of the pushforward distribution in generative


modeling. Furthermore, if p has associated PDF pZ , we can compute
the expectation of a function f according to q as
Z
EU ∼q [f (U )] = EZ∼p [f (T (Z))] = f (T (z))pZ (z)dz,
Z

even though we do not know the explicit form of the PDF of q.

Pushforward measures
More generally, we can define the notion of pushforward, in the language
of measures. Denote M(Z) the set of measures on a set Z. A measure
α ∈ M(Z), that has a density dα(z) := pZ (z)dz, can be integrated
against a funtcion f as
Z Z
f (z)dα(z) = f (z)pZ (z)dz.
Z Z

A measure α is called a probability measure if it is positive and satisfies


α(Z) = Z dα(z) = Z pZ (z)dz = 1. See Peyré and Cuturi (2019,
R R

Chapter 2) for a concise introduction.

Definition 12.2 (Pushforward operator and measure). Given a con-


tinuous map T : Z → U and some measure α ∈ M(Z), the push-
forward measure β = T♯ α ∈ M(U) is such that for all continuous
12.4. Path gradient estimators, reparametrization trick 277

functions f ∈ C(U)
Z Z
f (u)dβ(u) = f (T (z))dα(z).
U Z

Equivalently, for any measurable set A ⊂ U, we have

β(A) = α({z ∈ Z : T (z) ∈ A}) = α(T −1 (A)),

where T −1 (A) = {z ∈ Z : T (z) ∈ A}.

Importantly, the pushforward operator preserves positivity and mass,


therefore if α is a probability measure, then so is T♯ α. The pushforward
of a probability measure therefore defines a pushforward distribution
(since a distribution can be parametrized by a probability measure).

12.4.5 Change-of-variables theorem


We saw that a pushforward distribution associated with a variable U
is implicitly defined through a transform U := T (Z) and can be easily
sampled from as long as it is easy to sample Z. However, in some
applications (e.g., density estimation), we may want to know the PDF
associated with U . Assuming the transform T is invertible, we have
Z = T −1 (U ) and therefore for A ⊆ U, we have
Z
−1
P(U ∈ A) = P(Z ∈ T (A)) = pZ (z)dz.
T −1 (A)

Using the change-of-variables theorem from multivariate calculus,


assuming T −1 is available, we can give an explicit formula for the PDF
of the pushforward distribution, see e.g. (Schwartz, 1954; Taylor, 2002).

Proposition 12.7 (PDF of the pushforward distribution). Suppose


Z ∼ p, where p is a distribution over Z, with PDF pZ . Given
a diffeomorphism T : Z → U (i.e., an invertible and differen-
tiable map), the pushforward distribution of p through T is the
distribution q such that U := T (Z) ∼ q and its PDF is

qU (u) = | det(∂T −1 (u))|pZ (T −1 (u)),


278 Differentiating through integration

where ∂T −1 (u) is the Jacobian of T −1 : U → Z.

Using this formula, we obtain


Z
P(U ∈ A) = pU (u)du
ZA
= | det(∂T −1 (u))|pZ (T −1 (u))du.
A

Using the inverse function theorem (Theorem 11.5), we then have

∂T −1 (u) = (∂T (T −1 (u)))−1 ,

under the assumption that T (z) is continuously differentiable and


has invertible Jacobian ∂T (z). Normalizing flows are parametrized
transformations T designed such that T −1 and its Jacobian ∂T −1 are
easy to compute; see e.g. Kobyzev et al. (2019) and Papamakarios et al.
(2021) for a review.

12.5 Stochastic programs

A stochastic program is a program that involves some form of random-


ness. In a stochastic program, the final output, as well as intermediate
variables, may therefore be random variables. In other words, a stochas-
tic program induces a probability distribution over program outputs,
as well as over execution trajectories.

12.5.1 Stochastic computation graphs


A stochastic program can be represented by a stochastic computation
graph as originally introduced by Schulman et al. (2015). Departing from
that work, our exposition explicitly supports two types of intermediate
operations: sampling from a conditional distribution or evaluating a
function. These operations can produce either deterministic variables
or random variables.

Function and distribution nodes


Formally, we define a stochastic computation graph as a directed acyclic
graph G = (V, E), where V = Vf ∪ Vp , Vf is the set of function nodes
12.5. Stochastic programs 279

and Vp is the set of distribution nodes. Similarly to computation graphs


reviewed in Section 4.1.3, we number the nodes as V = {0, 1, . . . , K}.
Node 0 corresponds to the input s0 ∈ S0 , which we assume to be deter-
ministic. It is the variable with respect to which we wish to differentiate.
Node K corresponds to the program output SK ∈ SK , which we assume
to be a random variable. A node k ∈ {1, . . . , K} can either be a func-
tion node k ∈ Vf with an associated function fk or a distribution
node k ∈ Vp , with associated conditional distribution pk . A stochastic
program has at least one distribution node, the source of randomness.
Otherwise, it is a deterministic program. As for computation graphs,
the set of edges E is used to represent dependencies between nodes. We
denotes the parents of node k by pa(k).

Deterministic and random variables

We distinguish between two types of intermediate variables: deter-


ministic variables sk and random variables Sk . Therefore, a dis-
tribution pk or a function fk may receive both types of variables
as conditioning or input. It is then convenient to split pa(k) as
pa(k) = determ(k) ∪ random(k), where we defined the determinis-
tic parents determ(k) := {i1 , . . . , ipk } and the random parents
random(k) := {j1 , . . . , jqk }. Therefore, si1 , . . . , sipk are the determinis-
tic parent variables and Sj1 , . . . , Sjqk are the random parent variables,
of node k.

Executing a stochastic program

We assume that nodes 0, 1, . . . , K are in topological order (if this is


not the case, we need to perform a topological sort). Given parent
variables si1 , . . . , sipk and Sj1 , . . . , Sjqk , a node k ∈ {1, . . . , K} produces
an output as follows.

• If k ∈ Vp (distribution node), the output is

Sk ∼ pk (· | sdeterm(k) , Srandom(k) )
⇐⇒ Sk ∼ pk (· | si1 , . . . , sipk , Sj1 , . . . , Sjqk )
280 Differentiating through integration

Note that technically pk is the distribution of Sk conditioned on its


parents, not the distribution of Sk . Therefore, we should in princi-
ple write Sk | sdeterm(k) , Srandom(k) ∼ pk (· | sdeterm(k) , Srandom(k) ).
We avoid this notation for conciseness and for symmetry with
function nodes.
Contrary to a function node, a distribution node can have no
parents. That is, if k ∈ Vp , it is possible that pa(k) = ∅. A good
example would be a parameter-free noise distribution.

• If k ∈ Vf (function node), the output is in general

Sk := fk (sdeterm(k) , Srandom(k) )
:= fk (si1 , . . . , sipk , Sj1 , . . . , Sjqk )

and in the special case qk = |random(k)| = 0, the output is

sk := fk (sdeterm(k) )
:= fk (si1 , . . . , sipk ).

Unless the associated conditional distribution pk is a delta distribution,


that puts all the probability mass on a single point, the output of a
distribution node k ∈ Vp is necessarily a random variable Sk ∈ Sk .
For function nodes k ∈ Vf , the output of the function fk is a random
variable Sk ∈ Sk if at least one of the parents of k produces a random
variable. Otherwise, if all parents of k produce deterministic variables,
the output of fk is a deterministic variable sk ∈ Sk .
The entire procedure is summarized in Algorithm 12.1. We emphasize
that SK = f (s0 ) ∈ SK is a random variable. Therefore, a stochastic
program (implicitly) induces a distribution over SK , and also over
intermediate random variables Sk . Executing the stochastic program
allows us to draw samples from that distribution.

Special cases
If all nodes are function nodes, we recover computation graphs, reviewed
in Section 4.1.3. If all nodes are distribution nodes, we recover Bayesian
networks, reviewed in Section 10.5.
12.5. Stochastic programs 281

Algorithm 12.1 Executing a stochastic program


Nodes: 1, . . . , K in topological order, where node k is either a
function fk or a conditional distribution pk
Input: input s0 ∈ S0
1: for k := 1, . . . , K do
2: Retrieve pa(k) = determ(k) ∪ random(k)
3: if k ∈ Vp then ▷ Distribution node
4: Sk ∼ pk (·|sdeterm(k) , Srandom(k) )
5: else if k ∈ Vf then ▷ Function node
6: if |random(k)| = ̸ 0 then
7: Sk := fk (sdeterm(k) , Srandom(k) ) ▷ Output is a R.V.
8: else if |random(k)| = 0 then
9: sk := fk (sdeterm(k) ) ▷ Output is deterministic
10: Output: f (s0 ) := SK ∈ SK

12.5.2 Examples
We now present several examples that illustrate our formalism. We use
the legend below in the following illustrations.

Deterministic Stochastic
Function Sampler
variable variable

• Example 1 (SFE estimator):

S1 ∼ p1 (· | s0 )
S2 := f2 (S1 )
E(s0 ) := E[S2 ]
∇E(s0 ) = ES1 [f2 (S1 )∇s0 log p1 (S1 | s0 )]

• Example 2 (Pathwise estimator):


282 Differentiating through integration

S1 ∼ p1
S2 := f2 (S1 , s0 )
E(s0 ) := E[S2 ]
∇E(s0 ) = ES1 [∇s0 f2 (S1 , s0 )]

• Example 3 (SFE estimator + chain rule):

s1 := f1 (s0 )
S2 ∼ p2 (· | s1 )
S3 := f3 (S2 )
E(s0 ) := E[S3 ]
∇E(s0 ) = ∂f (s0 )∗ ES2 [f3 (S2 )∇s1 log p2 (S2 | s1 )]

• Example 4:
12.5. Stochastic programs 283

s1 := f1 (s0 )
s2 := f2 (s0 )
S3 ∼ p3 (· | s1 )
S4 ∼ p4 (· | s2 , S3 )
S5 := f5 (S4 )
E(s0 ) := E[S5 ] = ES3 [ES4 [f5 (S4 )]]
∇E(s0 ) = ES3 [∂f1 (s0 )∗ ∇s1 log p(S3 | s1 )ES4 [f5 (S4 )]]
+ ES3 [ES4 [∂f2 (s0 )∗ ∇s2 log p4 (S4 |s2 , S3 )f5 (S4 )]]

As can be seen, the gradient expressions can quickly become quite


complicated, demonstrating the merits of automatic differentiation in
stochastic computation graphs.

12.5.3 Unbiased gradient estimators

The output of a stochastic program is a random variable

SK := f (s0 ).

It implicitly defines a probability distribution p(·|s0 ) such that SK ∼


p(·|s0 ). Executing the stochastic program once gives us an i.i.d. sample
from p(·|s0 ).
Since derivatives are defined for deterministic variables, we need a
way to convert a random variable to a deterministic variable. One way
to do so is to consider the expected value (another way would be the
mode)
E(s0 ) := E[SK ] = E[f (s0 )] ∈ conv(SK ),
where the expectation is over SK ∼ p(·|s0 ) or equivalently over the
intermediate random variables Sk

Sk ∼ pk (·|sdeterm(k) , Srandom(k) ),

for k ∈ Vp (the distribution nodes). We then wish to compute the


gradient or more generally the Jacobian of E(s0 ).
284 Differentiating through integration

If all nodes in the stochastic computation graph are function nodes,


we can estimate the gradient of E(s0 ) using the pathwise estimator
a.k.a. reparametrization trick (Section 12.4). This is the approach taken
by Kingma and Welling (2013) and Rezende et al. (2014).
If all nodes in the stochastic computation graph are distribution
nodes, we can use the SFE estimator (Section 12.3). Schulman et al.
(2015) propose a surrogate loss so that using autodiff on that loss
produces an unbiased gradient of the expectation, using the SFE esti-
mator. Foerster et al. (2018) extend the approach to support high-order
differentiation. Krieken et al. (2021) further extend the approach by
supporting different estimators per node, as well as control variates.

Converting distribution nodes into function nodes and vice-versa

Our formalism uses two types of nodes: distribution nodes with asso-
ciated conditional distribution pk and function nodes with associated
function fk . It is often possible to convert between node types.
Converting a distribution node into a function node is exactly the
reparametrization trick studied in Section 12.4. We can use transforma-
tions such as the location-scale transform or the inverse transform.
Converting a function node into a distribution node can be done
using the change-of-variables theorem, studied in Section 12.4.5, on a
pushforward distribution.
Because the pathwise estimator has lower variance than SFE, this is
the method of choice when the fk functions are available. The conversion
from distribution node to function node and vice-versa is illustrated in
Fig. 12.1.

12.5.4 Local vs. global expectations

A stochastic computation graph can be seen as a stochastic process,


a collection of random variables Sk , indexed by k, the position in the
topological order. However, random variables are incompatible with
autodiff. Replacing random variables by their expectation can be seen
as a way to make them compatible with autodiff. Two strategies are
then possible.
12.5. Stochastic programs 285

Transformation
(location-scale transform, inverse transform)

Pushforward (implicit) distribution

Parametric (explicit) distribution

Score Function Estimator Path Gradient Estimator


(SFE) (PGE)

Change-of-variables theorem

Figure 12.1: It is sometimes possible to convert a distribution node to a function


node and vice-versa using a suitable transformation.

As we saw in the previous section, a strategy is to consider the


expectation of the last output SK . This strategy corresponds to a
global smoothing. The two major advantages are that i) we do not
need to assume that fk+1 is well-defined on conv(Sk ) and ii) this induces
a probability distribution over program executions. This is for instance
useful to compute the variance of the program. The gradient of the
program’s expected value can be estimated by the reparametrization
trick or by the SFE, depending on the type of nodes used.
A second strategy is to replace an intermediate random variable
Sk ∈ Sk , for k ∈ {1, . . . , K}, by its expectation E[Sk ] ∈ conv(Sk ). This
strategy corresponds to a local smoothing. A potential drawback of
this approach is that E[Sk ] belongs to conv(Sk ), the convex hull of Sk .
Therefore, the function fk+1 in which E[Sk ] is fed must be well-defined
on conv(Sk ), which may not always be the case. In the case of control
flows, another disadvantage is computational. We saw in Section 5.6 and
Section 5.7 that using a soft comparison operator within a conditional
statement induces a distribution on a binary or categorical random
variable, corresponding to the branch to be selected. A conditional
statement can then be locally smoothed out by replacing the random
286 Differentiating through integration

variable by its expectation i.e., a convex combination of all the


branches. This means that, unless the distribution has sparse support,
all branches must be evaluated.

12.6 Differential equations

12.6.1 Parameterized differential equations


From residual networks to neural ODEs
Starting from s0 := x, residual networks, reviewed in Section 4.5, iterate
for k ∈ {1, . . . , K}
sk := sk−1 + hk (sk−1 , wk ).
A residual network can be seen as parameterizing incremental discrete-
time input changes (hence the name “residual”)
sk − sk−1 = hk (sk , wk ).
Chen et al. (2018) proposed to parameterize continuous-time (instan-
taneous) changes instead. They considered the evolution s(t) of the
inputs in continuous time driven by a function h(t, s, w) parameterized
by w, starting from x. Formally, the evolution s(t) is the solution of
the ordinary differential equation (ODE)
s(0) = x
s′ (t) = h(t, s(t), w) t ∈ [0, T ] (12.5)
Here, s′ (t) is the vector of derivatives of s as defined in Remark 2.4, and
T denotes a final time for the trajectory. The output of such a neural
ODE (Chen et al., 2018) is then f (x, w) := s(T ). Alternatively, the
output can be seen as the solution of an integration problem
Z T
f (x, w) = s(T ) = x + h(t, s(t), w)dt. (12.6)
0
Differential equations like Eq. (12.5) arise in many contexts beyond neu-
ral ODEs, ranging from modeling physical systems to pandemics (Braun
and Golubitsky, 1983). Moreover, the differential equation presented
in Eq. (12.5) is just an example of an ordinary differential equation,
while controlled differential equations or stochastic differential equations
can also be considered.
12.6. Differential equations 287

Existence of a solution
First and foremost, the question is whether s(t) is well-defined. For-
tunately, the answer is positive under mild conditions, as shown by
Picard-Lindelöf’s theorem recalled below (Butcher, 2016, Theorem 16).

Theorem 12.1 (Exsistence and uniqueness of ODE solutions). If h :


[0, T ] × S → S is continuous in its first variable and Lipschitz-
continuous in its second variable, then there exists a unique differ-
entiable map s : [0, T ] → S satisfying

s(0) = s0
s′ (t) = h(t, s(t)) t ∈ [0, T ],

for some given s0 ∈ S.

For time-independent linear functions h(t, s) = As, the integral


in Eq. (12.6) can be computed in closed form as

st = exp(tA)(s0 ),

where exp(A) is the matrix exponential. Hence, the output s(T ) can
be expressed as a simple function of the parameters (A in this case).
However, generally, we do not have access to such analytical solutions,
and, just as for solving optimization problems in Chapter 11, we need
to resort to some iterative algorithms.

Integration methods
To numerically solve an ODE, we can use integration methods,
whose goal is to build a sequence sk that approximates the solution
s(t) at times tk . The simplest integration method is the explicit Euler
method, that approximates the solutions between times tk−1 and tk as
Z tk
s(tk−1 ) − s(tk ) = h(t, s(t), w)dt
tk−1

≈ δk h(tk−1 , s(tk−1 ), w),

for a time-step
δk := tk − tk−1 .
288 Differentiating through integration

The resulting integration scheme consists in computing starting from


s0 = x, for k ∈ {1, . . . , K},

sk := sk−1 + δk h(tk−1 , sk−1 , w).

Assimilating δk h(tk−1 , sk−1 , w) with hk (sk−1 , wk ), we find that residual


networks are essentially the discretization of a neural ODE by an
explicit Euler method; more precisely, a non-autonomous neural ODEs,
see e.g. (Davis et al., 2020).
Euler’s forward method is only one integration method among
many. To cite a few, there are implicit Euler methods, semi-implicit
methods, Runge-Kutta methods, linear multistep methods, etc. See,
e.g., Gautschi (2011) for a detailed review. The quality of an integration
method is measured by its consistency and its stability (Gautschi, 2011).
These concepts naturally influence the development of evaluation and
differentiation techniques for ODEs. We briefly summarize them below.
Given a fixed time interval δk = δ and K = ⌈T /δ⌉ points, an
integration method is consistent of order k if ∥sk − s(kδ)∥ = O(δ k )
as δ → 0 and therefore k → +∞. The higher the order k, the fewer points
we need to reach an approximation error ε on the points considered.
The term ∥sk − s(kδ)∥ = O(δ k ) is reminiscent of the error encountered
in finite differences (Chapter 7) and is called the truncation error.
The (absolute) stability of a method is defined by the set of time-
steps such that the integration method can integrate s′ (t) = λs(t) for
some λ ∈ C without blowing up as t → +∞.

12.6.2 Continuous adjoint method

Since different parameters w induce different trajectories associated to


h(t, s, w) in Eq. (12.5), we may want to select one of these trajectories
by minimizing some criterion. For example, we may consider selecting
w ∈ W by minimizing a loss L on the final point of the trajectory,

min L(f (x, w), y), (12.7)


w∈W

where Z T
f (x, w) := s(T ) = x + h(t, s(t), w)dt.
0
12.6. Differential equations 289

To solve such problems, we need to access gradients of ℓ composed with


f through VJPs of the solution of the ODE. The VJPs can actually
be characterized as solutions of an ODE themselves thanks to the
continuous time adjoint method (Pontryagin, 1985), presented
below, and whose proof is postponed to Section 12.6.6.

Proposition 12.8 (Continuous-time adjoint method). Consider a func-


tion h : [0, T ]×S ×W → S, continuous in its first variable, Lipschitz-
continuous and continuously differentiable in its second variable.
Assume that ∂3 h(t, s, w) exists for any t, s, w, and is also continu-
ous in its first variable, Lipschitz-continuous in its second variable.
Denote s : S → S the solution of the ODE

s(0) = x
s′ (t) = h(t, s(t), w) t ∈ [0, T ],

and f (x, w) = s(T ) the final state of the ODE at time T .


Then, the function f is differentiable, and for an output direction
u ∈ S, its VJP along u is given by

∂f (x, w)∗ u = (r(0), g)

for Z T
g= ∂3 h(t, s(t), w)∗ r(t)dt
0
and for r solving the adjoint (backward) ODE

r ′ (t) = −∂2 h(t, s(t), w)∗ r(t)


r(T ) = u.

In particular, the gradient ∇(L ◦ f )(x, w) for L : S → R a


differentiable loss is obtained by solving the adjoint ODE with
r(T ) = ∇L(s(T )).

Example 12.5 (Fitting data through the solution of an ODE). As an


illustrative example, we can consider optimizing the parameters of
an ODE to fit some data points. Namely, we may seek a continuous
290 Differentiating through integration

1.0 s(t; w1)


s(t; w2)
s(t; w * )

0.5

0.0
0.5 1.0 1.5
Figure 12.2: Finding the optimal parameters of an ODE to fit some observed data.
The dots represent the trajectories of a dynamical system observed at regular times
(time is represented here by a gradient color, the lighter the color, the larger the
time). Each line represents the solution of an ODE given by some hyperparameters
w. The objective is to find the hyperparameters of the ODE such that its solution
fits the data points. Green and orange lines fail to do so while the blue line fits
the data. To compute such parameters w, we need to backpropagate through the
solution of the ODE.

time solution z(t; w) of a modified Lotka Volterra ODE


!
′ αz1 (t; w) − βz1 (t; w)z2 (t; w)
z (t; w) = + c,
−γz2 (t; w) + δz1 (t; w)z2 (t; w)

for w = (α, β, γ, δ, c), that fits some observations z1 , . . . , zT . The


optimization problem consists then of
T
min (z(tj ; w) − zj )2 ,
X
w
τ =1

and requires backpropagating through the solution z(·; w) of the


ODE w.r.t. to its candidate parameters w. Fig. 12.2 illustrates such
a problem with varying candidate parameters

12.6.3 Gradients via the continuous adjoint method


Proposition 12.8 gives a formal definition of the gradient. However, just
as computing the mapping f (x, w) itself, computing its VJP or the
12.6. Differential equations 291

gradient of L ◦ f requires solving an integration problem. Note that


the integration of r(t) in Proposition 12.8 requires also values of s(t).
Therefore, we need to integrate both r(t) and s(t). Such an approach
is generally referred as optimize-then-discretize because we first
formulate the gradient in continuous time (the “optimize part”) and
then discretize the resulting ODE.

Simple discretization scheme


A first approach consists in defining a backward discretization scheme
that can approximate s(t) backward in time. Namely, by defining
σ(t) = s(T − t), ρ(t) = r(T − t), and γ(t) = tT ∂3 h(τ, s(τ ), w)∗ r(τ )dτ ,
R

the derivative of L ◦ f is given by (ρ(T ), γ(T )). The functions σ, ρ, γ


are solutions of a standard ODE

σ(0) = s(T ), σ ′ (t) = −h(T − t, σ(t), w),


ρ(0) = ∇L(s(T )), ρ′ (t) = ∂2 h(T − t, σ(t), w)∗ ρ(t),
γ(0) = 0, γ ′ (t) = ∂3 h(T − t, σ(t), w)∗ ρ(t).

The above ODE can then be solved by any integration method. Note,
however, that it requires first computing s(T ) and ∇L(s(T )) by an
integration method. The overall computation of the gradient using
an explicit Euler method to solve forward and backward ODEs is
summarized in Algorithm 12.2.
Algorithm 12.2 naturally looks like the reverse mode of autodiff for
a residual neural networks with shared weights. A striking difference
is that the intermediate computations sk are not kept in memory and,
instead, new variables ŝk are computed along the backward ODE. One
may believe that by switching to continuous time, we solved the memory
issues encountered in reverse-mode autodiff. Unfortunately, this comes
at the cost of numerical stability. As we use a discretization scheme
to recompute the intermediate states backward in time through ŝk in
Algorithm 12.2, we accumulate some truncation errors.
To understand the issue here, consider applying Algorithm 12.2
repeatedly on the same parameters but using ŝ0 instead of s0 = x each
time. In the continuous realm, σ(T ) = s(0). But after discretization,
ŝ0 ≈ σ(T ) does not match s0 . Therefore, by applying Algorithm 12.2
292 Differentiating through integration

Algorithm 12.2 Gradient computation via continuous adjoint method


with Euler explicit discretization
1: Functions: h : [0, T ] × S × W → R, L : S → R
2: Inputs: input x, parameters w, number of discretization steps K.
3: Set discretization step δ = T /K, denote hk (s, w) = h(kδ, s, w).
4: Set s0 := x
5: for k := 1, . . . , K do ▷ Forward discretization
6: Compute sk := sk−1 + δhk−1 (sk−1 , w).
7: Compute u := ∇L(sK ).
8: Initialize rK := u, ŝK = sK , gK = 0
9: for k := K, . . . , 1 do ▷ Backward discretization
10: Compute ŝk−1 := ŝk − δhk (ŝk , w)
11: Compute rk−1 := rk + δ∂2 hk (ŝk , w)∗ rk
12: Compute gk−1 := gk + δ∂3 hk (ŝk , w)∗ rk
13: Output: (r0 , g0 ) ≈ ∇(L ◦ f )(x, w)

with s0 = ŝ0 , we would not get the same output even if in continu-
ous time we naturally should have. This phenomenon is illustrated in
Fig. 12.3. It intuitively shows why Algorithm 12.2 induces some noise
in the estimation of the gradient.

Multiple shooting scheme

An alternative approach consists in integrating both the forward and


backward ODEs jointly. Namely, we may solve an ODE with boundary
values

s′ (t) = h(t, s(t), w), s(0) = x,


r ′ (t) = −∂2 h(t, s(t), w)∗ r(t), r(T ) = ∇L(s(T ))
′ ∗
g (t) = −∂3 h(t, s(t), w) r(t), g(T ) = 0,

by means of a multiple shooting method or a collocation method (Stoer


et al., 1980). This approach still requires ∇L(s(T )) to be approximated
first.
12.6. Differential equations 293

Forward approx. Solution


of the ODE of ODE
Forward
discretization.
error
Total forward
discretization
error
...

Total backward
discretization
error ...
Local forward
discretization.
Backward approx. error
of the ODE

Figure 12.3: Forward and backward discretizations when using the continuous
adjoint method.

12.6.4 Gradients via reverse-mode on discretization


A simpler approach consists in replacing the objective in Eq. (12.7) by
its version discretized using some numerical method, such as an Euler
forward discretization scheme. That is, we seek to solve
min L(sK ) where sk = sk−1 + δh(kδ, sk−1 , w) k ∈ {1, . . . , K},
w∈W

with s0 = 0 and δ some discretization step. Gradients of the objective


can be computed by automatic differentiation. That approach is often
referred to as discretize-then-optimize. At first glance, this approach
may suffer from very high memory requirements. Indeed, to get an
accurate solution of the ODE, a numerical integration method may
require K to be very large. Since a naive implementation of reverse-mode
automatic differentiation has a memory that scales linearly with K,
computing the gradient by a discretize-then-optimize method could be
prohibitive. However, the memory requirements may easily be amortized
using checkpointing, as explained in Section 8.5; see also (Gholaminejad
et al., 2019).
As for the optimize-then-discretize method, we still accumulate
some truncation errors in the forward discretization process. This dis-
cretization error occurs when computing the gradient in reverse-mode
294 Differentiating through integration

too. The discretize-then-optimize method can be seen as computing


gradients of a surrogate objective. For that objective, the gradients are
correct and well-defined. However, they may not match the gradients of
the true ODE formulation.
To compare the discretize-then-optimize and optimize-then-discretize
approaches, Gholaminejad et al. (2019) compared their performance on
an ODE whose solution can be computed analytically by selecting h
to be linear in s. The authors observed that discretize-then-optimize
generally outperformed optimize-then-discretize. A middle ground can
actually be found by using reversible differentiation schemes.

12.6.5 Reversible discretization schemes

Our exposition of the optimize-then-discretize or discretize-then-optimize


approaches used a simple Euler explicit discretization scheme. However,
for both approaches, we could have used other discretization schemes
instead, such as reversible discretization schemes.
A reversible discretization scheme is a discretization scheme such
that we have access to a closed-form formula for the inverse of its
discretization step. Formally, a discretization method M builds an
approximation (sk )K k=1 of the solution of an ODE s (t) = h(t, s(t)) on

an interval [0, T ] by computing for k ∈ (1, . . . , K)

tk , sk , ck = M(tk−1 , sk−1 , ck−1 ; h, δ), (12.8)

where δ > 0 is some fixed discretization step, tk is the time step (typically
tk = tk−1 +δ), sk is the approximation of s(tk ), and ck is some additional
context variables used by the discretization method to build the iterates.
An explicit Euler method does not have a context, but just as an
optimization method may update some internal states, a discretization
method can update some context variable. The discretization scheme
in Eq. (12.8) is a forward discretization scheme as we took a positive
discretization step. By taking a negative discretization step, we obtain
the corresponding backward discretization scheme, for k ∈ (K, . . . , 1),

tk−1 , sk−1 , ck−1 = M(tk , sk , ck ; h, −δ).


12.6. Differential equations 295

A discretization method is reversible if we have access to M−1 to


recompute the inputs of the discretization step from its outputs,
tk−1 , sk−1 , ck = M−1 (tk , sk , ck ; h, δ).
A reversible discretization method is symmetric if the backward dis-
cretization scheme is exactly the inverse of the forward discretization
scheme, i.e.,
M(tk , sk , ck ; h, −δ) = M−1 (tk , sk , ck ; h, δ).
The explicit Euler method is clearly not symmetric and a priori not
reversible, unless we can solve for yk−1 , the equation yk = yk−1 −
δf (yk−1 ).

Leapfrog method
The (asynchronous) leapfrog method (Zhuang et al., 2021; Mutze,
2013) on the other hand is an example of symmetric reversible discretiza-
tion method. For a constant discretization step δ, given tk−1 , sk−1 , ck−1
and a function h, it computes
δ
t̄k−1 := tk−1 +
2
δ
s̄k−1 := sk−1 + ck−1
2
c̄k−1 := h(t̄k−1 , s̄k−1 )
δ
tk := t̄k−1 +
2
δ
sk := s̄k−1 + c̄k−1
2
ck := 2c̄k−1 − ck−1
M(tk−1 , sk−1 , ck−1 ; h, δ) := (tk , sk , ck ).
One can verify that we indeed have M(tk , sk , ck ; h, −δ) = (tk−1 , sk , ck ).
By using a reversible symmetric discretization scheme in the optimize-
then-discretize approach, we ensure that, at the end of the backward
discretization pass, we recover exactly the original input. Therefore, by
repeating forward and backward discretization schemes we always get
the same gradient, which was not the case for an Euler explicit scheme.
296 Differentiating through integration

By using a reversible discretization scheme in the discretize-then-


optimize method, we address the memory issues of reverse mode autodiff.
As explained in Section 8.6, we can recompute intermediate values during
the backward pass rather than storing them.

Momentum residual networks

In the leapfrog method, the additional variables ck may actually be


interpreted as velocities of a system whose acceleration is driven by
the given function, that is, s′′ (t) = h(t, s(t), w). Such an interpretation
suggests alternatives to the usual neural ODE paradigm. For instance,
momentum neural networks (Sander et al., 2021b), can be inter-
preted as the discretization of a second-order ordinary differential
equations, which are naturally amenable to reversible differentiation
schemes with a low memory footprint.

12.6.6 Proof of the continuous adjoint method

In the following, we denote s(t, x, w) the solution of the ODE at time t


given the input x and the parameters w. We focus here on the formula-
tion of the VJP. The proof relies on the existence of partial derivatives
of s(t, x, w), which we do not cover here and refer to, e.g., Pontryagin
(1985) for a complete proof of such facts given the assumptions.
We use the ODE constraint to introduce adjoint variables, this time
in the form of a continuously differentiable function r. For any such a
function r, we have

⟨f (x, w), u⟩ = ⟨s(T, x, w), u⟩


Z T
+ ⟨r(t), h(t, s(t, x, w), w) − ∂t s(t, x, w)⟩dt,
0

using Leibniz notations such as ∂t s(t, x, w) = ∂1 s(t, x, w). The VJPs


12.6. Differential equations 297

then decompose as
∂w f (x, w)∗ [u]
= ∂w s(T, x, w)∗ u
Z T
+ (∂w s(t, x, w)∗ ∂s∗ h(t, s(t, x, w), w)∗ − ∂wt
2
s(t, x, w)∗ )r(t)dt
0
Z T
+ ∂w h(t, s(t, x, w), w)∗ r(t)dt,
0
∂x f (x, w)∗ [u]
= ∂x s(T, x, w)∗ u
Z T
+ (∂x s(t, x, w)∗ ∂s∗ h(t, s(t, x, w), w)∗ − ∂xt
2
s(t, x, w)∗ )r(t)dt
0

Here the second derivative terms ∂wt 2 s(t, x, w)∗ r, ∂ 2 s(t, x, w)∗ r cor-
xt
respond to second derivatives of ⟨s(t, x, w), r⟩. Since the Hessian is
symmetric (Schwartz’s theorem presented in Proposition 2.10), we can
swap the derivatives in t and w or x. Then, to express the gradient
uniquely in terms of first derivatives of s, we use an integration by part
to have for example
Z T Z T
2
∂wt s(t, x, w)∗ r(t)dt = 2
∂tw s(t, x, w)∗ r(t)dt
0 0
= (∂w s(T, x, w)∗ r(T ) − ∂w s(0, x, w)∗ r(0))
Z T
− ∂w s(t, x, w)∗ r(t)∗ ∂t r(t)dt.
0
Since s(0) = x, we have ∂w s(0, x, w)∗ r(0) = 0. The VJP w.r.t. w can
then be written as
∂w f (x, w)∗ [u]
= ∂w s(T, x, w)∗ [u − r(T )]
Z T
+ ∂w s(t, x, w)∗ [∂s h(t, s(t, x, w), w)∗ r(t) + ∂t r(t)]dt
0
Z T
+ ∂w h(t, s(t, x, w), w)∗ r(t)dt.
0
By choosing r(t) to satisfy the adjoint ODE
∂t r(t) = −∂s h(t, s(t, x, w), w)∗ r(t), r(T ) = u,
298 Differentiating through integration

the expression of the VJP simplifies as


Z T

∂w f (x, w) [u] = ∂w h(t, s(t, x, w), w)∗ r(t)dt.
0

For the VJP w.r.t. to x, we can proceed similarly. Using an integration


by part, we have, this time, ∂x s(0, x, w)∗ r(0) = r(0) since s(0) = x.
Choosing the same curve r(t) satisfying the adjoint ODE we get

∂x f (x, w)∗ [u] = r(0).

The existence of a curve r solution of the backward ODE can easily be


shown from Picard Lindelöf’s theorem and the assumptions.

12.7 Summary

• We studied how to differentiate integrals, with a focus on expec-


tations and solutions of a differential equation.

• For differentiating through expectations, we studied two main


methods: the score function estimator (SFE, a.k.a. REINFORCE)
and the path gradient estimator (PGE, a.k.a. reparametrization
trick).

• The SFE is suitable when it is easy to sample from the distribution


and its log-PDF is explicitly available. It is an unbiased estimator,
but is known to suffer from high variance.

• The PGE is suitable for pushforward distributions, distributions


that are implicitly defined through a transformation, or a se-
quence of them. These distributions can be easily sampled from,
by injecting a source of randomness (such as noise) through the
transformations. An unbiased, low-variance estimator of the gra-
dient of their expectation is easily obtained, provided that we can
interchange integration and differentiation.

• If we have an explicit distribution, we can sometimes convert it


to an implicit distribution, thanks to the location-scale trans-
formation or the inverse transformation.
12.7. Summary 299

• Conversely, if we have an implicit distribution, we can convert it to


an explicit distribution using the change-of-variables theorem.
However, this formula requires to compute the determinant of
an inverse Jacobian, and is computationally expensive in general.
Normalizing flows use invertible transformations so that the inverse
Jacobian is cheap to compute, by design.

• Stochastic computation graphs can use a mix of explicit and


implicit distributions at each node.

• For differentiating through the solution of a differential equation,


two approaches can be considered.

• We can express the gradient as the solution of a differential equa-


tion thanks to the continuous adjoint method. We may then
discretize backwards in time the differential equation that the gra-
dient satisfies. This is the optimize-then-discretize approach.

• We can also first discretize the problem in such a way that the
gradient can simply be computed by reverse mode auto-diff, ap-
plied on the discretization steps. This is the discretize-then-
optimize approach. The optimize-then-discretize approach has
no memory cost, but discrepancies between the forward and back-
ward discretization passes often lead to numerical errors. The
discretize-then-optimize introduces no such discrepancies but may
come at a large memory cost.

• Reversible discretization schemes can circumvent the memory


cost, as they enable the recomputation of intermediate discretiza-
tion steps backwards in time.
Part IV

Smoothing programs
13
Smoothing by optimization

When a function is non-differentiable (or worse, discontinuous), a rea-


sonable approach is to replace it by a differentiable approximation
(or at least, by a continuous relaxation). We refer to the process of
transforming a non-differentiable function into a differentiable one as
“smoothing” the original function. In this chapter, we begin by review-
ing a smoothing technique based on infimal convolution. We then
review an equivalent dual approach, based on the Legendre-Fenchel
transform. We illustrate how to apply these techniques to compute
smoothed ReLUs and smoothed max operators, as well as continuous
relaxations of step functions and argmax operators.

13.1 Primal approach

We first review how to smooth functions in the original, primal space


of the function, using the infimal convolution and more particularly the
Moreau envelope, a.k.a. Moreau-Yoshida regularization. In this chapter,
we consider functions taking potentially infinite positive values, that is,
functions taking values in the half-extended real line R ∪ {∞}. For a

301
302 Smoothing by optimization

function f : RM → R ∪ ∞, we define its domain as

dom(f ) = {u ∈ RM : f (u) < ∞}.

13.1.1 Infimal convolution


As we elaborate in Section 14.1.6, the infimal convolution, sometimes
abbreviated inf-conv, can be seen as a counterpart of the usual convolu-
tion, in which integration has been replaced by minimization (hence its
name). We give its formal definition below.

Definition 13.1 (Infimal convolution). The infimal convolution be-


tween two functions f : RM → R ∪ {∞} and g : RM → R ∪ {∞} is
defined by

(f □g)(µ) := inf f (u) + g(µ − u)


u∈RM
= inf f (µ + z) + g(z)
z∈RM
= inf f (u) + g(z) s.t. u = µ + z.
u,z∈RM

It is easy to check that the three definitions are indeed equivalent,


by using the change of variable u := µ + z, which is a location-scale
transform; see Section 12.4.1. Similarly to the classical convolution,
the infimal convolution between two functions f and g creates a new
function f □g, and it is commutative, meaning that for all µ ∈ RM ,
we have
(f □g)(µ) = (g□f )(µ).
Computing the infimal convolution involves the resolution of a mini-
mization problem, that may or may not enjoy an analytical solution.
Some examples are given in Table 13.1.

Existence
The infimal convolution (f □g)(µ) exists if the infimum inf u∈RM f (u) +
g(µ − u) is finite (Bauschke and Combettes, 2017, Proposition 12.6).
A sufficient condition to achieve this is that u 7→ f (u) + g(µ − u) is
convex for all µ ∈ RM . However, this is not a necessary condition. For
13.1. Primal approach 303

Table 13.1: Examples of infimal convolutions. We use ιC to denote the indicator


function of the set C.

f (u) g(z) (f □g)(µ)


f (u) 0 inf u∈RM f (u)
f (u) ι{v} (z) f (µ − v)
ιC (u) ιD (z) ιC+D (µ)
ιC (u) ∥z∥2 dC (µ) = inf u∈C ∥µ − u∥2
f (u) 1
2 ∥z∥2
2 envf (µ) = inf u∈RM 21 ∥µ − u∥22 + f (u)

example, the infimum can be finite even if f or g are nonconvex, for


example if their domain is a compact set.

Infimal convolution with a regularization function

When a function f is non-differentiable, a commonly-used technique is


to replace it by its infimal convolution f □R, with some regularization
R. The most used regularization is the squared 2-norm, leading to the
Moreau envelope, as we now review.

13.1.2 Moreau envelope

When R(z) := 12 ∥z∥22 , the infimal convolution f □R gives the so-called


Moreau envelope of f , also known as Moreau-Yoshida regularization
of f .

Definition 13.2 (Moreau envelope). Given a function f : RM →


R ∪ {∞}, its Moreau envelope is defined as
1
 
envf (µ) := f □ ∥ · ∥22 (µ)
2
1
= inf f (u) + ∥µ − u∥22
u∈RM 2
1
= inf f (µ + z) + ∥z∥22 .
z∈RM 2
304 Smoothing by optimization

Intuitively, the Moreau envelope is the minimal value over u ∈ RM


of a trade-off between staying close to the input µ according to the
proximity term 12 ∥µ − u∥22 and minimizing f (u). Provided that the
minimizer exists and is unique, we can define the associated proximal
operator of f as
1
proxf (µ) := arg min ∥µ − u∥22 + f (u),
u∈RM 2

In other words, we have for proxf (µ) well defined,


1
envf (µ) = f (proxf (µ)) + ∥µ − proxf (µ)∥22 . (13.1)
2

Properties
A crucial property of the Moreau envelope envf is that for any convex
function f , it is always a smooth function, even when f itself is not
smooth. By smooth, we formally mean that the resulting function envf
is differentiable everywhere with Lipschitz-continuous gradients. We say
L-smooth, if the gradients are L-Lipshcitz continuous. Such a property
can determine the efficiency of optimization algorithms as reviewed in
Section 15.4. We recap below useful properties of the Moreau envelope.

Proposition 13.1 (Properties of Moreau envelope). Let f : RM →


R ∪ {∞}.

1. Smoothness: If f is convex, the function envf is 1-smooth.

2. Gradient: Provided that proxf (µ) is well-defined on µ ∈ RM ,


the gradient of the Moreau envelope can be expressed in terms
of the proximal operator as

∇envf (µ) = µ − proxf (µ).

3. Moreau decomposition: If f is convex, then for any µ ∈


RM , we have the following identity

proxf (µ) + proxf ∗ (µ) = µ,


13.1. Primal approach 305

where f ∗ is the convex conjugate of f , detailed in Section 13.2.


In particular, we get

∇envf ∗ (µ) = proxf (µ)

4. Convexity: envf is convex if f is convex.

5. Infimums coincide envf has the same infimum as the origi-


nal function f :

min envf (µ) = min f (u).


µ∈RM u∈RM

Proof.

1. This is best seen using the dual approach detailed in Section 13.3.

2. This follows from Danskin’s theorem, reviewed in Section 11.2.

3. See, e.g., Bauschke and Combettes (2017, Theorem 14.3).

4. This follows from the fact that the infimum of a jointly convex
function is convex.

5. We have
1
inf envf (µ) = inf ∥µ − u∥22 + f (u)
inf
µ∈RM 2
µ∈RM u∈RM
1
= inf inf ∥µ − u∥22 + f (u)
u∈R µ∈R 2
M M

= inf f (u).
u∈RM

Examples
To illustrate smoothing from the Moreau envelope perspective, we
show how to smooth the 1-norm. In this case, we obtain an analytical
expression for the Moreau envelope.
306 Smoothing by optimization

3.0
Huber loss
2.5
Absolute loss

2.0

1.5

1.0

0.5

0.0
3 2 1 0 1 2 3

Figure 13.1: The Huber loss is the Moreau envelope of the absolute loss.

Example 13.1 (Smoothing the 1-norm via infimal convolution). We


wish to smooth f (u) := ∥u∥1 = M j=1 |ui |. The corresponding prox-
P

imal operator is the soft-thresholding operator (see Section 16.4),


1
proxf (µ) = arg min ∥µ − u∥22 + ∥u∥1
u∈RM 2
= sign(µ) · max(|µ| − 1, 0).

Using Eq. (13.1) and after some algebraic manipulations, we obtain


M M
envf (µ) = huber(µj ) ≈
X X
|µj |,
j=1 j=1

where we defined the Huber loss


 2
 µj if |µj | ≤ 1
huber(µj ) := 2 .
|µ | − 1
j 2 if |µj | > 1

This is illustrated in Fig. 13.1 with M = 1.

We also illustrate in Fig. 13.2 that the Moreau envelope of nonconvex


functions can be approximately computed numerically.
13.1. Primal approach 307

ReLU Ramp Step


3 Original 1.00 1.00
Moreau env 0.75 0.75
2
0.50 0.50
1
0.25 0.25
0 0.00 0.00
2 0 2 2 0 2 2 0 2

Figure 13.2: The Moreau envelope is not limited to convex functions. For instance,
the ramp function is continuous but nonconvex, and the step function is not only
nonconvex but also discontinuous. In this figure, we approximately computed the
infimum over u ∈ R in Definition 13.2 by restricting the search on a finite grid, in a
closed interval.

13.1.3 Vector-valued functions


The Moreau envelope is defined by envf (µ) := inf u∈RM f (u)+ 12 ∥µ−u∥22 .
As such, it is limited to scalar-valued functions f : RM → R. To extend
the Moreau envelope to vector-valued functions f : RM → RT , where
f (u) = (f1 (u), . . . , fT (u)) and fi : RM → R for i ∈ [T ], we may choose
to smooth each fj separately to define
envf (µ) := (envf1 (µ), . . . , envfT (µ)),
where
1
envfi (µ) = inf fi (ui ) + ∥µ − ui ∥22 .
ui ∈RM 2
This approach requires to solve T separate minimization problems and
performs the smoothing of each output coordinate i ∈ [T ] independently.
From Proposition 2.9, we then have that the VJP of envf (u) with any
direction d ∈ RT is
T
∂envf (µ)∗ [d] = ∂envfi (µ)∗ [di ]
X

i=1
T
= di ∇envfi (µ).
X

i=1
In the particular case f (u) = (f1 (u1 ), . . . , fT (uT )), we obtain
T
∂envf (µ)∗ [d] = di envfi (µi ).
X

i=1
308 Smoothing by optimization

An alternative was proposed by Roulet and Harchaoui (2022). For a


differentiable function f : RM → RT , we recall that the VJP of f with
a direction d ∈ RT reads
∂f (u)∗ [d] = ∇⟨f, d⟩(u),
where we defined the scalar-valued function ⟨f, d⟩(u) := ⟨f (u), d⟩. As a
result, if f is non differentiable, a natural idea is to approximate its VJP
∂f (u)∗ [d] (had it existed) by the gradient ∇env⟨f,d⟩ (µ) of the Moreau
envelope
1
env⟨f,d⟩ (µ) = inf ⟨f (u), d⟩ + ∥µ − u∥22 . (13.2)
u∈RM 2
This requires a single optimization problem to solve, independently of
the number of outputs T . Moreover, for d = ei , this recovers envfi (µ)
as a special case.
This approach allows in principle to perform reverse-mode autodiff
(gradient backpropagation) on a neural network whose layers use the
Moreau envelope. Indeed, following Proposition 13.1, the approximate
VJP of f with a direction d is given by
∂f (µ)∗ [d] ≈ ∇env⟨f,d⟩ (µ) = µ − u⋆ ,
where u⋆ is the solution of the minimization problem in Eq. (13.2).
However, we emphasize that this minimization problem could be difficult
to solve in general. Indeed, when performing gradient backpropagation,
the direction d is not necessarily non-negative, therefore the function
being minimized in Eq. (13.2) could be nonconvex, even if each fi is
convex. Another potential caveat is that the direction d influences the
smoothing strength, while in principle we should be able to smooth a
function independently of whether we compute its VJP or not. To see
that, for example in the particular case f (u) = (f1 (u1 ), . . . , fT (uT )),
one easily checks that for d = (d1 , . . . , dT ), we get
T
env⟨f,d⟩ (µ) = envdi fi (µi ).
X

i=1
Smoothing vector-valued functions by Moreau envelope (or more gen-
erally, by infimal convolution) remains an open area of research. We
will see in Chapter 14 that smoothing by convolution more naturally
supports vector-valued functions.
13.2. Legendre–Fenchel transforms, convex conjugates 309

13.2 Legendre–Fenchel transforms, convex conjugates

The Legendre-Fenchel transform, a.k.a. convex conjugate, is a way to


turn a function f into a new function, denoted f ∗ . We now review it in
detail, as it plays a major role for the dual approach to smoothing.

13.2.1 Definition
Consider the class of affine functions of the form

u 7→ ⟨u, v⟩ − b.

These functions are parametrized by their slope v ∈ RM and their


intercept −b ∈ R. Now, suppose we fix v. Given a function f (u), affine
lower bounds of f (u) are all the functions of u such that b satisfies for
all u ∈ RM ,

⟨u, v⟩ − b ≤ f (u) ⇐⇒ ⟨u, v⟩ − f (u) ≤ b.

The tightest lower bound is then defined by b such that

b := sup ⟨u, v⟩ − f (u),


u∈dom(f )

where we recall that the domain of f is defined by

dom(f ) := {u ∈ RM : f (u) < ∞}.

This leads to the definition of Legendre-Fenchel transform, a.k.a.


convex conjugate.

Definition 13.3 (Legendre-Fenchel transform, convex conjugate). Given


a function f : RM → R ∪ {∞}, its convex conjugate is defined by

f ∗ (v) := sup ⟨u, v⟩ − f (u).


u∈dom(f )

We use a sup rather than a max to indicate that f ∗ (v) is potentially


∞. Following the previous discussion, −f ∗ (v) is the intercept of the
tightest affine lower bound with slope v of f (u). This is illustrated
Fig. 13.3.
310 Smoothing by optimization

f (u)

v
f * (v)
0 1
u
Figure 13.3: For a fixed slope v, the function u 7→ uv − f ∗ (v) is the tighest affine
lower bound of f with slope v.

The Legendre-Fenchel transform is a function transformation, as it


produces a new function f ∗ . It can be seen as a dual representation
of a function: instead of representing a convex function f by its graph
(u, f (u)) for u ∈ dom(f ), we can represent it by the set of tangents
with slope v and intercept −f ∗ (v) for v ∈ dom(f ∗ ), as illustrated
in Fig. 13.4. As the name “convex conjugate” indicates, it is convex,
even if the original function is not.

13.2.2 Closed-form examples


Computing f ∗ (v) involves the resolution of a maximization problem,
which could be difficult in general without assumption on f . In some
cases, however, we can compute an analytical expression, as we now
illustrate.

Example 13.2 (Analytical conjugate examples). When f (u) = 12 ∥u∥22 ,


with dom(f ) = RM , the conjugate is
1
f ∗ (v) = max ⟨u, v⟩ − ∥u∥22 .
u∈RM 2
13.2. Legendre–Fenchel transforms, convex conjugates 311

0.0

1.0
0.5

f * (v)
f (u)

0.5
1.0 v
f * (v)
0.0
0.0 0.5 1.0 2 1 0 1
u v
Figure 13.4: Left: instead of representing a convex function f by its graph (u, f (u))
for u ∈ dom(f ), we can represent it by the set of tangents with slope v and intercept
−f ∗ (v) for v ∈ dom(f ∗ ). Right: by varying the slope v of all possible tangents, we
obtain a function of the slope v rather than of the original input u. The colors of the
tangents on the left are chosen to match the colors of the vertical lines on the right.

Setting the gradient u 7→ ⟨u, v⟩ − 12 ∥u∥22 to zero, we obtain u⋆ = v.


Plugging u⋆ back, we therefore obtain
1 1
f ∗ (v) = ⟨u⋆ , v⟩ − ∥u⋆ ∥22 = ∥v∥22 .
2 2
Therefore, f = f ∗ in this case.
When f (u) = ⟨u, log u⟩, with dom(f ) = RM + , the minimizer of
u 7→ ⟨u, v⟩ − ⟨u, log u⟩ is u⋆ = exp(v − 1) and the conjugate is
M
f ∗ (v) = exp(vj − 1).
X

j=1

See for instance Boyd and Vandenberghe (2004) or Beck (2017) for
many more examples.
312 Smoothing by optimization

Constraining the domain


We can incorporate constraints using an indicator function with
values in the extended real line R ∪ {∞},

0 if u ∈ C
ιC (u) := .
+∞ otherwise

Example 13.3 (Incorporating constraints). If f (u) = ιC (u), where


C is a convex set, then

f ∗ (v) = sup ⟨u, v⟩ − f (u) = sup ⟨u, v⟩ := σC (v),


u∈dom(f ) u∈C

which is known as the support function of C. The corresponding


argmax (assuming that it exists),

v 7→ arg max⟨u, v⟩,


u∈C

is known as the linear maximization oracle (LMO) of C. As


another example, if f (u) = ⟨u, log u⟩ + ι△M (u) then
M
f ∗ (v) = logsumexp(v) = log exp(vj ).
X

i=1

We postpone a proof to Proposition 13.9.

13.2.3 Properties
The conjugate enjoys several useful properties, that we now summarize.

Proposition 13.2 (Convex conjugate properties).

1. Convexity: f ∗ (v) is a convex function for all f : RM →


R ∪ {∞} (even if f is nonconvex).

2. Fenchel-Young inequality: for all u, v ∈ RM

f (u) + f ∗ (v) − ⟨u, v⟩ ≥ 0.


13.2. Legendre–Fenchel transforms, convex conjugates 313

3. Gradient: if the supremum in Definition 13.3 is uniquely


achieved, then f ∗ (v) is differentiable at v and its gradient is

∇f ∗ (v) = arg max ⟨u, v⟩ − f (u).


u∈dom(f )

Otherwise, f ∗ (v) is sub-differentiable at v and we get a sub-


gradient instead.

4. Maps: If f and f ∗ are differentiable, then

v = ∇f (u) ⇐⇒ u = ∇f ∗ (v) ⇐⇒ f ∗ (v)+f (u)−⟨u, v⟩ = 0.

5. Biconjugate: f = f ∗∗ if and only if f is convex and closed


(i.e., its sublevel sets form a closed set), otherwise f ∗∗ ≤ f .

Proof.
1. This follows from the fact that v 7→ supu∈C g(u, v) is convex if g
is convex in v. Note that this is true even if g is nonconvex in u.
Here, g(u, v) = ⟨u, v⟩ − f (u), which is affine in v and therefore
convex in v.
2. This follows immediately from Definition 13.3.
3. This follows from Danskin’s theorem, reviewed in Section 11.2.
Another way to see this is by observing that
f ∗ (v) = ⟨g, v⟩ − f (g)
f ∗ (v ′ ) ≥ ⟨g, v ′ ⟩ − f (g),
where g := arg max ⟨u, v⟩ − f (u). Subtracting the two, we obtain
u∈dom(f )

f ∗ (v ′ ) ≥ f ∗ (v) + ⟨g, v ′ − v⟩.


Now, using that f ∗ is convex and Definition 15.6, we obtain that
g = ∇f ∗ (v).
4. See, e.g., Bauschke and Combettes (2017, Proposition 16.10).
5. See Boyd and Vandenberghe (2004, Section 3.3).
314 Smoothing by optimization

13.2.4 Conjugate calculus


While deriving a convex conjugate expression can be difficult in general,
in some cases, it is possible to use simple rules to derive conjugates in
terms of other conjugates.

Proposition 13.3 (Conjugate calculus rules).

1. Separable sum of functions: if f (u) = j=1 fj (uj ), then


PM

M
f ∗ (v) = fj∗ (vj ).
X

j=1

2. Scalar multiplication: if f (u) = c · g(u), for c > 0, then

f ∗ (v) = c · g ∗ (v/c).

3. Addition to an affine function and translation: if f (u) =


g(u) + ⟨α, u⟩ + β, then

f ∗ (v) = g ∗ (v − α) − β.

4. Composition with an invertible linear map: if f (u) =


g(M u), where x 7→ M x is an invertible linear map, then

f ∗ (v) = g ∗ (M −T v).

5. Non-separable sum of functions: if h1 and h2 are convex


functions, then (h1 + h2 )∗ = h∗1 □h∗2 , where □ is the infimal
convolution operator.

13.2.5 Fast Legendre transform


When an analytical expression is not available, we can resort to numeri-
cal schemes to approximately compute the transform / conjugate. When
f is convex, because −f is concave, the maximization in Definition 13.3
is that of a concave function. Therefore, the conjugate can be computed
to arbitrary precision in polynomial time using classical iterative algo-
13.3. Dual approach 315

rithms for constrained optimization such as projected gradient descent


(Section 16.3) or conditional gradient a.k.a. Frank-Wolfe (Jaggi, 2013).
Without convexity assumption on f , f ∗ (v) can be approximated by

f ∗ (v) ≈ sup ⟨u, v⟩ − f (u),


u∈U

where U ⊆ dom(f ) is a discrete grid of values. We can then compute


f ∗ (v) for several inputs v ∈ V using the linear-time Legendre transform
algorithm (Lucet, 1997), where V ⊆ dom(f ∗ ) is another discrete grid.
The complexity is O(|U| · |V|), which is linear in the grid sizes. However,
the grid sizes are typically |U| = |V| = O(N M ), for N equally-distributed
points in each of the M dimensions. Therefore, this approach is limited
to small-dimensional settings, e.g., M ∈ {1, 2, 3}.

13.3 Dual approach

Previously, we presented how to smooth a function by performing


its infimal convolution with a primal-space regularization R. We now
present how to smooth a function by regularizing its Legendre-Fenchel
transform (convex conjugate) instead. This dual, equivalent approach,
is often mathematically more convenient.

13.3.1 Duality between strong convexity and smoothness

We begin by stating a well-known result that will underpin this whole


section: smoothness and strong convexity are dual to each other (Hiriart-
Urruty and Lemaréchal, 1993; Kakade et al., 2009; Beck, 2017; Zhou,
2018).

Proposition 13.4 (Duality between strong convexity and smoothness).


f is µ1 -strongly convex w.r.t. the norm ∥ · ∥ over dom(f ) if and only
if f ∗ is µ-smooth w.r.t. the dual norm ∥ · ∥∗ over dom(f ∗ ).

For a review of the notions of smoothness and strong convexity,


see Section 15.4. We give two examples of strongly-convex and smooth
conjugate pairs in Table 13.2.
316 Smoothing by optimization

Table 13.2: Examples of strongly-convex and smooth conjugate pairs.

Function Norm Domain Conjugate Dual norm Dual domain


1 1
2
2 ∥u∥2 ∥ · ∥2 RM 2
2 ∥v∥2 ∥ · ∥2 RM
⟨u, log u⟩ ∥ · ∥1 △M logsumexp(v) ∥ · ∥∞ RM

13.3.2 Smoothing by dual regularization


The duality between smoothness and strong convexity suggests a generic
approach in order to smooth a function f : RM → R, by going through
the dual space.
1. Compute the conjugate f ∗ :
f ∗ (v) := sup ⟨u, v⟩ − f (u).
u∈dom(f )

2. Add strongly-convex regularization Ω to the conjugate:


fΩ∗ (v) := f ∗ (v) + Ω(v). (13.3)

3. Go back to the primal space, by computing the conjugate of fΩ∗ :


fΩ (u) := fΩ∗∗ (u) = max ⟨u, v⟩ − fΩ∗ (v).
v∈RM

Note that u and v belong to different spaces, i.e., u ∈ dom(f ) and


v ∈ dom(f ∗ ). Following Proposition 13.4, if Ω is µ-strongly convex,
then fΩ (u) is µ1 -smooth. Furthermore, following Proposition 13.2, fΩ (u)
is convex, even if f is nonconvex. Therefore, fΩ (u) is a smooth and
convex relaxation of f (u).
Steps 1 and 3 are the most challenging, as they both require the
derivation of a conjugate. While an analytical solution may not exist in
general, in some simple cases, there is, as we now illustrate.

Example 13.4 (Smoothing the 1-norm via dual regularization). We re-


visit Example 13.1, this time from the dual perspective. We wish
to smooth out the 1-norm f (u) := ∥u∥1 = M j=1 |uj |.
P

1. Compute the conjugate. The conjugate of any norm ∥ · ∥


13.3. Dual approach 317

is the indicator function of the dual norm’s unit ball {v ∈


RM : ∥v∥∗ ≤ 1} (see e.g. Boyd and Vandenberghe (2004,
Example 3.26)). The dual norm of ∥u∥1 is ∥v∥∞ . Moreover,

{v ∈ RM : ∥v∥∞ ≤ 1} = [−1, 1]M .

Recalling that ιC is the indicator function of C, we obtain

f ∗ (v) = ι[−1,1]M (v).

2. Adding strongly-convex regularization. We add quadratic


regularization Ω(v) := 12 ∥v∥22 to define

fΩ∗ (v) := ι[−1,1]M (v) + Ω(v).

3. Going back to the primal.


M
fΩ (u) = fΩ∗∗ (u) = ⟨u, v ⋆ ⟩ − Ω(v ⋆ ) = huber(ui ),
X

i=1

where v ⋆ = clip (u) := max (min (u, 1) , −1).

We therefore indeed recover the Huber loss from Example 13.1.

ReLU functions can be smoothed out in a similar way, as we see in


more details in Section 13.4.
The dual approach allows us to easily bound the smoothed function
in terms of the original function.

Proposition 13.5 (Bounds). If LΩ ≤ Ω(v) ≤ UΩ for all v ∈ dom(Ω),


then for all u ∈ RM ,

f (u) − UΩ ≤ fΩ (u) ≤ f (u) − LΩ .

Proof. Let us define

v ⋆ := arg max⟨u, v⟩ − f ∗ (v)


v∈RM

vΩ := arg max⟨u, v⟩ − fΩ∗ (v),
v∈RM
318 Smoothing by optimization

where we recall that fΩ∗ := f ∗ + Ω. We then have for all u ∈ RM

fΩ (u) = ⟨u, vΩ

⟩ − fΩ∗ (vΩ

) ≥ ⟨u, v ⋆ ⟩ − fΩ∗ (v ⋆ ) = f (u) − Ω(v ⋆ )

and similarly

f (u) − Ω(vΩ

) = ⟨u, v ⋆ ⟩ − f ∗ (v ⋆ ) − Ω(vΩ

) ≥ ⟨u, vΩ

⟩ − fΩ∗ (vΩ

) = fΩ (u).

Combining the two with LΩ ≤ Ω(v) ≤ UΩ for all v ∈ dom(Ω), we obtain

f (u) − UΩ ≤ f (u) − Ω(v ⋆ ) ≤ fΩ (u) ≤ f (u) − Ω(vΩ



) ≤ f (u) − LΩ .

Remark 13.1 (The gradient is differentiable almost everywhere). From


Proposition 13.2, the gradient of fΩ (u) equals

∇fΩ (u) = arg max⟨u, v⟩ − fΩ∗ (v).


v∈RM

If Ω is strongly convex, then fΩ is smooth, meaning that ∇fΩ


is Lipschitz continuous. From Rademacher’s theorem reviewed in
Section 2.7.1, ∇fΩ is then differentiable almost everywhere (that is,
fΩ is twice differentiable almost everywhere). We use this property
in the sequel to define continuous differentiable almost everywhere
relaxations of step functions and argmax operators.

13.3.3 Equivalence between primal and dual regularizations

So far, we saw two approches to obtain a smooth approximation of a


function f . The first approach is based on the infimal convolution f □R,
where R : dom(f ) → R denotes primal regularization. The second ap-
proach is based on regularizing the Legendre-Fenchel transform (convex
conjugate) f ∗ of f with some dual regularization Ω, to define fΩ . It
turns out that both approaches are equivalent.

Proposition 13.6 (Equivalence between primal and dual regularizations).


Let f : RM → R ∪ {∞} and R : RM → R ∪ {∞}, both convex and
closed. Then, fΩ = f □R with Ω = R∗ .
13.3. Dual approach 319

Proof. We have

fΩ (u) = (f ∗ + Ω)∗ (u) = sup ⟨u, v⟩ − f ∗ (v) − Ω(v).


v∈dom(f ∗ )

If h1 and h2 are convex, we have (h1 + h2 )∗ = h∗1 □h∗2 (Beck, 2017,


Theorem 4.17). Using h1 = f ∗ and h2 = Ω = R∗ gives the desired result
using that f ∗∗ = f and R∗∗ = R since both are convex and closed (see
Proposition 13.2).

In particular, with Ω = 12 ∥ · ∥22 = Ω∗ , this shows that the Moreau


envelope can equivalently be written as

envf = fΩ = fΩ∗ .

Given the equivalence between the primal and dual approaches, using
one approach or the other is mainly a matter of mathematical or
algorithmic convenience, depending on the case.
In this book, we focus on applications of smoothing techniques to dif-
ferentiable programming. For applications to non-smooth optimization,
see Nesterov (2005) and Beck and Teboulle (2012).

13.3.4 Regularization scaling


Dual approach
If Ω is 1-strongly convex, then fΩ is a 1-smooth approximation of the
original function f . To control the smoothness of the approximation,
it suffices to regularize with εΩ for ε > 0, leading to a 1/ε-smooth
approximation fεΩ of f . Moreover, one can check that

fεΩ (v) = εfΩ (v/ε)


∇fεΩ (v) = ∇fΩ (v/ε).

Therefore, if we know how to compute fΩ , we can also compute fεΩ and


its gradient easily. Furthermore, the approximation error induced by
the smoothing can be quantified using Proposition 13.5 as we then have

f (u) − εUΩ ≤ fΩ (u) ≤ f (u) − εLΩ ,

provided that LΩ ≤ Ω(v) ≤ UΩ for all v ∈ dom(Ω).


320 Smoothing by optimization

Primal approach

Following Definition 13.2, if we use dual regularization εΩ, where ε > 0


controls the regularization strength, the corresponding primal regular-
ization is R = εΩ∗ (·/ε). That is, we have

fεΩ = f □εΩ∗ (·/ε).

In the particular case Ω(v) = 12 ∥v∥22 , we have

ε 1 1
R(u) = ∥u/ε∥22 = ∥u∥22 = Ω(u).
2 2ε ε
We therefore get
1 1 1
fεΩ = f □ Ω = (εf □Ω) = envεf .
ε ε ε

13.3.5 Generalized entropies

A natural choice of dual regularization Ω(π), when π ∈ △M is a discrete


probability distribution, is a negative entropy function, also known as
negentropy. Since negentropies play a major role in smoothed max
operators, we discuss them in detail here.

Information content and entropy

An entropy function measures the amount of “suprise” of a random


variable or equivalently of a distribution. To define an entropy, we must
first define the information content I(E) of an event E. The value
returned by such a function should be 0 if the probability of the event is
1, as there is no surprise. Conversely, information content should attain
its maximal value if the probability of the event is 0, as it is maximally
surprising. Furthermore, the more probable an event E is, the less
surprising it is. Therefore, when p(E) increases, I(E) should decrease.
Overloading the notation, we also write the information content of the
outcome y of a random variable Y as the information content of the
event {Y = y},
I(y) := I({Y = y}).
13.3. Dual approach 321

Given an information content function, we can then define the


entropy H(Y ) of a random variable Y ∈ Y as the expected information
content,
H(Y ) := E[I(Y )].
Different definitions of information content lead to different definitions
of entropy.

Shannon’s entropy
A definition of information content satisfying the criteria above is
1
 
I(E) := log = − log p(E).
p(E)
Indeed, − log 1 = 0, − log 0 = ∞ and − log is a decreasing function over
(0, 1]. Using this information content definition leads to Shannon’s
entropy (Shannon, 1948)
H(Y ) = E[I(Y )] = − p(y) log p(y).
X

y∈Y

We can therefore define the Shannon entropy of a discrete probability


distribution π ∈ △M as
M
H(π) = − πi log πi = −⟨π, log π⟩
X

i=1

and use the corresponding negentropy as regularization


Ω(π) = −H(π) = ⟨π, log π⟩.
The function is strongly convex w.r.t. ∥ · ∥1 over △M . However, it is not
strongly convex over RM+ , since this is not a bounded set; see for instance
(Blondel, 2019, Proposition 2). Since Ω is added to f ∗ in Eq. (13.3),
we can therefore use this choice of Ω to smooth out a function f if
dom(f ∗ ) ⊆ △M .

Gini’s entropy
As an alternative, we can define information content as
1
I(E) = (1 − p(E)).
2
322 Smoothing by optimization

0.7

0.6

0.5

0.4

0.3

0.2
Tsallis 1 (Shannon)
0.1 Tsallis = 1.5
0.0
Tsallis = 2 (Gini)
0.0 0.2 0.4 0.6 0.8 1.0

Figure 13.5: Tsallis entropies of the distribution π = (π, 1 − π) ∈ △2 , for π ∈ [0, 1].
An entropy is a non-negative concave function that attains its maximum at the
uniform distribution, here (0.5, 0.5). A negative entropy, a.k.a. negentropy, can be used
as a dual regularization function Ω to smooth out a function f when dom(f ∗ ) ⊆ △M .

The 12 factor is for later mathematical convenience. This again satisfies


the criteria of an information content function. Indeed, i) when p(E) = 1,
I(E) = 0 ii) when p(E) = 0, I(E) attains its maximum of 12 iii)
the function is decreasing w.r.t. p(E). Using this information content
definition leads to Gini’s entropy a.k.a. Gini index (Gini, 1912)
1X
H(Y ) = E[I(Y )] = p(y)(1 − p(y)).
2 y∈Y

We can use Gini’s negative entropy to define for all π ∈ △M


1 1
Ω(π) = ⟨π, π − 1⟩ = (∥π∥22 − 1).
2 2
The function is strongly convex w.r.t. ∥ · ∥2 over RM . We can therefore
use this choice of Ω to smooth out a function f if dom(f ∗ ) ⊆ RM . This
means that the set of functions that we can smooth out with Gini
entropy is larger than the set of functions we can smooth out with
Shannon entropy.

Tsallis entropies
Given α ≥ 1, a more general information content definition is
1
I(E) = (1 − p(E)α−1 ).
α(α − 1)
13.3. Dual approach 323

Tsallis 1 (Shannon) Tsallis = 1.5 Tsallis = 2 (Gini)


(0, 0, 1) (0, 0, 1) (0, 0, 1)

(1, 0, 0) (0, 1, 0) (1, 0, 0) (0, 1, 0) (1, 0, 0) (0, 1, 0)

Figure 13.6: Contours of Tsallis entropies on the probability simplex.

Using this definition leads to the Tsallis entropy (Tsallis, 1988)


1
H(Y ) = E[I(Y )] = p(y)(1 − pα−1 (y)).
X
α(α − 1) y∈Y

The Tsallis entropy recovers the Shannon entropy in the limit α → 1


and the Gini entropy when α = 2. We can use the Tsallis negative
entropy to define for all π ∈ △M
1 1
Ω(π) = ⟨π, π α−1 − 1⟩ = (∥π∥αα − 1),
α(α − 1) α(α − 1)
where ∥v∥p is the p-norm for (p ≥ 1)
M
! p1
p
∥v∥p :=
X
vi ,
i=1

so that
M
p
∥v∥pp =
X
vi .
i=1
Tsallis entropies for α → 1 (Shannon entropy), α = 1.5 and α = 2 (Gini
entropy) are illustrated in Fig. 13.5 and Fig. 13.6.

Definition and properties of generalized entropies


So far, we saw how to define an entropy as the expected information
content. However, generalized entropies (DeGroot, 1962; Grünwald and
Dawid, 2004) do not necessarily need to take this form. We follow the
definition of Blondel et al. (2020).
324 Smoothing by optimization

Definition 13.4 (Entropy function). A function H : △M → R+ is


an entropy if

1. H(π) = 0 if π ∈ {e1 , . . . , eM },

2. H is strictly concave,

3. H(P π) = H(π) for any permutation matrix P .

This definition implies that H is non-negative and is uniquely maxi-


mized by the uniform distribution (Blondel et al., 2020, Proposition 4).
This is indeed what we expect from an entropy function. An example is
the squared p-norm entropy (Blondel et al., 2020)
1 1
H(π) =− ∥π∥2p .
2 2
Since the squared p-norm is strongly convex for p ∈ (1, 2] (Ball et al.,
2002), this entropy is strongly concave for p ∈ (1, 2] and can therefore
be used to smooth out functions.
We now illustrate how to apply these techniques to compute smoothed
ReLUs and smoothed max operators, as well as continuous relaxations
of step functions and argmax operators.

13.4 Smoothed ReLU functions

To demonstrate the application of the smoothing techniques discussed in


this chapter, we begin by explaining how to smooth the ReLU function.
The ReLU function is defined by

u if u ≥ 0
relu(u) := = max(u, 0).
0 otherwise
We recall that in order to smooth a function f by the dual approach,
we calculate its conjugate f ∗ , add regularization Ω to it to obtain
fΩ∗ := f ∗ + Ω and then obtain fΩ by computing fΩ∗∗ .
Here, we wish to smooth out f = relu. Its convex conjugate is

0 if π ∈ [0, 1]
relu∗ (π) = ι[0,1] (π) = .
∞ if π ̸∈ [0, 1]
13.4. Smoothed ReLU functions 325

To notice why, we observe that



u if u ≥ 0
relu(u) = max u · π = max u · π = . (13.4)
π∈[0,1] π∈{0,1} 0 otherwise

Indeed, since the objective is linear in π, the maximum is attained at


one of the extreme points of [0, 1], so that we can replace the constraint
π ∈ [0, 1] with π ∈ {0, 1}. This shows that the ReLU is exactly the
support function of [0, 1]. Since the conjugate of the support function
is the indicator function, we indeed obtain relu∗ = ι[0,1] . We therefore
have
relu∗Ω (π) = relu∗ (π) + Ω(π) = ι[0,1] (π) + Ω(π)
and for some choice of Ω, we need to be able to compute

reluΩ (u) = max u · π − (ι[0,1] (π) + Ω(π))


π∈R
= max u · π − Ω(π).
π∈[0,1]

The softplus

If we use the regularizer Ω(π) = π log π + (1 − π) log(1 − π), which


comes from using Shannon’s negentropy ⟨π, log π⟩ with π = (π, 1 − π),
we obtain
reluΩ (u) = softplus(u) = log(1 + exp(u)).
This result is a special case of Proposition 13.9.

The sparseplus

If we use the regularizer Ω(π) = π(π − 1), which comes from using
Gini’s negentropy with 21 ⟨π, π − 1⟩ with π = (π, 1 − π), we obtain

0,

 u ≤ −1
reluΩ (u) = sparseplus(u) = 1
(u + 1)2 , −1 < u < 1 .
4
u≥1

u,

See Fig. 13.8 (left figure) for a comparison of softplus and sparseplus.
326 Smoothing by optimization

13.5 Smoothed max operators

As a more elaborate application of the smoothing techniques discussed


in this chapter, we explain how to smooth max operators. Smoothed
max operators include smoothed ReLU functions as a special case.

13.5.1 Definition and properties


With a slight notation overloading, given a vector u = (u1 , . . . , uM ) ∈
RM , we define its maximum as
max(u) := max uj .
j∈[M ]

To obtain a smooth approximation maxΩ of max, we again apply the


dual approach. The conjugate of max is
max∗ (π) = ι△M (π).
To notice why, we observe that the vertices of the probability simplex
△M are the standard basis vectors e1 , . . . , eM . Since the objective is
linear, we then have
max(u) = max ⟨u, π⟩ = max ⟨u, π⟩.
π∈△M π∈{e1 ,...,eM }

In other words, the maximum operator is exactly the support function


of △M . Since the conjugate of the support function is the indicator
function, we indeed obtain max∗ = ι△M . We can therefore write
max∗Ω (π) = max∗ (π) + Ω(π) = Ω(π) + ι△M (π)
and
maxΩ (u) = (Ω + ι△M )∗ (u)
= max ⟨u, π⟩ − (Ω(π) + ι△M (π))
π∈RM
= max ⟨u, π⟩ − Ω(π).
π∈△M

The smoothed max operator maxΩ can be useful in a neural network,


for example as a smoothed max pooling layer. Its properties have been
studied in (Mensch and Blondel, 2018, Lemma 1), as we recall here for
convenience.
13.5. Smoothed max operators 327

Proposition 13.7 (Properties of maxΩ ). The following properties


hold.

1. Bounds: if LΩ ≤ Ω(π) ≤ UΩ for all π ∈ △M , then max(u) −


UΩ ≤ maxΩ (u) ≤ max(u) − LΩ for all u ∈ RM .

2. Monotonicity: if u ≤ v (element-wise), then maxΩ (u) ≤


maxΩ (v).

3. Commutativity: if Ω(P π) = Ω(π) for any permutation


matrix P and any π ∈ △M , then maxΩ (P u) = maxΩ (u) for
any permutation matrix P .

4. Distributivity of +: maxΩ (u + c 1) = maxΩ (u) + c for all


u ∈ RM and all c ∈ R.

These properties are leveraged in (Mensch and Blondel, 2018) to


create differentiable dynamic programs. We consider in the following two
possible choices of Ω leading to the softmax and sparsemax operators
illustrated in Fig. 13.7.

Smoothed min operators


The minimum operator can be expressed in terms of the maximum
operator, since for all u ∈ RM ,

min(u) = − max(−u).

Given a smoothed max operator maxΩ , we can therefore easily define a


smoothed min operator as

minΩ (u) := −maxΩ (−u).

13.5.2 Reduction to root finding


Computing maxΩ (u) for a general strongly-convex regularization Ω in-
volves the resolution of a maximum over probability simplex constraints.
For convenience, let us define the notation

δΩ (u) := (Ω + ιRM )∗ (u) = max ⟨u, v⟩ − Ω(v).


+ v∈RM
+
328 Smoothing by optimization

The following proposition shows that we can reduce computing maxΩ


to solving a root equation involving δΩ .
Proposition 13.8 (Computing maxΩ as root finding). Suppose Ω is
strongly convex. For all u ∈ RM ,

maxΩ (u) = min τ + δΩ (u − τ 1)


τ ∈R
= τ ⋆ + δΩ (u − τ ⋆ 1)

and
∇maxΩ (u) = ∇δΩ (u − τ ⋆ 1),
where τ ⋆ is the solution w.r.t. τ of the above min, which satisfies
the root equation

⟨∇δΩ (u − τ ⋆ 1), 1⟩ = 1.

Proof. The idea is to keep the non-negativity constraint explicit, but to


use a Lagrange multiplier for the equality constraint of the probability
simplex. We then have
maxΩ (u) = max ⟨u, v⟩ − Ω(v)
v∈△M

= max min⟨u, v⟩ − Ω(v) − τ (⟨v, 1⟩ − 1)


v∈RM τ ∈R
+

= min τ + max ⟨u − τ 1, v⟩ − Ω(v)


τ ∈R v∈RM
+

= min τ + δΩ (u − τ 1),
τ ∈R

where we used that we can swap the min and the max, since (u, v) 7→
⟨u, v⟩ − Ω(v) is convex-concave and v ∈ △M is an affine constraint. The
gradient ∇δΩ (u) follows from Danskin’s theorem. The root equation
follows from computing the derivative of τ 7→ τ +δΩ (u−τ 1) and setting
it to zero.

13.5.3 The softmax


When Ω is Shannon’s negentropy, we obtain that maxΩ is the softmax,
already briefly discussed in Section 4.4.2.
13.5. Smoothed max operators 329

Proposition 13.9 (Analytical expression of the softmax). When Ω(π) =


⟨π, log π⟩, we get

softmax(u) := maxΩ (u)


= max ⟨u, π⟩ − Ω(π)
π∈△M

= logsumexp(u)
M
= log
X
euj .
j=1

Proof. Since dom(Ω) = RM + , we have δΩ = Ω (i.e., the non-negativity


constraint is redundant). From Example 13.2, we therefore have δΩ (u) =


j=1 exp(uj − 1). From Proposition 13.8, maxΩ (u) = τ + δΩ (u − τ 1)
PM ⋆ ⋆

where τ ⋆ satisfies ⟨∇δΩ (u − τ ⋆ 1), 1⟩ = 1. Since ∇δΩ (u) = exp(u −


1), we need to solve M j=1 exp(uj − 1 − τ ) = 1. We therefore get
P

τ + 1 = logsumexp(u) and therefore maxΩ (u) = logsumexp(u) −


1+ M j=1 exp(uj − logsumexp(u)) = logsumexp(u).


P

Since − log M ≤ Ω(π) ≤ 0 for all π ∈ △M , following Proposi-


tion 13.7, we get for all u ∈ RM

max(u) ≤ softmax(u) ≤ max(u) + log M.

A unique property of the softmax, which is not the case of all maxΩ
operators, is that it supports associativity.

Proposition 13.10 (Associativity of the softmax). For all a, b, c ∈


R,

softmax(softmax(a, b), c) = softmax(a, softmax(b, c)).

13.5.4 The sparsemax


Alternatively, choosing Ω to be Gini’s negentropy leads to the sparsemax
(Martins and Astudillo, 2016; Mensch and Blondel, 2018).
330 Smoothing by optimization

Proposition 13.11 (Variational formulation of sparsemax). When Ω(π) =


2 ⟨π, π − 1⟩, we have
1

sparsemax(u) := maxΩ (u)


= max ⟨u, π⟩ − Ω(π)
π∈△M

= ⟨u, π ⋆ ⟩ − Ω(π ⋆ )

where
π ⋆ = sparseargmax(u) := arg min ∥u − π∥22 .
π∈△M

Proof. This follows from the fact that Ω(π) is up to a constant equal
to 12 ∥π∥22 and completing the square.

Therefore, computing the sparsemax can use the sparseargmax (the


Euclidean projection onto the probability simplex) as a building block.
We discuss how to compute it in more detail in Section 13.7. Applying
Proposition 13.8 gives an alternative formulation.

Proposition 13.12 (Sparsemax as root finding). When Ω(π) = 12 ⟨π, π−


1⟩, we have
M
1X
sparsemax(u) = maxΩ (u) = min τ + [ui − τ ]2+
τ ∈R 2 i=1

and τ ⋆ satisfies
M
[ui − τ ]+ = 1.
X

i=1

Proof. First, we compute the expression of δΩ (u) = maxv∈RM ⟨u, v⟩ −


+
Ω(v). Setting the gradient of v 7→ ⟨u, v⟩ − Ω(v) and clipping, we
obtain v ⋆ = [u]+ . Plugging v ⋆ back, we obtain δΩ (u) = 12 M i=1 [ui ]+ .
2
P

Using Proposition 13.8 proves the proposition’s first part. Setting the
derivative w.r.t. τ to zero gives the second part.
13.5. Smoothed max operators 331

max(u1, u2, 0) softmax(u1, u2, 0) sparsemax(u1, u2, 0)


2.5 2.5 2.5

0.0 0.0 0.0

u2
u2

u2
2.5 2.5 2.5
2.5 0.0 2.5 2.5 0.0 2.5 2.5 0.0 2.5
u1 u1 u1
0 2 4 2 4 0 2 4
Value Value Value

Figure 13.7: Max, softmax and sparsemax functions. The max function has non-
smooth contour lines (set of points {u ∈ R3 : f (u) = c} for some constant c
represented by dashed gray lines). So the gradient along these contour lines switch
suddenly at the corners of the contour lines switch. This shows that the max function
is not differentiable everywhere, namely, non-differentiable on the set of points
{u ∈ R3 : ui = uj for any i ̸= j}. The contour lines of the softmax and sparsemax
functions on the other hand are smooth illustrating that these functions are smooth
counterpart of the max function.

It can be shown (Duchi et al., 2008; Condat, 2016) that the exact
solution τ ⋆ is obtained by
 
j⋆
1 X
τ ⋆ = ⋆  u[i] − 1 , (13.5)
j i=1

where j ⋆ is the largest j ∈ [M ] such that


 
j
1 X
uj −  u[i] − 1 > 0,
j i=1

and where we used the notation u[1] ≥ u[2] ≥ · · · ≥ u[M ] . As an


alternative, we can also compute τ ⋆ approximately using a bisection or
by gradient descent w.r.t. τ .
Since 2M1
≤ ∥π∥22 ≤ 21 , we get − M2M
−1
≤ ∥π∥22 ≤ 0 for all π ∈ △M .
Following Proposition 13.7, we therefore get for all u ∈ RM
M −1
max(u) ≤ sparsemax(u) ≤ max(u) + .
2M
332 Smoothing by optimization

13.5.5 Recovering smoothed ReLU functions


Using the vector u = (u, 0) ∈ R2 as input, the smoothed max operator
recovers the smoothed ReLU:
maxΩ ((u, 0)) = reluΨ (u),
where we defined Ψ(π) := Ω((π, 1 − π)). With Ω being Shannon’s
negentropy, we recover Ψ(π) = π log π + (1 − π) log(1 − π); with Ω being
Gini’s negentropy, we recover Ψ(π) = π(π − 1), that we used to smooth
the ReLU.

13.6 Relaxed step functions (sigmoids)

We now turn to creating continuous relaxations of step functions. The


binary step function, a.k.a. Heaviside step function, is defined by

1 if u ≥ 0
step(u) := .
0 otherwise
From Eq. (13.4), its variational form is
step(u) = arg max u · π.
π∈[0,1]

We can therefore define the relaxation


stepΩ (u) := arg max u · π − Ω(π).
π∈[0,1]

Notice that, unlike the case of the smoothed ReLU, it is a regularized


argmax, not a regularized max. Following Remark 13.1, strongly convex
regularization Ω ensures that stepΩ (u) is a Lipschitz continuous function
of u and is therefore, at least, differentiable almost everywhere, unlike
step(u).

The logistic function


If we use the regularizer Ω(π) = π log π + (1 − π) log(1 − π), we obtain
the closed form
1 eu
stepΩ (u) = logistic(u) := = .
1+e −u 1 + eu
This function is differentiable everywhere.
13.7. Relaxed argmax operators 333

The sparse sigmoid


As an alternative, if we use Ω(π) = π(π − 1), we obtain a piecewise
linear sigmoid,

0, u ≤ −1


stepΩ (u) = sparsesigmoid(u) := 1
(u + 1), −1 < u < 1 .
2
1, u≥1

Unlike the logistic function, it can reach the exact values 0 or 1. However,
the function has two kinks, where the function is non-differentiable.

Link between smoothed ReLU functions and sigmoids


It turns out that the three sigmoids we presented above (step, logistic,
sparsesigmoid) are all equal to the derivative of their corresponding
smoothed ReLU function:
step(u) = relu′ (u)
logistic(u) = softplus′ (u)
sparsesigmoid(u) = sparseplus′ (u)
and more generally
relu′Ω (u) = stepΩ (u).
This is a consequence of Danskin’s theorem; see Example 11.2. We
illustrate the smoothed ReLU functions and relaxed step functions
(sigmoids) in Fig. 13.8.

13.7 Relaxed argmax operators

We now turn to argmax operators, which are a generalization of step


functions. With a slight notation overloading, let us now define
argmax(u) := ϕ(arg max uj ),
j∈[M ]

where ϕ(j) = onehot(j) = ej is used to embed any integer j ∈ [M ] into


RM . Following the previous discussion, we have the variational form
argmax(u) = arg max⟨u, π⟩ = arg max ⟨u, π⟩,
π∈△M π∈{e1 ,...,eM }
334 Smoothing by optimization

Activations Sigmoids
2 1.0
ReLU Heaviside
SoftPlus Logistic
1 SparsePlus 0.5 SparseSigmoid

0 0.0
2 1 0 1 2 2 1 0 1 2

Figure 13.8: Smoothed ReLU functions and relaxed step functions (sigmoids).
Differentiating the left functions gives the right functions.

where the second equality uses that a linear function is maximized at


one of the vertices of the simplex. This variational form suggests to
define the relaxation

argmaxΩ (u) := arg max⟨u, π⟩ − Ω(π).


π∈△M

Again, following Remark 13.1, argmaxΩ (u) is guaranteed to be, at least,


a differentiable almost everywhere function of u if Ω is strongly convex.
Similarly to sigmoids, it turns out that these mappings are equal to
the gradient of their corresponding smoothed max operator:

argmaxΩ (u) = ∇maxΩ (u).

This is again a consequence of Danskin’s theorem.

The softargmax

When using Shannon’s entropy Ω(π) = ⟨π, log π⟩, we obtain

exp(u)
argmaxΩ (u) = softargmax(u) = PM ,
j=1 exp(uj )

which is differentiable everywhere.

Proof. We know that maxΩ (u) = logsumexp(u) and that ∇ maxΩ (u) =
argmaxΩ (u). Differentiating logsumexp(u) gives softargmax(u).
13.7. Relaxed argmax operators 335

The sparseargmax

When using Gini’s entropy Ω(π) = 12 ⟨π, π−1⟩, which is up to a constant


equal to 12 ∥π∥22 , we obtain the sparseargmax (Martins and Astudillo,
2016)

argmaxΩ (u) = sparseargmax(u)


1
:= arg max⟨u, π⟩ − ⟨π, π − 1⟩
π∈△M 2
1
= arg max⟨u, π⟩ − ∥π∥22
π∈△M 2
= arg min ∥u − π∥22 ,
π∈△M

which is nothing but the Euclidean projection onto the probability


simplex (see also Section 16.3). The Euclidean projection onto the
probability simplex △M can be computed exactly using a median-
finding-like algorithm. The complexity is O(M ) expected time and
O(M log M ) worst-case time (Brucker, 1984; Michelot, 1986; Duchi
et al., 2008; Condat, 2016). Computing the Euclidean projection onto
the probability simplex boils down to computing τ ⋆ given in Eq. (13.5).
Once we computed it, we have

sparseargmax(u) = [u − τ ⋆ ]+ ,

At its name indicates, and as the above equation shows, sparseargmax


is sparse, but it is only differentiable almost everywhere. Note that
the operator is originally known as sparsemax (Martins and Astudillo,
2016), but this is a misnomer, as it is really an approximation of the
argmax. Therefore, in analogy with the softargmax, we use the name
sparseargmax. We compare the argmax, softmax and sparseargmax in
Fig. 13.9 and Fig. 13.10.

Relaxed argmin operators

The argmin operator can be expressed in terms of the argmax operator,

arg min(u) = arg max(−u).


336 Smoothing by optimization

= argmax(u1, u2, 0)
1 2 3

2.5 2.5 2.5


0.0 0.0 0.0
u2

u2

u2
2.5 2.5 2.5
2.5 0.0 2.5 2.5 0.0 2.5 2.5 0.0 2.5
u1 u1 u1
0.0 0.5 1.0 0.0 0.5 1.0 0.0 0.5 1.0
Value Value Value
= softargmax(u1u2, 0)
1 2 3

2.5 2.5 2.5


0.0 0.0 0.0
u2

u2

u2

2.5 2.5 2.5


2.5 0.0 2.5 2.5 0.0 2.5 2.5 0.0 2.5
u1 u1 u1
0.0 0.5 1.0 0.0 0.5 1.0 0.0 0.5 1.0
Value Value Value
= sparseargmax(u1u2, 0)
1 2 3

2.5 2.5 2.5


0.0 0.0 0.0
u2

u2

u2

2.5 2.5 2.5


2.5 0.0 2.5 2.5 0.0 2.5 2.5 0.0 2.5
u1 u1 u1
0.0 0.5 1.0 0.0 0.5 1.0 0.0 0.5 1.0
Value Value Value

Figure 13.9: Values of argmax(u), softargmax(u), and sparseargmax(u) for u =


(u1 , u2 , 0), when varying u1 and u2 . The argmax is a piecewise constant, discontinuous
function. The softargmax is a continuous and differentiable everywhere function, but
it is always strictly positive and therefore dense. The sparseargmax is a continuous
function and its output can be sparse, but it is only a differentiable almost everywhere
function.
13.8. Summary 337

Figure 13.10: Same as Fig. 13.9 but using a 3D plot.

Given a relaxed argmax operator argmaxΩ , we can therefore define a


relaxed argmin by

argminΩ (u) := argmaxΩ (−u).

We then have for all u ∈ RM

argminΩ (u) = ∇minΩ (u).

13.8 Summary

• When a function f is non-differentiable (or worse, discontinuous),


a reasonable approach is to replace it by its smooth approximation
(or continuous relaxation).

• The first approach we reviewed is infimal convolution between f


and primal regularization R. The Moreau envelope is a special
case, obtained by using R = 12 ∥ · ∥22 .

• The second approach we reviewed is regularizing the convex con-


jugate f ∗ of f with some dual regularization Ω. We saw that the
primal and dual approaches are equivalent when R = Ω∗ .

• The Legendre-Fenchel transformation, a.k.a. convex conjugate,


can be seen as a dual representation of a function: instead of
representing f by its graph (u, f (u)) for u ∈ dom(f ), we can
represent it by the set of tangents with slope v and intercept
338 Smoothing by optimization

−f ∗ (v) for v ∈ dom(f ∗ ) As its name indicates, it is convex, even


if the original function is not.

• We showed how to apply smoothing techniques to create smoothed


ReLU functions and smoothed max operators. We also showed that
taking their gradients allowed us to obtain generalized sigmoid
functions and argmax operators.
14
Smoothing by integration

In this chapter, we review smoothing techniques based on convolution.

14.1 Convolution

14.1.1 Convolution operators

The convolution between two functions f and g produces another


function, denoted f ∗ g. It is defined by
Z ∞
(f ∗ g)(µ) := f (u)g(µ − u) du, (14.1)
−∞

assuming that the integral is well defined. It is therefore the integral of


the product of f and g after g is reflected about the y-axis and shifted.
It can be seen as a generalization of the moving average. Using the
change of variable z := µ − u, we can also write
Z ∞
(f ∗ g)(µ) = f (µ − z)g(z) dz = (g ∗ f )(µ). (14.2)
−∞

The convolution operator is therefore commutative.

339
340 Smoothing by integration

14.1.2 Convolution with a kernel

The convolution is frequently used together with a kernel κ to create a


smooth approximation f ∗ κ of f . The most frequently used kernel is
the Gaussian kernel with width σ, defined by
1 1 z 2
κσ (z) := √ e− 2 ( σ ) .
2πσ
This is the probability density function (PDF) of the normal distribution
with zero mean and variance σ 2 . The term √2πσ 1
is a normalization
constant, ensuring that the kernel sums to 1 for all σ. We therefore say
that κσ is a normalized kernel.

Averaging perspective

Applying the definition of the convolution in Eq. (14.1), we obtain


1
Z ∞
1 µ−u 2
(f ∗ κσ )(µ) := √ f (u)e− 2 ( σ
)
du
2πσ −∞
= EU ∼pµ,σ [f (U )],

where
1 1 µ−u 2
pµ,σ (u) := κσ (µ − u) = κσ (z) = √ e− 2 ( σ )
2πσ
is the PDF of the Gaussian distribution with mean µ and variance
σ 2 . Therefore, we can see f ∗ κσ as the expectation of f (u) over a
Gaussian centered around µ. This property is true for all translation-
invariant kernels, that correspond to a location-scale family distribution
(e.g., the Laplace distribution). The convolution therefore performs an
averaging with all points, with points nearby µ given more weight by
the distribution. The parameter σ controls the importance we want to
give to farther points. We call this viewpoint averaging, as we replace
f (u) by E[f (U )].

Perturbation perspective

Conversely, using the alternative definition of the convolution operator


in Eq. (14.2), which stems from the commutativity of the convolution,
14.1. Convolution 341

we have
Z ∞
1 z 2
(f ∗ κσ )(µ) := f (µ − z)e− 2 ( σ ) dz
−∞
= EZ∼p0,σ [f (µ − Z)]
= EZ∼p0,σ [f (µ + Z)],

where, in the third line, we used that p0,σ is sign invariant, i.e., p0,σ (z) =
p0,σ (−z). This viewpoint shows that smoothing by convolution with
a Gaussian kernel can also be seen as injecting Gaussian noise or
perturbations to the function’s input.

Limit case

When σ → 0, the kernel κσ converges to a Dirac delta function,

lim κσ (z) = δ(z).


σ→0

Since the Dirac delta is the multiplicative identity of the convolution


algebra (this is also known as the sifting property), when σ → 0, f ∗ κσ
converges to f , i.e.,

lim (f ∗ κσ )(u) = f (u).


σ→0

14.1.3 Discrete convolution

Many times, we work with functions whose convolution does not have
an analytical form. In these cases, we can use a discrete convolution on
a grid of values. For two functions f and g defined over Z, the discrete
convolution is defined by

(f ∗ g)[i] := f [j]g[i − j].
X

j=−∞

As for its continuous counterpart, the discrete convolution is commuta-


tive, namely,

(f ∗ g)[i] = f [i − j]g[j] = (g ∗ f )[i].
X

j=−∞
342 Smoothing by integration

10
= 0.25
8 = 0.5
= 1.0
6

0
3 2 1 0 1 2 3

Figure 14.1: Smoothing of the signal f [t] := t2 + 0.3 sin(6πt) with a sampled and
renormalized Gaussian kernel.

When g has finite support over the set S := {−M, −M +1, . . . , 0, . . . , M −


1, M }, meaning that g[i] = 0 for all i ̸∈ S, a finite summation may be
used instead, i.e.,

M
(f ∗ g)[i] = f [i − j]g[j] = (g ∗ f )[i].
X

j=−M

In practice, convolution between a discrete signal f : Z → R and a


continuous kernel κ : R → R is implemented by discretizing the kernel.
One of the simplest approaches consists in sampling points on an interval,
evaluating the kernel at these points and renormalizing the obtained
values, so that the sampled kernel sums to 1. This is illustrated with
the Gaussian kernel in Fig. 14.1. Since the Gaussian kernel decays
exponentially fast, we can choose a small interval around 0. For a survey
of other possible discretizations of the Gaussian kernel, see Getreuer
(2013).
14.1. Convolution 343

14.1.4 Differentiation
Remarkably, provided that the two functions are integrable with inte-
grable derivatives, the derivative of the convolution satisfies
(f ∗ g)′ = (f ′ ∗ g) = (f ∗ g ′ ),
which simply stems from switching derivative and integral in the defini-
tion of the convolution. Moreover, we have the following proposition.

Proposition 14.1 (Differentiability of the convolution). If g is n-times


differentiable with compact support over R and f is locally inte-
grable over R, then f ∗ g is n-times differentiable over R.

14.1.5 Multidimensional convolution


So far, we studied the convolution of one-dimensional functions. The
definition can be naturally extended to multidimensional functions
f : RM → R and g : RM → R as
Z
(f ∗ g)(µ) := f (u)g(µ − u) du,
RM
assuming again that the integral exists. Typically, a Gaussian kernel
with diagonal covariance matrix is used
M
1 1
z 2
− 1 ( j )2 1 ∥z∥2
κσ (z) := e 2 σj = √ M e− 2 σ 2 , (14.3)
Y

j=1 2πσj 2π σ M
where, in the second equality, we assumed σ1 = · · · = σM . In an image
processing context, where M = 2, it is approximated using a discrete
convolution and it is called a Gaussian blur.

14.1.6 Link between convolution and infimal convolution


The infimal convolution we studied in Section 13.1 takes the form
(f □g)(µ) := inf f (u) + g(µ − u).
u∈RM

In comparison, the classical convolution takes the form


Z
(F ∗ G)(µ) := F (u)G(µ − u) du.
RM
344 Smoothing by integration

The two forms of convolution are clearly related. Infimal convolution


performs an infimum and uses the sum of f and g: it uses a min-
plus algebra. Classical convolution performs an integral and uses the
product of F and G: it uses a sum-product algebra.

14.1.7 The soft infimal convolution


The link between the infimal convolution and the classical convolution
can be further elucidated if we replace the infimum with a soft minimum
in the definition of the infimal convolution.

Definition 14.1 (Soft infimal convolution). The soft infimal convo-


lution between f : RM → R and g : RM → R is

(f □ε g)(µ) := softminε f (u) + g(µ − u),


u∈RM

where we defined the soft minimum (assuming that it exists) over


S of any function h : S → R as
Z
softminε h(u) := −ε log exp (−h(u)/ε) du.
u∈S S

We recover the infimal convolution as ε → 0.

Computation using a convolution


We now show that we can rewrite the soft infimal convolution using
a classical convolution. Indeed, by using the exponential change of
variable (sometimes referred to as Cole-Hopf transformation in a
partial differential equation context)

Cε {f }(u) := exp(−f (u)/ε)


Cε−1 {F }(v) = −ε log F (v),

we can define each function in the exponential domain,

Fε := Cε {f }
Gε := Cε {g}
Hε := Cε {hε }.
14.1. Convolution 345

It is easy to check that we then have


Hε (µ) = (Fε ∗ Gε )(µ).
Back to log domain, we obtain
hε (µ) = Cε−1 {Hε }(µ).
Combining the transformation and its inverse, we can write
hε (µ) = Cε−1 {Cε {f } ∗ Cε {g}}(µ).
What we have shown is that, after an exponential change of variable,
the soft infimal convolution can be reduced to the computation of a
convolution. This is useful as a discrete convolution on a grid of size n
can be computed in O(n log n).

14.1.8 The soft Moreau envelope


We saw in Section 13.1.2 that the infimal convolution between f and
R(z) = 12 z 2 is the Moreau envelope,
1
Mf (µ) := (f □R)(µ) = inf f (u) + ∥µ − u∥22 .
u∈RM 2
Replacing the infimal convolution with a soft infimal convolution, we
can define the “soft” Moreau envelope,
1
Mfε (µ) := (f □ε R)(µ) = softminε f (u) + ∥µ − u∥22 .
u∈RM 2
We emphasize that this is operation is not the same as the convolution
of f with a Gaussian kernel. Indeed, we have
1
Z  
Mf (µ) = −ε log
ε
exp (−f (u) − ∥µ − u∥2 )/ε du.
2
RM 2
while
Z
(f ∗ κσ )(µ) := f (u)κσ (µ − u) du,
RM
where κσ is for instance defined in Eq. (14.3).
We saw that the Moreau envelope is a smooth function. One may
therefore ask what do we gain from using a soft Moreau envelope. The
benefit can be computational, as the latter can be approximated using
a discrete convolution.
346 Smoothing by integration

14.2 Fourier and Laplace transforms

Let us define the Fourier transform of f by


Z ∞
F (s) := F{f }(s) := f (t)e−i2πst dt, s ∈ R.
−∞

Note that F{f } is a function transformation: it transforms f into


another function F .

14.2.1 Convolution theorem


Now, consider the convolution
h(t) := (f ∗ g)(t).
If we define the three transformations
F := F{f }, G := F{g}, H := F{h},
the convolution theorem states that
H(s) = F{h}(s) = F (s) · G(s), s ∈ R.
Written differently, we have
F{f ∗ g} = F{f } · F {g}.
In words, in the Fourier domain, the convolution operation becomes a
multiplication. Conversely,
h(t) = (f ∗ g)(t) = F −1 {F · G}(t), t ∈ R.
The convolution theorem also holds if we replace the Fourier transform
with the Laplace transform or with the two-sided (bilateral) Laplace
transform.

14.2.2 Link between Fourier and Legendre transforms


In Section 13.2, we studied another function transformation: the convex
conjugate, also known as Legendre-Fenchel transform. We recap the
analogies between these transforms in Table 14.1. In particular, the
counterpart of
F{f ∗ g} = F{f } · F {g}.
14.2. Fourier and Laplace transforms 347

Table 14.1: Analogy between Fourier and Legendre transforms. See Proposition 13.3
for more conjugate calculus rules.

Fourier F{f } Legendre f ∗


Semiring (+, ·) (min, +)
Scaling (a > 0) f (t) = g(t/a) f (t) = ag(t/a)
F{f }(t) = aF{g}(as) f ∗ (s) = ag ∗ (s)
Translation f (t) = g(t − t0 ) f (t) = g(t − t0 )
F{f }(s) = e−i2πt0 s F{g}(s) f ∗ (s) = g ∗ (s) + t0
Convolution h=f ∗g h = f □g
F{h} = F{f } · F {g} h = f ∗ + g∗

2
Gaussian / quadratic f (t) =qe−at f (t) = a2 t2
2 2
F{f }(s) = πa e−π s /a f ∗ (s) = 2a1 2
s
Smoothing f ∗ κσ 1
f □ 2ε ∥ · ∥22

for the infimal convolution is

(f □g)∗ = f ∗ + g ∗ .

In words, the Legendre-Fenchel transform is to the infimal convolution


what the Fourier transform is to the convolution.

14.2.3 The soft Legendre-Fenchel transform


We saw in Section 13.2 that the Legendre-Fenchel transform (convex
conjugate) of a function f : RM → R is

f ∗ (v) := max ⟨u, v⟩ − f (u).


u∈RM

If necessary, we can support constraints by including an indicator


function in the definition of f . The conjugate can be smoothed out using
a log-sum-exp, which plays the role of a soft maximum (Section 13.5).
348 Smoothing by integration

Definition 14.2 (Soft convex conjugate).

fε∗ (v) := softmaxε ⟨u, v⟩ − f (u),


u∈RM

where we defined the soft maximum (assuming that it exists) over


S of any function g : S → R as
Z
softmaxε g(u) := ε log exp (g(u)/ε) du.
u∈S S

In the limit ε → 0, we recover the convex conjugate.

Computation using a convolution

We now show that this smoothed conjugate can be rewritten using a


convolution if we apply a bijective transformation to f .

Proposition 14.2 (Smoothed convex conjugate as convolution). The


smoothed conjugate can be rewritten as
1
 
fε∗ (v) = Q−1
ε (v)
Qε {f } ∗ Gε
where
1 1 ∥ · ∥22
  !
Gε := Cε = exp −
∥ · ∥22
2 2 ε
1 1 1
   
Qε {f } := Cε f (·) − ∥ · ∥2 = exp
2
∥ · ∥2 − f (·)
2
2 2ε ε
1
ε {F } := ∥ · ∥2 − ε log(F (·)).
Q−1 2
2

This insight was tweeted by Gabriel Peyré in April 2020.


14.2. Fourier and Laplace transforms 349

0
3 2 1 0 1 2 3

Figure 14.2: Applying the smoothed conjugate twice gives a smoothed biconjugate
(convex envelope) of the function.

Proof.
1 1
Z  
fε (v) := ε log exp ⟨u, v⟩ − f (u)) du
ε ε
1 1 1 1
Z  
= ε log exp − ∥u − v∥22 + ∥u∥22 + ∥v∥2 − f (u)) du
2
2ε 2ε 2ε ε
1 1 1 1
Z  
= ε log exp − ∥u − v∥22 + ∥u∥22 − f (u)) du + ∥v∥22
2ε 2ε ε 2
1
Z
= ε log Gε (v − u)Qε {f }(u)du + ∥v∥22
2
1
= ε log(Qε {f } ∗ Gε )(v) + ∥v∥2
2 
1 1

= ∥v∥ − ε log
2
(v)
2 Qε {f } ∗ Gε
1
 
= Q−1ε (v)
Qε {f } ∗ Gε

What did we gain from this viewpoint? The convex conjugate can
often be difficult to compute in closed form. If we replace RM with a dis-
crete set S (i.e., a grid), we can then approximate the smoothed convex
350 Smoothing by integration

conjugate in O(n log n), where n = |S|, using a discrete convolution,

(Qε {f } ∗ Gε )(v) ≈ Gε (v − u)Qε {f }(u)


X

u∈S
= Kq,

where K is the n×n Gaussian kernel matrix whose entries correspond to


exp(− 2ε
1
∥u−u′ ∥22 ) for u, u′ ∈ S and q is the n-dimensional vector whose
entries correspond to exp( 1ε ( 12 ∥v∥22 − f (u)) for u ∈ S. This provides
a GPU-friendly alternative to the fast Legendre transform algorithm,
discussed in Section 13.2. Of course, due to the curse of dimensionality,
the technique is limited to functions defined on low-dimensional sets.
We illustrate in Fig. 14.2 the application of the technique to computing
an approximate biconjugate (convex envelope) of a function.

Remark 14.1 (Link with the two-sided Laplace transform). For


one-dimensional functions, instead of using a convolution, we can
also write the soft convex conjugate as
1
Z ∞  
fε∗ (v) = ε log exp[uv − f (u)]) du
−∞ ε
n f o v
= ε log B e− ε −
ε
v
 
= −Cε−1 {B {Cε {f }}} −
ε
where we defined the two-sided (bilateral) Laplace transform
Z ∞
B{g}(v) := e−uv g(u)du
−∞

and where we assumed that the integral exists.

14.3 Examples

In this section, we review practical examples for which the convolution


with a Gaussian kernel enjoys an analytical solution.

14.3.1 Smoothed step function


14.3. Examples 351

Example 14.1 (Smoothed Heaviside). The Heaviside step function


is defined by

1 if u ≥ 0
step(u) := h(u) := .
0 otherwise

With the Gaussian kernel, we therefore obtain


Z µ Z ∞
(h ∗ κσ )(µ) = κσ (z)h(µ − z)dz + κσ (z)h(µ − z)dz
−∞ µ
Z µ
= κσ (z)dz
−∞
= Φσ (µ)
1 µ
  
= 1 + erf √ ,
2 2σ
where Φσ (µ) is the CDF of the Gaussian distribution with zero
mean and variance σ 2 , and where we used the error function
2
Z z
2
erf(z) := √ e−t dt,
π 0

that we both already encountered in Chapter 3. Although there is


no closed form for the error function, it is commonly available in
numerical analysis software, such as SciPy.

14.3.2 Smoothed ReLU function

Example 14.2 (Smoothed ReLU). The ReLU is defined by



u if u ≥ 0
r(u) := = u · h(u).
0 otherwise
352 Smoothing by integration

3.0 1.0
= 0.5
2.5 = 1.0 0.8
2.0 = 2.0 0.6
1.5
0.4
1.0
0.5 0.2

0.0 0.0
3 2 1 0 1 2 3 3 2 1 0 1 2 3

Figure 14.3: Smoothing of the ReLU and Heaviside functions by convolution with
a Gaussian kernel, for three values of the width σ.

Similarly to the previous example, we obtain


Z µ
(r ∗ κσ )(µ) = κσ (z)r(µ − z)dz
−∞
Z µ
= κσ (z)(µ − z)dz
−∞
Z µ Z µ
=µ κσ (z)z − κσ (z)zdz
−∞ −∞
= µΦσ (µ) + σ 2 κσ (µ).

In the second integral, setting a := 1


2σ 2
, we used
1 1 t 1
Z Z
2 2
ze−az dz = − et dt = − e + C = − e−az + C
2a 2a 2a
and t := −az 2 ⇒ zdz = − 2a
1
dt.

To illustrate differentiation of the convolution, we show how to


differentiate the smoothed ReLu.

Example 14.3 (Differentiating the smoothed ReLU). Differentiating


the smoothed ReLU from Example 14.2, we obtain

(r ∗ κσ )′ = (r′ ∗ κσ ) = h ∗ κσ = Φσ .

Therefore, unsurprisingly, the derivative of the smoothed ReLU is


the smoothed Heaviside step function. Differentiating once again,
14.4. Perturbation of blackbox functions 353

we obtain,

(r ∗ κσ )′′ = (h ∗ κσ )′ = (h′ ∗ κσ ) = δ ∗ κσ = κσ ,

where the derivative h′ is well-defined almost everywhere. We can


arrive at the same result by using that h ∗ κσ = Φσ and Φ′σ = κσ ,
since Φσ and κσ are the CDF and PDF of the Gaussian with zero
mean and σ 2 variance.

14.4 Perturbation of blackbox functions

In this section, we review how to approximately compute a convolution


with a kernel and its gradient using Monte-Carlo estimation.

14.4.1 Expectation in a location-scale family


A rather intuitive approach to smooth a function f : RM → R is to
average its values on an input µ, perturbed by some additive noise
Z ∼ p, for some noise distribution p. This defines the surrogate

fσ (µ) := EZ∼p [f (µ + σZ)].

The parameter σ controls the perturbation strength: as σ → 0, we


naturally recover f . An equivalent viewpoint is obtained by defining
the transformation (change of variables)

U := µ + σZ.

We then have
U ∼ pµ,σ ,
where pµ,σ is the location-family distribution generated by the noise
distribution p. It is the pushforward distribution of Z through the
transformation (see Section 12.4.4). In this notation, the initial noise
distribution p is then simply p = p0,1 . The perturbed function can then
be expressed from these two perspectives as

fσ (µ) = EZ∼p0,1 [f (µ + σ · Z)]


= EU ∼pµ,σ [f (U )]. (14.4)
354 Smoothing by integration

Writing the expectation as the integral of a p.d.f, we naturally recover


the smoothing by convolution presented earlier,
Z
fσ (µ) = f (µ + σz)p0,1 (z)dz

= f ∗ κσ (µ),

where we defined the kernel

κσ (z) := p0,σ (−z).

In the sequel, we assume that the noise distribution decomposes as

p0,1 (z) := exp(−ν(z))/C,

where ν(z) is the log-density of the noise distribution and C is a


normalization constant. For instance, the Gaussian distribution with
diagonal covariance matrix and the corresponding
√ M Gaussian kernel
are obtained with ν(z) = 12 ∥z∥22 and C = 2π .

Approximation by Monte-Carlo estimation


Instead of approximating the integral above (continuous convolution)
with a discrete convolution on a grid, as we did in Section 14.1.3,
the expectation perspective suggests that we can estimate fσ (µ) by
Monte-Carlo estimation: we simply draw samples from the distribution,
evaluate the function at these samples and average. Beyond mere Monte-
Carlo estimation, more elaborate approximation schemes are studied in
(Chaudhuri and Solar-Lezama, 2010).

14.4.2 Gradient estimation by reparametrization


If f is differentiable with finite expectation and Leibniz’s integration
rule holds, we have

∇fσ (µ) = EZ∼p0,1 [∇f (µ + σ · Z)]. (14.5)

If f is differentiable almost everywhere, the formula may still hold, see


Section 12.1. For example, if f is the ReLU, then ∇f is the Heaviside
step function, and we obtain the correct gradient of fσ using the formula
14.4. Perturbation of blackbox functions 355

above; see Example 14.1. However, if f is not absolutely continuous,


the formula may not hold. For example, if f is the Heaviside function,
the right-hand side of (14.5) is 0 which does not match the gradient of
fσ ; see again Example 14.1.
From the second expression of fσ in (14.4), we can see the formula
of the gradient in (14.5) as a reparametrization trick U = µ + σZ; see
Section 12.4. Namely, we have

∇fσ (µ) = ∇µ EU ∼pµ,σ [f (U )]


= ∇µ EZ∼p0,1 [f (µ + σ · Z)]
= EZ∼p0,1 [∇µ f (µ + σ · Z)]
= EZ∼p0,1 [∇f (µ + σ · Z)]. (14.6)

14.4.3 Gradient estimation by SFE, Stein’s lemma


In some cases, we may not have access to ∇f or f may not be absolutely
continuous and therefore the formula in (14.5) cannot apply. For these
cases, we can use the score function estimator (SFE) from Section 12.3.
Here, for fσ (µ) = EU ∼pµ,σ [f (U )], we obtain

∇fσ (µ) = EU ∼pµ,σ [f (U )∇µ log pµ,σ (U )].

Since the PDF can be written as


1
pµ,σ (u) = p0,1 ((u − µ)/σ),
σ
where
p0,1 (z) := exp(−ν(z))/C,
we obtain
∇µ log pµ,σ (u) = ∇ν((u − µ)/σ)/σ.
To summarize, we have shown that

∇fσ (µ) = EU ∼pµ,σ [f (U )∇ν((U − µ)/σ)/σ]


= EZ∼p0,1 [f (µ + σ · Z)∇ν(Z)/σ], (14.7)

where we used the change of variable Z = (U −µ)/σ. The same technique


can also be used if we want to estimate the gradient w.r.t. θ = (µ, σ) or
356 Smoothing by integration

if we want to estimate the Jacobian of the expectation of a vector-valued


function.
In the particular case of Gaussian noise, since ∇ν(z) = z, we obtain

∇fσ (µ) = EZ∼p0,1 [f (µ + σ · Z)Z/σ].

This is known as Stein’s lemma. It should be noted that the above is


an unbiased estimator of the gradient of the smoothed function fσ , but
a biased estimator of the gradient of the original function f (assuming
that it exists). However, smoothing is usually a good thing, as it can
accelerate the convergence of gradient-based algorithms. Computing
the gradient of perturbed general programs is studied in detail in
(Kreikemeyer and Andelfinger, 2023).

14.4.4 Link between reparametrization and SFE

Using the log-derivative identity, we have for any distribution with


differentiable density p

∇p(z)
Z  
EZ∼p [h(Z)∇ log p(Z)] = h(z) p(z)dz
RM p(z)
Z
= h(z)∇p(z)dz.
RM

Using integration by parts and assuming that h(z)p(z) goes to zero


when ∥z∥ → ∞, we have
Z Z
h(z)∇p(z)dz = − p(z)∇h(z)dz.
RM RM

We have therefore the identity

EZ∼p [h(Z)∇ log p(Z)] = −EZ∼p [∇h(Z)].

Importantly, contrary to the SFE estimator from Section 12.3, this


identity uses gradients with respect to z, not with respect to the
parameters of the distribution. Nevertheless, using the reparametrization
14.4. Perturbation of blackbox functions 357

h(z) := f (µ + σ · z), we have ∇h(z) = ∇f (µ + σ · z) · σ so that

∇fσ (µ) = ∇µ EU ∼pµ,σ [f (U )]


= EZ∼p0,1 [∇f (µ + σ · Z)] (reparametrization trick)
= −EZ∼p0,1 [h(Z)∇ log p(Z)/σ]
= EZ∼p0,1 [h(Z)∇ν(Z)/σ]
= EZ∼p0,1 [f (µ + σ · Z)∇ν(Z)/σ] (score function estimator)

Essentially, integration by parts allowed us to convert the reparametriza-


tion trick estimator into the SFE estimator. For more applications of
integration by parts in machine learning, see Francis Bach’s excellent
blog post.

14.4.5 Variance reduction and evolution strategies


As discussed in Chapter 12, the SFE suffers from high variance. We now
apply variance reduction techniques to it. To do so, we assume that
∇ν(Z) has zero mean for Z ∼ p0,1 . This assumption for example holds
for Gaussian or centered Gumbel noise distributions. This assumption
implies that

EZ∼p0,1 [f (µ)∇ν(Z)/σ] = f (µ)EZ∼p0,1 [∇ν(Z)/σ] = 0

and therefore

∇fσ (µ) = EZ∼p0,1 [(f (µ + σ · Z) − f (µ))∇ν(Z)/σ] . (14.8)

This is an example of control variate discussed in Section 12.3. This


can be interpreted as using a finite difference for computing a direc-
tional derivative in the random direction Z (see “limit case” below).
Inspired by a central finite difference, we can also use

∇fσ (µ) = EZ∼p0,1 [(f (µ + σ · Z) − f (µ − σ · Z))∇ν(Z)/(2σ)] . (14.9)

These estimators have been used as part of blackbox (zero-order) op-


timization algorithms, such as evolution strategies (Salimans et
al., 2017) or random gradient-free optimization (Nesterov and
Spokoiny, 2017). For quadratic functions, it is easy to show that the
second estimator achieves lower variance (Recht and Frostig, 2017). The
358 Smoothing by integration

103
SFE
102
SFE with forward difference
SFE with central difference
Gradient error 101

100

10 1

100 101 102 103 104 105 106 107


Number of samples
Figure 14.4: Comparison of the score function estimator (SFE) with or without
variance reduction for blackbox gradient estimation. We show the error |∇f (µ) −
∇fσ (µ)| for f (u) := u3 and fσ (µ) := E[f (µ + σZ)], where Z ∼ Normal(0, 1) and
σ := 0.1. To estimate ∇fσ (µ), we compare three estimators: the vanilla SFE Eq. (14.7),
the SFE estimator with forward difference (variance reduced) Eq. (14.8), and the SFE
estimator with central difference (varianced reduced) Eq. (14.9). In all three cases,
we approximate the expectation by Monte-Carlo estimation using some number of
samples. The variance-reduced estimators not only achieve smaller error, they are
also more numerically stable as σ gets smaller.

idea of sampling both Z and −Z simultaneously is called antithetic


(Geweke, 1988) or mirrored sampling (Brockhoff et al., 2010). Evolution
strategies have also been used to obtain unbiased gradient estimators
of partially unrolled computational graphs (Vicol et al., 2021). We
empirically compare the SFE with or without variance reduction for
blackbox gradient estimation in Fig. 14.4.

14.4.6 Zero-temperature limit

We now discuss the limit case σ → 0. That is, we assume that we do


not want to perform smoothing and that ∇f exists. We recall that the
directional derivative of f at µ in the direction z is

∂f (µ)[z] = ⟨∇f (µ), z⟩


= lim [f (µ + σ · z) − f (µ)] /σ.
σ→0
14.5. Gumbel tricks 359

When σ → 0 and Z follows the standard Gaussian distribution, meaning


that ∇ν(z) = z, Eq. (14.8) therefore becomes

∇fσ (µ) = EZ∼p0,1 [∂f (µ)[Z]∇ν(Z)]


= EZ∼p0,1 [∂f (µ)[Z]Z]
= EZ∼p0,1 [⟨∇f (µ), Z⟩Z]
h i
= EZ∼p0,1 ∇f (µ)ZZ ⊤
= ∇f (µ).

This should not be too surprising, as we already know from the convo-
lution perspective that fσ (µ) = (f ∗ κσ )(µ) → f (µ) when σ → 0. This
recovers the randomized forward-mode estimator already presented in
Section 8.7.

14.5 Gumbel tricks

14.5.1 The Gumbel distribution


The Gumbel distribution is a distribution frequently used in extreme
value theory. It arises naturally as the distribution of the logarithm of a
negative exponential family. We consider its centered version with PDF

p(z) = exp(−ν(z)),

where
ν(z) = z + γ + exp(−(z + γ)),
and where γ ≈ 0.577 is Euler’s constant. We extend it to a multivariate
distribution by taking M independent centered Gumbel distributions
Z = (Z1 , . . . , Zm ) with associated location-scale family

µ + σZ ∼ pµ,σ ,

and p0,1 = p. As EZ∼p0,1 [∇ν(Z)] = 0, we can use the Gumbel noise as


an alternative to the Gaussian noise used in Section 14.4. Thankfully,
in particular cases, we can compute closed-form expressions of the
expectation.
360 Smoothing by integration

Remark 14.2 (Sampling Gumbel noise). If U ∼ Uniform(0, 1), then


Z := − log(− log(U )) satisfies Z ∼ Gumbel(0, 1).

Remark 14.3 (Link between Gumbel and exponential distribution).


A random variable Z is distributed as a pµ,1 Gumbel distribution
if and only if exp(−Z) is distributed as an exponential distribution
Exp(exp(µ − γ)). To see this, one can simply compute the CDF
of exp(−Z) and recognize the CDF of Exp(exp(µ − γ)). There-
fore, when comparing Gumbel distributions we can use standard
properties of the exponential distribution.

14.5.2 Perturbed comparison

To start with, the Gumbel distribution can be used to smooth a binary


comparison like the greater than or equal operators. Recall that the
latter is defined for any µ1 , µ2 ∈ R as

1 if µ1 ≥ µ2
gt(µ1 , µ2 ) := = step(µ1 − µ2 ),
0 if µ1 < µ2

where step is the Heaviside function. As shown below, by perturbing each


variable with Gumbel noise, we recover logistic(a − b) = 1/(1 + e−(a−b) )
as an approximation of step(a − b).

Proposition 14.3 (Gumbel trick for binary variables). Let


Z1 , Z2 ∼ Gumbel(0, 1) be independent random variables. The differ-
ence of their location-scale transform (Section 12.4.1) is distributed
according to a logistic distribution (Remark 3.1), i.e.,

µ1 + σZ1 − (µ2 + σZ2 ) ∼ Logistic(µ1 − µ2 , σ),

for µ1 , µ2 ∈ R and σ > 0. In particular, we have


1
E[gt(µ1 + σZ1 , µ2 + σZ2 )] = .
1+ e−(µ1 −µ2 )/σ
14.5. Gumbel tricks 361

Proof. We first derive the CDF of µ1 + σZ1 − (µ2 + σZ2 ) as


P(µ1 + σZ1 − (µ2 + σZ2 ) ≤ t) = P (µ1 /σ + Z1 ≤ (µ2 + t)/σ + Z2 )
 
= P e−(µ1 /σ+Z1 ) ≥ e−((µ2 +t)/σ+Z2 ) .

By Remark 14.3, e−(µ1 /σ+Z1 ) ∼ Exp(exp(µ1 /σ − γ)), and similarly for


e−(µ2 +t)/σ+Z2 . Now one easily shows that if U ∼ Exp(u), V ∼ Exp(v)
independent, then P(U ≤ V ) = u/(u + v). Hence, we get
e(µ2 +t)/σ−γ
P(µ1 + σZ1 − (µ2 + σZ2 ) ≤ t) =
e(µ2 +t)/σ−γ + eµ1 /σ−γ
1
= .
1+e −(t−(µ 1 −µ2 ))/σ

We recognize the CDF of the logistic distribution with mean µ1 − µ2


and scale σ, denoted Logistic(µ1 − µ2 , σ). For the last claim, we simply
have that
E[gt(µ1 + σZ1 , µ2 + σZ2 )] = E [step(µ1 + σZ1 − (µ2 + σZ2 )]
= P(µ1 + σZ1 − (µ2 + σZ2 ) ≥ 0)
1
= .
1 + e 1 −µ2 )/σ
−(µ

14.5.3 Perturbed argmax


Suppose we want to smooth
y(u) := arg max ⟨y, u⟩ = ϕ(i(u)),
y∈{e1 ,...,eM }

where
i(u) := arg max ui
i∈[M ]

ϕ(i) := ei
with ϕ(i) is the one-hot encoding of i ∈ [M ]. It turns out that the
function y(u) perturbed using Gumbel noise enjoys a closed form
expectation, which is nothing else than the softargmax.
362 Smoothing by integration

Proposition 14.4 (Gumbel trick for categorical variables). Given M


independent Gumbel random variables Z ∼ p0,1 , define

Y := i(µ + σ · Z) ∈ [M ],

for µ ∈ RM and σ > 0. Then, Y is distributed according to

qµ,σ := Categorical(softargmax(µ/σ)).

Moreover, we have

yσ (µ) = EZ∼p0,1 [y(µ + σ · Z)]


= EY ∼qµ,σ [ϕ(Y )]
= softargmax(µ/σ).

Proof. For k ∈ [M ], we have that


!
P(Y = k) = P arg max{µi + σZi } = k
i∈[M ]
!
= P arg min{e−µi /σ−Zi } = k
i∈[M ]

By Remark 14.3, we have that e−µi /σ−Zi ∼ Exp(exp(µi /σ−γ)). One eas-
ily verifies as an exercise, that, for U1 , . . . , UM independent exponential
variables with parameters u1 , . . . , um , we have P(arg mini∈[M ] {Ui } =
k) = uk / M i=1 ui . Hence, we get
P

exp(µk /σ)
P(Y = k) = PM ,
i=1 exp(µi /σ)
that is,
Y ∼ Categorical(softargmax(µ/σ)).
The last claim follows from the distribution of Y and the definition of
ϕ.

14.5.4 Perturbed max


A similar result holds if we now wish to perturb the max instead of the
argmax.
14.5. Gumbel tricks 363

Proposition 14.5 (Link to log-sum-exp). Given M independent Gum-


bel random variables Z ∼ p0,1 , and,

f (u) := max ui ,
i∈[M ]

the random variable

V := f (µ + σ · Z).

is distributed according to

qµ,σ := pσLSE(µ/σ),σ .

Moreover, we have

fσ (µ) = EZ∼p0,1 [f (U + σ · Z)] = EV ∼qµ,σ [V ] = σ · LSE(θ/σ).

Proof. We derive the CDF of f (µ + σ · Z) as


!
−(µi /σ−Zi ) −t/σ
P(max {µi + σZi } ≤ t) = P min {e }≥e
i∈[M ] i∈[M ]

We have e−(µi /σ−Zi ) ∼ Exp(exp(µi /σi )−γ) and for U1 , . . . , UM indepen-


dent exponential random variables with parameters ui , mini∈[M ] Ui ∼
Exp( M i=1 ui ). Hence,
P

M
! !
P max {µi + σZi } ≤ t = exp − (exp(µi /σ − γ)) exp(−t/σ)
X
i∈[M ]
i=1
= exp(− exp(−(t − σLSE(µ/σ))/σ − γ)).

We recognize the CDF of the centered Gumbel distribution with location-


scale parameters σLSE(µ/σ), σ.

For further reading on the Gumbel trick, see Tim Vieira’s great
blog.

14.5.5 Gumbel trick for sampling


The Gumbel trick is also useful in its own right for sampling with-
out computing the normalization constant of the softargmax. Indeed,
364 Smoothing by integration

Proposition 14.4 ensures that if Z is Gumbel noise, then Y is dis-


tributed according to Categorical(softargmax(µ/σ)). Computing the
arg-maximum, as required to compute Y , can be done in one pass. There-
fore, we obtain a one-pass algorithm to sample directly from the logits
µ, without explicitly computing the probabilities softargmax(µ/σ).
One may wonder whether such trick could also be used with the
normal distribution. Unfortunately, there is no closed form in this case
because it would require integrating the CDF of the maximum of M − 1
Gaussian distributions. However, other tricks can be defined such as
using Weibull distributions, see Balog et al. (2017).

14.5.6 Perturb-and-MAP
Previously, we discussed the Gumbel trick in the classification setting,
where Y = [M ]. In the structured prediction setting, outputs are
typically embedded in RM but the output space is very large. That is,
Y ⊆ RM but |Y| ≫ M . Structured outputs are then decoded using a
maximum a-posteriori (MAP) oracle

f (u) := max⟨y, u⟩
y∈Y

y(u) := arg max⟨y, u⟩.


y∈Y

For this setting, the perturbed versions of f and y,

fσ (µ) := EZ∼p0,1 [f (µ + σ · Z)]


yσ (µ) := EZ∼p0,1 [y(µ + σ · Z)],

no longer enjoy a closed form in general. However, we can approximate


them using Monte-carlo estimation. For the gradient of ∇fσ (µ), two
estimators exist (Abernethy et al., 2016; Berthet et al., 2020).

Proposition 14.6 (Gradient of perturbed max). Let Y ⊆ RM and


p0,1 be a noise distribution with density

p0,1 (z) := exp(−ν(z))/C.


14.5. Gumbel tricks 365

Then, fσ (µ) is smooth, and its gradient is given by

∇fσ (µ) = EZ∼p0,1 [y(µ + σ · Z)]


= EZ∼p0,1 [f (µ + σ · Z)∇ν(Z)σ]
∈ conv(Y).

We therefore have ∇fσ (µ) = yσ (µ).

The first estimator is simply a consequence of the reparametrization


trick seen in Eq. (14.6) and of y = ∇f , which follows from Danskin’s
theorem (see Section 11.2). The second estimator is just SFE seen in
Eq. (14.7). The first estimator usually has lower variance, as it uses
more information, namely that y = ∇f .
The Jacobian of yσ (µ) also has two estimators (Abernethy et al.,
2016; Berthet et al., 2020).

Proposition 14.7 (Jacobian of perturbed argmax). Under the same


notation as in Proposition 14.6, we have
h i
∂yσ (µ) = EZ∼p0,1 y(µ + σZ)∇ν(Z)⊤ /σ
h   i
= EZ∼p0,1 f (µ + σZ) ∇ν(Z)∇ν(Z)⊤ − ∇2 ν(Z) /σ 2 .

The first estimator uses SFE. The second estimator is obtained by


differentiating through

yσ (µ) = ∇fσ (µ) = EZ∼p0,1 [f (µ + σ · Z)∇ν(Z)/σ].

The first estimator usually has lower variance. Note that we cannot use
the reparametrization trick this time, since y is discontinuous, contrary
to f .

Link between perturbation and regularization


As shown in (Berthet et al., 2020, Proposition 2.2), assuming Y is a
convex polytope with non-empty interior and p has a strictly positive
density, the function

fσ (µ) := EZ∼p0,1 [f (µ + σ · Z)] = EZ∼p0,1 [max⟨µ + σ · Z, y⟩]


y∈Y
366 Smoothing by integration

is strictly convex and its convex conjugate fσ∗ (y) is Legendre-type. We


can therefore rewrite fσ (µ) from the regularization perspective as

fσ (µ) = max⟨µ, y⟩ − fσ∗ (y).


y∈Y

and ∇fσ (µ) = yσ (µ) is a mirror map, a one-to-one mapping from


RM to the interior of Y. Unfortunately, fσ∗ (y) does not enjoy a closed
form in general. Conversely, does any regularization has a corresponding
noise distribution? The reciprocal is not true.

14.5.7 Gumbel-softmax

Suppose we want to smooth out the composition h(u) := g(y(u)) by

hσ (µ) := EZ∼p0,1 [g(y(µ + σZ))]

where
y(u) := arg max ⟨y, u⟩.
y∈{e1 ,...,eM }

This is useful to compute the expectation of a loss (instead of the


loss of an expectation). To compute the gradient of hσ (µ), we can
readily use SFE. However, we saw that it suffers from high variance.
Unfortunately, we cannot use the reparametrization trick here, since
y(u) is a discontinuous function.
The key idea of the Gumbel-sofmax (Jang et al., 2016; Maddison
et al., 2016) is to replace y(u) with a softargmax (with temperature
parameter τ ) to define

hσ,τ (µ) := EZ∼p0,1 [g(softargmaxτ (µ + σZ))] .

Since the softargmax is a regularized argmax, we can see the Gumbel-


softmax approach as using both regularization and perturbation. The
key benefit is that we can now use the reparametrization trick to get an
unbiased estimator of ∇hσ,τ (µ). However, this will be a biased estimator
of ∇hσ (µ), the amount of bias being controlled by the temperature τ .
In particular, in the limit case τ → 0, we have hσ,τ (µ) → hσ (µ). One
caveat, however, is that the function g needs to be well defined on △M ,
instead of {e1 , . . . , eM }.
14.6. Summary 367

The use of the softargmax transformation defines a continuous


distribution (Jang et al., 2016; Maddison et al., 2016), that we now
explain with σ = 1.

Proposition 14.8 (Gumbel-softargmax / Concrete distributions). Let


us define the continuous random variable

T := softargmaxτ (µ + Z) ∈ △M ,

where Z ∼ p0,1 is a Gumbel random variable. Then T is distributed


according to a distribution with density
M
!−M M
πi Y πi
pµ,τ (t) := Γ(M )τ M −1
X
tτ τ +1 ,
i=1 i i=1 ti

where π := softargmax(µ).

We can extend the Gumbel softargmax to the structured setting by


replacing
y(u) := arg max⟨y, u⟩,
y∈Y

with its regularized variant (Paulus et al., 2020). Similarly as before,


one caveat is that g needs to be well defined on conv(Y) instead of Y.
Moreover, regularizing y is not always easy computationally.

14.6 Summary

• We studied smoothing techniques based on function convolution


with a kernel. Due to the commutativity of the convolution, we
can alternatively see these as the expectation of the function,
perturbed with noise, assuming the kernel corresponds to the
PDF of some noise distribution.

• Their gradients can be estimated using the path gradient estimator


(PGE) or score function estimator (SFE), depending on whether
the gradient of the original function is available or not.

• We saw that Stein’s lemma is a special case of SFE used with


Gaussian noise. The so-called “evolution strategies” are just a
368 Smoothing by integration

variant of that with variance reduction and can be interpreted as


randomized finite difference.

• When using Gumbel noise, we were able to derive closed-form


expressions for the expectation in specific cases: perturbed com-
parison, perturbed argmax and perturbed max.

• We also studied the connections between smoothing by optimiza-


tion and smoothing by integration. Infimal convolution is the
counterpart of convolution, and the Legendre-Fenchel transform
is the counterpart of Fourier and Laplace’s transforms. Infimal
convolution uses a min-plus algebra in the log domain, while
the convolution uses a sum-product algebra in the exponential
domain.
Part V

Optimizing differentiable
programs
15
Optimization basics

15.1 Objective functions

Consider a function L, for example evaluating the error or “loss” L(w)


achieved by a model with parameters w ∈ W, where W = RP . To find
the best possible model parameterization, we seek to minimize L(w),
that is, to compute approximately

L⋆ := inf L(w),
w∈W

assuming that the infimum exists (i.e., L(w) is lower bounded). We will
denote a solution, if it exists, by
 
w⋆ ∈ arg min L(w) := w ∈ W : L(w) = min

L(w′ ) .
w∈W w ∈W

In general, an analytical solution is not available and computing such


a minimum approximately requires an optimization algorithm. An
optimization algorithm is an iterative procedure, which, starting from an
initial point w0 , outputs after t iterations a point wt that approximates
the minimum of L up to some accuracy ε, i.e.,

L(wt ) − L⋆ ≤ ε. (15.1)

370
15.2. Oracles 371

15.2 Oracles

To produce iterates w1 , w2 , . . . that converge to a minimum, the al-


gorithm naturally needs to have access to information about L. For
example, the algorithm needs a priori to be able to evaluate L to know
if it decreased its value or not. Such information about the function
is formalized by the notion of oracles (Nemirovski and Yudin, 1983).
Formally, oracles are procedures that an algorithm can call to access
information about the objective L(w) at any given point w ∈ W. We
usually mainly consider the following three oracles.
• Zero-order oracle: evaluating the function L(w) ∈ R.
• First-order oracle: evaluating the gradient ∇L(w) ∈ W for L
differentiable.
• Second-order oracle: evaluating the Hessian matrix ∇2 L(w),
or evaluating the Hessian-vector product (HVP) ∇2 L(w)v ∈ W,
for L twice differentiable and any vector v ∈ W.
Given an oracle O for a function L, we can formally define an optimiza-
tion algorithm as a procedure which computes the next iterate as a
function of all past and current information. Formally, an algorithm A
builds a sequence w1 , . . . , wt from a starting point w0 as

wt+1 := A(w0 , . . . , wt , O(w0 ), . . . , O(wt ), λ),

where λ ∈ Λ ⊆ RQ encapsulates some hyperparameters of the algorithm,


such as the stepsize. Oftentimes, algorithms build the next iterate simply
from the information collected at the current iterate, without using all
past iterates. That is, they take the form wt+1 = A(wt , O(wt ), λ). A
classical example is the gradient descent algorithm, that uses a first-order
oracle to compute iterates of the form

wt+1 := wt − γ∇L(wt ),

where the stepsize γ is a hyperparameter of the algorithm. The notion of


oracle therefore delineates different classes of algorithms. For instance,
we may consider zero-order algorithms or first-order algorithms.
372 Optimization basics

15.3 Variational perspective of optimization algorithms

One of the most basic optimization algorithms is the proximal point


method, which produces wt+1 from wt by
1
wt+1 := arg min L(w) + ∥w − wt ∥22 .
w∈W 2γ
In words, the next iterate is produces by solving a trade-off between
minimizing the function L and staying close to wt . Unfortunately, the
optimization problem involved in performing this parameter update is
as difficult as the original optimization problem, making the proximal
point method impractical.
As we shall see in Chapter 16 and Chapter 17, many optimization
algorithms can be seen as an approximation of the proximal point
method, in the sense that they solve
1
wt+1 := arg min L̃(w, wt ) + ∥w − wt ∥22 .
w∈W 2γ
or more generally
1
wt+1 := arg min L̃(w, wt ) + d(w, wt ),
w∈W γ

where L̃(w, wt ) is an approximation of L(w) around wt and d(w, w′ )


is some form of distance between w and w′ . Different choices of L̃ and
d lead to different optimization algorithms, and to different trade-offs.

15.4 Classes of functions

When studying algorithms theoretically, stronger results can often be


stated by restricting to certain classes of functions. We already covered
continuous and differentiable functions in Chapter 2. We review a few
important other classes in this section.

15.4.1 Lipschitz functions


Lipschitz continuity is a stronger form of continuity. Intuitively, a
Lipschitz continuous function is limited in how fast it can change.
15.4. Classes of functions 373

Definition 15.1 (Lipschitz-continuous functions). A function g : W →


F is β-Lipschitz continuous if for all w, v ∈ W

∥g(w) − g(v)∥2 ≤ β∥w − v∥2 .

Note that the definition is valid even for vector-valued functions.

With respect to arbitrary norms

Thanks to dual norms reviewed in Section 18.1, we can state a more


general definition of Lipschitz continuity based on arbitrary norms,
instead of the 2-norm. Moreover, we may consider Lipschitz-continuity
over a subset of the input domain.

Definition 15.2 (Lipschitz continuous functions w.r.t. a norm). A func-


tion g : W → F is said to be β-Lipschitz w.r.t. a norm ∥ · ∥ over a
set C ⊆ W if for all w, v ∈ C

∥g(w) − g(v)∥∗ ≤ β∥w − v∥.

When ∥ · ∥ = ∥ · ∥2 , we recover Definition 15.1, since the 2-norm is


dual to itself.

15.4.2 Smooth functions

A differentiable function L is said to be β-smooth if its gradients are


β-Lipschitz continuous. Setting g(w) = ∇L(w) in Definition 15.1, we
obtain the following definition.

Definition 15.3 (Smooth functions). A differentiable function


L : W → R is β-smooth for β > 0 if for all w, v ∈ W

∥∇L(w) − ∇L(v)∥2 ≤ β∥w − v∥2 .

Smoothness ensures that the information provided by the gradient


at some w is meaningful in a neighborhood of w, since its variations are
upper-bounded. If the variations were not bounded, the gradient at v
374 Optimization basics

arbitrarily close to w could drastically change, rendering the information


provided by a first-order oracle potentially useless.
Smoothness of a function can be interpreted as having a quadratic
upper bound on the function as formalized below.

Proposition 15.1 (Smooth functions). If a differentiable function


L : W → R is β-smooth then for all w, v ∈ W,
β
|L(w) − L(v) + ⟨∇L(v), w − v⟩| ≤ ∥w − v∥22 .
2
In particular, we have
β
L(w) ≤ L(v) + ⟨∇L(v), w − v⟩ + ∥w − v∥22 .
2

Proof. This is shown by bounding |L(v) − L(w) − ⟨∇L(w), v − w⟩|


using the integral representation of the objective along w − v, i.e.,
|L(v) − L(w) − ⟨∇L(w), v − w⟩| = | 01 ⟨∇L(w + s(v − w)), v − w⟩ds −
R

⟨∇L(w), v − w)⟩| ≤ 01 ∥∇L(w + s(v − w))ds − ∇L(w)∥2 ∥v − w∥2 ≤


R

L∥w − v∥22 /2, where the last inequality follows from the smoothness
assumption and standard integration.

In other words, L(w) is upper-bounded and lower-bounded around


v by a quadratic function of w. We will see in Section 16.1 that this
characterization gives rise to a variational perspective on gradient
descent.

With respect to arbitrary norms

We can generalize the definition of smoothness in Definition 15.3 to


arbitrary norms.

Definition 15.4 (Smooth functions w.r.t. a norm). A function L :


W → R is β-smooth w.r.t. a norm ∥ · ∥ over a set C if for all
w, v ∈ C
β
∥∇L(w) − ∇L(v)∥∗ ≤ ∥w − v∥.
2
15.4. Classes of functions 375

An equivalent characterization, generalizing Proposition 15.1 to


arbitrary norms, is given below (see, e.g. Beck (2017, Theorem 5.8)).

Proposition 15.2 (Smooth functions w.r.t. a norm). If a differentiable


function L : W → R is β-smooth w.r.t. a norm ∥ · ∥ over a set C,
then for all w, v ∈ C
β
| L(w) − L(v) − ⟨∇L(v), w − v⟩ | ≤ ∥w − v∥2 ,
| {z } 2
BL (w,v)

where Bf is the Bregman divergence generated by f (Definition 18.2).

15.4.3 Convex functions


A convex function is a function such that its value on the average of two
or more points is smaller than the average of the values of the functions
at these points. This is illustrated in Figure 15.2 and formalized below.

Definition 15.5 (Convex functions). A function L : W → R is said


to be convex if for all w, v ∈ W and τ ∈ [0, 1]

L(τ w + (1 − τ )v) ≤ τ L(w) + (1 − τ )L(v).

The function L is strictly convex if the above inequality is strict


for all w ̸= v.

The above characterization can easily be generalized to multiple


points. Namely, for w1 , . . . , wn ∈ W and τ1 , . . . τn ≥ 0 such that
i=1 τi = 1 (that is, τ1 , . . . , τn defines a probability distribution over
Pn

[n]), we have if L is convex that


n n
!
τi L(wi ).
X X
L τi wi ≤
i=1 i=1

The point ni=1 τi wi is called a convex combination. This can be seen


P

as comparing the function at the average point to the average of the


values at theses points and can further be generalized to any random
variable.
376 Optimization basics

Proposition 15.3 (Jensen’s inequality). A function L : W → R is


convex if it satisfies Jensen’s inequality, that is, for any random
variable W on W,

L(E[W ]) ≤ E[L(W )],

provided that the expectations are well-defined.

If the function considered is differentiable, an alternative charac-


terization of convexity is to observe how linear approximations of the
function lower bound the function. This is illustrated in Figure 15.2
and formalized below.
Definition 15.6 (Convex differentiable functions). A differentiable func-
tion L : W → R is convex if and only if for all w, v ∈ W

L(v) ≥ L(w) + ⟨∇L(w), v − w⟩.

The function L is strictly convex if and only if the above inequality


is strict for any w ̸= v.

The above characterization pinpoints the relevance of convex func-


tion in optimization: if we can find a point ŵ with null gradient, then
we know that we have found the minimum as we have
∇L(ŵ) = 0 =⇒ ∀v ∈ RP , L(v) ≥ L(ŵ) =⇒ L(ŵ) = L⋆ .
This means that by having access to the gradient of the function or an
approximation thereof, we have access to a sufficient criterion to know
whether we found a global minimum. In the case of a gradient descent
on a smooth function, convexity ensures convergence to a minimum at
a sublinear rate as detailed below.
Finally, if the function is twice differentiable, convexity of a function
can be characterized in terms of the Hessian of the function.
Proposition 15.4 (Convex twice differentiable functions). A twice dif-
ferentiable function L : W → R is convex if and only if its Hessian
is positive semi-definite,

∀w ∈ W, ∇2 L(w) ⪰ 0, i.e., ∀w, v ∈ W, ⟨v, ∇2 L(w)v⟩ ≥ 0.


15.4. Classes of functions 377

The function L is strictly convex if and only if the Hessian is positive-


definite, ∀w ∈ W, ∇2 L(w) ≻ 0, i.e., ∀w, v ∈ W, ⟨v, ∇2 L(w)v⟩ >
0.

15.4.4 Strongly-convex functions


Convexity can also be strengthened by considering µ-strongly convex
functions.

Definition 15.7 (Strongly-convex functions). A function L : W →


R is µ-strongly convex for µ > 0 if for all w, v ∈ W and τ ∈ [0, 1]
µ
L(τ w + (1 − τ )v) ≤ τ L(w) + (1 − τ )L(v) − τ (1 − τ )∥w − v∥22 .
2
A differentiable function L is µ-strongly convex if and only if for
all w, v ∈ W
µ
L(v) ≥ L(w) + ⟨∇L(w), w − v⟩ + ∥w − v∥22 .
2
A twice differentiable function is µ-strongly convex if and only if
its Hessian satisfies

∀w ∈ W, ∇2 L(w) ⪰ µ I, i.e., ∀w, v ∈ W, ⟨v, ∇2 L(w)v⟩ ≥ µ∥v∥22 .

The characterization of strong convexity for differentiable functions


states that L(w) is lower-bounded by a quadratic. This enables the
design of linearly convergent algorithms as explained later. We naturally
have the implications

L strongly convex =⇒ L strictly convex =⇒ L convex.

With respect to arbitrary norms


A function can be strongly convex w.r.t. an arbitrary norm, simply
by replacing the 2-norm in Definition 15.7 with that norm. For differ-
entiable strongly convex functions, we have the following alternative
characterization, generalizing Definition 15.7 to arbitrary norms.
378 Optimization basics

Proposition 15.5 (Differentiable strongly-convex functions). If a dif-


ferentiable function L : W → R is µ-strongly convex w.r.t. a norm
∥ · ∥ over a set C, then for all w, v ∈ C
µ
∥w − v∥2 ≤ L(w) − L(v) − ⟨∇L(v), w − v⟩ .
2 | {z }
BL (w,v)

Obviously, if a function L is µ-strongly convex, then, λL is (µλ)-


strongly convex. Because all norms are equivalent, if a function is
strongly convex w.r.t. a norm, it is also strongly-convex w.r.t. another
norm. However, stating the norm w.r.t. which strong convexity holds can
lead to better constant µ (the higher, the better in terms of convergence
rates of, e.g., a gradient descent). We also emphasize that it is important
to mention over which set strong convexity holds. We give examples
below.

Example 15.1 (Strongly convex functions). The function f (u) =


2 ∥u∥2 is 1-strongly convex w.r.t. ∥ · ∥2 over R .
1 2 M

The function f (u) = ⟨u, log u⟩ is 1-strongly convex w.r.t. ∥ · ∥1


over △M . Applying Proposition 15.5, we obtain for all p, q ∈ △M
1
∥p − q∥21 ≤ Bf (p, q) = KL(p, q),
2
which is known as Pinsker’s inequality. We empirically verify
the inequality in Fig. 15.1.
More generally, f (u) is µ1 -strongly convex w.r.t. ∥ · ∥1 over any
bounded set C ⊂ RM + such that µ = supu∈C ∥u∥1 (Blondel, 2019).
However, it is not strongly convex over RM + , as it is not bounded.

15.4.5 Nonconvex functions

In general, the minimum of a function necessarily has a null gradient,


that is,
w⋆ ∈ arg min L(w) =⇒ ∇L(w⋆ ) = 0.
w∈W
15.4. Classes of functions 379

1.2
KL(p, q)
1.0 0.5||p q||21
0.8

0.6

0.4

0.2

0.0
0.0 0.2 0.4 0.6 0.8 1.0

Figure 15.1: Graphical verification of Pinsker’s inequality, 1


2
∥p − q∥21 ≤ KL(p, q),
with p := (π, 1 − π) and q := (0.3, 0.7).

To see this, consider the function F : t → L(w⋆ − t∇L(w∗ )). If


∇L(w⋆ ) ̸= 0, then F ′ (0) = −∥∇L(w⋆ )∥22 ̸= 0. Therefore, there exists
a small t > 0 such that F (t) < F (0), i.e., L(w⋆ ) is not the minimum.
However, if the function is not convex, the converse is a priori not true:
finding a point that has a null gradient does not ensure that we have
found a global minimum as illustrated in Figure 15.3.
For non-convex functions, a point with null gradient is called a
stationary point. A stationary point may define a local maximum
or a local minimum. Formally, ŵ is a local minimum if

∃r > 0, s.t. ∀v ∈ W satisfying ∥v − ŵ∥ ≤ r, we have L(v) ≥ L(ŵ).

A local maximum is defined similarly, except that L(v) ≤ L(ŵ) in a


neighborhood of ŵ. For non-convex functions, convergence rates are
therefore generally expressed in terms of convergence of the norm of the
gradient ∥∇f (wt )∥2 towards 0. Such theoretical results do not ensure
convergence to the global minimum but rather convergence to a point
where no further progress may a priori be possible with just gradient
information.
380 Optimization basics

Local minimum Global minimum

Figure 15.2: Convex function: any secant is


Figure 15.3: Non-convex func-
above the function, any tangent is below the func-
tion: a point with zero gradient
tion, a point with zero gradient is a minimum.
is not necessarily the global min-
imum.

15.5 Performance guarantees

For a given class of functions, we can define the performance of an


algorithm as the number of iterations the algorithm would need to find
an ε-accurate solution as in Eq. (15.1). This is called the computational
complexity of the algorithm, denoted
t = T (ε).
Alternatively, the performance of an algorithm can be stated in terms
of convergence rate, i.e., the accuracy that the algorithm reaches
after t iterations,
ε = R(t),
where R is a decreasing positive function vanishing as t → +∞. Usu-
ally, R incorporates properties of the function minimized, such as its
smoothness constant β and information on the initial point, such as its
function value. The corresponding computational complexity T (ε) is
then given as the minimum number of iterations t such that R(t) ≤ ε,
T (ε) = min{t ∈ N : R(t) ≤ ε}.
Convergence rates can generally be classified by considering the
progress ratio on iteration t, defined by
R(t)
ρt := .
R(t − 1)
15.5. Performance guarantees 381

The asymptotic convergence rate is then defined by

ρ∞ := lim ρt .
t→+∞

We can classify the rates as follows.


1. Sublinear convergence rates, ρ∞ = 1: the longer the algorithm
runs, the slower it makes progress. That is, the relative progress
eventually tends to stall as t → +∞. Examples of R(t) in this
category include O(1/t), O(1/t2 ) or more generally O(1/tα ) for
some α > 0. This is equivalent to T (ε) = O(ε−1/α ).
2. Linear convergence rates, ρ∞ = c ∈ (0, 1): the algorithm
eventually reaches a state of constant relative progress at each
iteration, leading to an overall rate R(t) = O(exp(−ct)) for c
depending on the properties of the objective. This corresponds to
T (ε) = O(c−1 ln ε−1 ).
3. Superlinear convergence rates, ρ∞ = 0: the relative progress
is better at each new iteration. This√can happen for, e.g., R(t) =
O(exp(−t2 )), leading to T (ε) = O( ln ε−1 ) or
R(t) = O(exp(− exp(t))), also called a quadratic rate, leading to
T (ε) = O(ln ln ε−1 ).
This is illustrated in Fig. 15.4.
Note that the term “linear” may be misleading as the rates are in
fact exponential. They are called “linear” because of their behavior in
log scale.

Upper and lower bounds


The best performance of a class of algorithms equipped with a given
oracle (e.g. first-order oracle) can be upper-bounded or lower-bounded.
This allows to show that an algorithm with access limited to a certain
type of oracle cannot theoretically do better than a certain number. For
example, the computational complexity to minimize β-smooth functions
restricted on [0, 1]P with first-order oracles is lower bounded by εcP
(Nemirovski and Yudin, 1983, p. 1.1.7). For example, with P = 10
and ε = 10−3 , this gives 1030 iterations. Note that these results are
pessimistic by construction. The actual performance of an algorithm
on a specific instance of this function class may be much better than
382 Optimization basics

1.0
Convergence rate R(t) (log scale)

10 2

Progress ratio t = R(tR(t)1)


0.8
10 5

R(t) = 1/ t (sublinear)
10 8
R(t) = 1/t (sublinear) 0.6
R(t) = 1/t 2 (sublinear)
10 11
R(t) = e t (linear) 0.4
10 14 R(t) = e t2 (superlinear)
0.2
10 17

10 20 0.0
0 20 40 60 80 100 0 20 40 60 80 100
Iteration t Iteration t

Figure 15.4: Left: convergence rates. Right: progress ratios. An algorithm with
sublinear convergence rates eventually eventually stops making progress. An algorithm
with linear convergence rate eventually reaches a state of constant progress. An
algorithm with superlinear convergence rate makes faster progress after each iteration.

this worst-case scenario, as it is the case with popular algorithms such


as quasi-Newton methods. Better computational complexities can be
achieved by further restricting the class of functions to the set of convex
functions, which play a central role in optimization and many other
fields.

Zero-order vs. first-order

For the class of smooth strongly convex functions, the computational


complexity of the best first-order algorithm is (up to constant and
logarithmic factors) P times better than that of the best zero-order al-
gorithm (Nesterov, 2018; Nesterov and Spokoiny, 2017). This theoretical
comparison shows that, while zero-order optimization algorithms may
perform on par with first-order optimization algorithms for problems
with a low dimension P , they can be much slower for high dimensional
problems, i.e., P ≫ 1.
In different settings, for example with stochastic oracles (Duchi
et al., 2015) or for different classes of√functions, slightly different com-
parisons may be achieved, such as a P factor instead of P . However,
the same conclusion holds in the current frameworks considered: first-
order optimization algorithms can provide fast rates that are dimension
independent while the rates of zero-order optimization algorithms gen-
15.6. Summary 383

erally depend on the dimension of the problem, making them unfit for
high-dimensional problems.
This explains the immense success of first-order algorithms for train-
ing neural networks. Fortunately, using reverse-mode autodiff, as studied
in Chapter 8, it can be shown that computing a gradient has roughly
the same complexity as evaluating the function itself Section 8.3.3

15.6 Summary

• The information available to us on a function can be formalized


by the notion of oracle. Zero-order oracles can only evaluate the
function; first-order oracles can also compute the gradient; second-
order oracles can also compute the Hessian or the Hessian-vector
product (HVP).

• Most optimization algorithms reviewed in this book can be viewed


from a variational perspective, in which the next iteration is
produced by optimizing a trade-off between an approximation of
the function and a proximity term. Different approximations and
different proximity terms lead to different algorithms.

• We also reviewed different classes of functions, and performance


guarantees.
16
First-order optimization

16.1 Gradient descent

Gradient descent is one of the simplest algorithms in our toolbox to


minimize a function. At each iteration, it moves along the negative
gradient direction, scaled by a stepsize γ:
wt+1 = wt − γ∇L(wt ). (16.1)
The path taken by a gradient descent on a simple quadratic is illustrated
in Fig. 16.1 for different choices of the stepsize.

16.1.1 Variational perspective


Consider the linear approximation of L(w) around wt ,
L(w) ≈ L(wt ) + ⟨∇L(wt ), w − wt ⟩.
One can easily check that the gradient descent update in Eq. (16.1) can
be rewritten as the solution of a minimization problem, namely,
1
wt+1 = arg min L(wt ) + ⟨∇L(wt ), w − wt ⟩ + ∥w − wt ∥22 . (16.2)
w∈W 2γ
In words, a gradient descent update optimizes a trade-off between staying
close to the current wt , thanks to the proximity term 2γ
1
∥w − wt ∥22 , and

384
16.1. Gradient descent 385

Gradient descent Gradient descent


0.5 Stepsize 0.5 0.5 Stepsize 1.8
w0 w0
w1 2 w2
w
0.0 w* 0.0 w*
w2

w2
w1

0.5 0.5
1 0 1 1 0 1
w1 w1

Figure 16.1: Trajectory taken by a gradient descent on an objective f (w) =


0.05w12 + 0.5w22 with a small (left) or large (right) stepsize. In each case the iterates
follow the normal vectors to the contour lines (dashed lines), that is, the negative
gradients. A small stepsize gives a slow convergence but a larger stepsize induces
oscillations.

minimizing the linearization of L around wt . Intuitively, by choosing γ


sufficiently small, we ensure that the minimizer of the regularized linear
approximation stays in a neighborhood where the linear approximation
is valid. This viewpoint is useful to motivate gradient descent extensions.

16.1.2 Convergence for smooth functions

As long as ∇L(wt ) ̸= 0, the function Lt (γ) := L(wt − γ∇L(wt )) has


a negative derivative at 0, i.e., L′t (0) = −∥∇L(wt )∥22 . Hence, as long
as ∇L(wt ) ̸= 0, there exists a stepsize ensuring a decrease in objective
values at each iterate. However, without further assumptions, such a
stepsize may depend on each iterate and may be infinitesimally small.
To quantify the convergence of gradient descent with a constant stepsize,
we restrict to the class of smooth functions. By applying Proposition 15.1
on the iterate of gradient descent, we obtain that

βγ 2
L(wt+1 ) ≤ L(wt ) − γ∥∇L(wt )∥22 + ∥∇L(wt )∥22 .
2
Therefore, for β-smooth functions, by selecting γ ≤ β1 , we get that
γ
L(wt+1 ) − L(wt ) ≤ − ∥∇L(wt )∥22 ,
2
which illustrates the main mechanism behind gradient descent: each
iteration decreases the objective by a constant times the norm of the
386 First-order optimization

gradient of the current iterate. This equation can further be summed


over all iterates up to T . This telescopes the objective values, leading to
−1
1 TX
min ∥∇L(wt )∥22 ≤ ∥∇L(wt )∥22
t∈{0,...,T −1} T t=0
2  
≤ L(w0 ) − L(wT )
γT
2  
≤ L(w0 ) − L⋆ ,
γT
where we recall that L⋆ is the infimum of L. Therefore, after sufficiently
many iterations, gradient descent finds a point whose gradient norm is
arbitrarily small.

Non-convex case
Without further assumptions, i.e., in the non-convex case, the above
result (i.e., convergence to a stationary point, measured by the gradient
norm) is the best we may get in theory. Denoting Ts (ε) the number
of iterations needed for a gradient descent to output a point that is
ε-stationary, i.e., ∥∇L(ŵ)∥2 ≤ ε, we have Ts (ε) ≤ O(ε−2 ).

Convex case
By adding a convexity assumption on the objective, we can use the lower
bound provided by the convexity assumption to ensure convergence to
a minimum. Namely, for a β-smooth and convex function f , and with
stepsize γ ≤ 1/β, we have that (Nesterov, 2018)
1
L(wT ) − L⋆ ≤ ∥w0 − w⋆ ∥22 .
γT
That is, we get a sublinear convergence rate, and the associated compu-
tational complexity to find a minimum is T (ε) = O(1/ε).

Strongly convex case


If we further strengthen the assumptions by considering β-smooth, µ-
strongly convex functions, the convergence rate of a gradient descent
16.1. Gradient descent 387

can be shown to be (Nesterov, 2018), for any stepsize γ ≤ 1/β,


 
L(wT ) − L⋆ ≤ (1 − γµ)T L(w0 ) − L⋆
 
≤ exp (−γµT ) L(w0 ) − L⋆ .

That is, we obtain a linear convergence rate and the associated computa-
tional complexity is T (ε) = O(ln ε−1 ). The above convergence rates may
be further refined (Nesterov, 2018); we focused above on the simplest
result for clarity.
Strong convexity can also be replaced by a weaker assumption,
gradient-dominating property (Polyak, 1963), i.e., ∥∇L(v)∥22 ≥ c(L(v)−
L⋆ ) for some constant c and any v ∈ W. A convex, gradient-dominating
function can also be minimized at a linear rate.

16.1.3 Momentum and accelerated variants


We started with gradient descent as a simple example of first-order
optimization algorithm. However, different optimization algorithms can
be designed from the access to first-order oracles and the knowledge
of the class of functions considered. For example, consider quadratic
convex functions w 7→ 12 w⊤ Aw + b⊤ w, that are a basic example of
smooth strongly convex functions if A is positive definite. An optimal
method in this case is the heavy-ball method of Polyak (1964), that can
be written as

v t+1 := νv t − γ∇L(wt )
wt+1 := wt + v t+1 .

The heavy-ball method uses an additional variable v t , that can be


interpreted as the velocity of a ball driven by the negative gradient
to converge towards a minimum. Intuitively, this additional velocity
circumvents the oscillations that a gradient descent may present as
illustrated in Fig. 16.2 compared to Fig. 16.1. For ν = 0, we recover
usual gradient descent. For ν > 0, the velocities accumulate a form
of an inertia momentum, where ν is interpreted as the “mass” of the
ball. In terms of convergence rates, the heavy-ball method can be
shown to converge linearly similarly to gradient descent, but with a rate
388 First-order optimization

Gradient descent with momentum


0.5 Stepsize 1.8 Momentum 0.2
w0
w2
0.0 w*

w2
w1

0.5
1 0 1
w1

Figure 16.2: Trajectory taken by a gradient descent with momentum. Compared to


gradient descent without momentum, for the same stepsize, the oscillations previously
observed in Fig. 16.1 are no longer present, and the algorithm converges then faster
to the minimum.

O(exp(−T µ/β)) for appropriate choices of ν, γ. In comparison, by


p

choosing an optimal stepsize for the gradient descent, its convergence


rate is O(exp(−T µ/β)) which is provably worse, as we always have
µ/β ≤ 1.
Beyond the case of quadratic functions, accelerated variants of
gradient descent for convex or strongly convex functions have been
developed by Nesterov (2018). Such variants have inspired the design
of optimization algorithms in stochastic settings presented below.

16.2 Stochastic gradient descent

In machine learning, we are usually interested in minimizing the ex-


pected loss of the model over the data distribution ρ:

min L(w) := ES∼ρ [L(w; S)] .


w∈W

For example, L is often set to L(w; S) := ℓ(Y, f (X, w)), where ℓ is a


loss function, f is a neural network and S = (X, Y ) is a random pair,
composed of an input X and an associated target Y , sampled from ρ.
In this setting, since the data distribution ρ is generally unknown and
may be infinite, we cannot exactly evaluate the expected loss L(w) or
its gradient ∇L(w).
In practice, we are often given a fixed dataset of n pairs si = (xi , yi ).
This is a special case of the expected loss setting, since this can be seen
16.2. Stochastic gradient descent 389

as a empirical distribution ρ = ρn
n
1X
L(w) = ES∼ρn [L(w; S)] = L(w; (Xi , Yi )).
n i=1

The gradient of L(w) is then


n
1X
∇L(w) := ∇L(w; (xi , yi )).
n i=1

In this case, we see that the full gradient ∇L(w), as needed by gradient
descent, is the average of the individual gradients. That is, the cost of
computing ∇L(w) is proportional to the number of training points n. For
n very large, that is a very large amount of samples, this computational
cost can be prohibitive. Stochastic gradients circumvent this issue.

16.2.1 Stochastic gradients


Usually, even if we do not know ρ, we can sample from it, i.e., we have
access to samples S ∼ ρ. We can then use a stochastic gradient of
the form ∇L(w; S) as a random estimate of ∇L(w). This may look like
a rough estimate but, on average, this is a valid approximation since

ES∼ρ [∇L(w; S)] = ∇L(w).

We say that ∇L(w; S) is an unbiased estimator of ∇L(w). To fur-


ther improve the approximation, we may also consider mini-batch
estimates by sampling m ≪ n data points Si := (Xi , Yi ) and using
i=1 ∇L(w; Si ), whose expectation still matches ∇L(w), while po-
1 Pm
m
tentially reducing the approximation error by averaging multiple stochas-
tic gradients. Computationally, the main advantage is that the cost is
now proportional to m instead of n.
In whole generality, one can consider stochastic first-order oracles
defined below.

Definition 16.1 (Stochastic first-order oracles). A stochastic first-


order oracle of an expected objective L(w) is a random estimate
g(w; S) of ∇L(w) with S sampled according to some distribution
390 First-order optimization

q. A stochastic gradient is said to be an unbiased estimator if

ES∼q [g(w; S)] = ∇L(w).

The variance of a stochastic gradient is


h i
ES∼q ∥g(w; S) − ∇L(w)∥22 .

When q = ρ, we recover stochastic gradients. When q is the product


of m independent samples according to p, we recover mini-batch stochas-
tic gradients. First-order stochastic optimization algorithms build upon
stochastic first-order oracles to approximately find the minimum of the
expected objective. In such a setting, the iterates of the algorithm are
by definition random. Convergence rates therefore need to be expressed
in probabilistic terms by considering for example the expected objective
value according to the randomness of the oracles.

16.2.2 Vanilla SGD


Equipped with a stochastic first-order oracle, such as (mini-batch)
stochastic gradients, we can define stochastic gradient descent as
wt+1 = wt − γg(wt ; S t ) where S t ∼ q.
We assume that S t is independent of wt . Compared to the usual gradient
descent, the main impediment of the stochastic setting is the additional
noise induced by the stochastic estimates: their variance.
For example, consider applying a stochastic gradient descent on the
expectation of β-smooth convex functions L(w; s) with unbiased oracles.
To harness the randomness of the iterates, consider after T iterations
outputting the average of the first T iterates, that is w̄T := T1 Tt=1 wt .
P

Moreover, suppose that the variance of the stochastic first-order oracles


is bounded by σ 2 for all minimizers w⋆ of L. Denoting by ES0 ,...,ST −1
the randomness associated to the stochastic oracles, we have then that
for a stepsize γ ≤ 1/(4β), (Lan, 2012),
1
ES0 ,...,ST −1 [L(w̄T )] − L⋆ ≤ ∥w0 − w⋆ ∥22 + 2γσ 2 .
γT
The resulting convergence rate illustrates that a stochastic gradient
descent converges to the minimum of the expected objective up to a
16.2. Stochastic gradient descent 391

constant term depending on the variance of the oracle and the stepsize.
One can diminish the variance by considering mini-batches: if the
variance of a single stochastic gradient is σ12 , considering a mini-batch
of m gradients reduces the variance of the corresponding oracle to
σm = σ12 /m. To decrease the additional term, one may also decrease
the stepsizes over the iterations. For example, by choosing a decreasing
stepsize like γ t = t−1/2 , the convergence rate is then of the order
O((∥w0 − w⋆ ∥22 + σ 2 ln t)/ (t)). The stepsize can also be selected as
p

a constant γ0 that decreases the average objective for the first T0


iterations and reduced by a multiplicative factor at regular intervals
like γj = ργj−1 for ρ ∈ (0, 1) to handle iterations between Tj , Tj+1 .
Alternative stepsize schedules such as a cosine decay (Loshchilov and
Hutter, 2016) have recently become popular.
The literature on alternative optimization schemes for stochastic
optimization is still rapidly evolving, with new heuristics regularly
proposed. We present below two popular techniques.

16.2.3 Momentum variants

Accelerated optimization algorithms developed in the deterministic


setting may be extended to the stochastic setting. For example, the
heavy-ball method can be adapted to the stochastic setting, leading to
stochastic gradient descent with momentum (Sutskever et al., 2013)
generally implemented as

v t+1 := νv t + g(wt ; S t )
wt+1 := wt − γv t+1 .

As mentioned earlier the momentum method can be modified to han-


dle non-quadratic smooth strongly convex functions. This leads to
Nesterov’s accelerated method in the deterministic setting. This has
been adapted to the stochastic with a so-called Nesterov momen-
tum (Sutskever et al., 2013)

v t+1 := νv t + g(wt + νv t ; S t )
wt+1 := wt − γv t+1 .
392 First-order optimization

16.2.4 Adaptive variants

In any gradient descent-like algorithm, selecting the stepsize is key for


good performance. While a constant stepsize may be used if the function
is smooth, we may not know in advance the smoothness constant of the
objective, which means that additional procedures may be required to
select appropriately the stepsize. In the deterministic case, line-searches
such as the Armijo or Wolfe’s rules (Wright and Nocedal, 1999) can be
used to check whether the selected stepsize decreases sufficiently the
objective at each iteration. Such rules have be adapted in the stochastic
setting (Vaswani et al., 2019).
Another way to decrease the sensitivity of the algorithm with respect
to the stepsize has been to estimate first and second-order moments
of the gradients and use the latter as a form of preconditioning to
smooth the trajectory of the iterates. This led to the popular Adam
optimizer (Kingma and Ba, 2014). It takes the form,

mt+1 := ν1 mt + (1 − ν1 )g t
v t+1 := ν2 v t + (1 − ν2 )(g t )2
m̂t+1 := mt+1 /(1 − ν1t )
v̂ t+1 := v t+1 /(1 − ν2t )
p 
wt+1 := wt − γ m̂t+1 / v̂ t+1 + ε ,

where g t := g(wt ; S t ), (g t )2 denotes the element-wise square of g t and


ν1 , ν2 , γ, ε are hyper-parameters of the algorithm. Numerous variants
exist, such as varying the stepsize γ above along the iterations.

16.3 Projected gradient descent

Oftentimes, we seek to find the solution of a minimization problem


subject to constraints on the variables, of the form

min L(w), (16.3)


w∈C

where C ⊆ W = RP is a set of constraints. We say that an approximate


solution w
b to Eq. (16.3) is feasible if w
b ∈ C. Naturally, the design
16.3. Projected gradient descent 393

of algorithms for the constrained setting now depends, not only on


information about L, but also on information about C.
Similarly to L, different oracles can be considered about C. One of
the most commonly used oracle is the Euclidean projection

Definition 16.2 (Euclidean projection). The Euclidean projection


onto the set C is defined by

projC (w) := arg min ∥w − v∥22 .


v∈C

This projection, which is well-defined when C is a convex set, can


be used in projected gradient descent, that we briefly review below.
Typically, the projection on a particular set C requires a dedicated
algorithm to compute it.
Other possible oracles are linear maximization oracles (LMO)
used in Frank-Wolfe algorithms and Bregman projection oracles,
used in mirror descent algorithms. The algorithm choice can be dictated
by what oracle about C is available.

16.3.1 Variational perspective


Projected gradient descent is a natural generalization of gradient descent,
based on the Euclidean projection oracle. Its iterates read

wt+1 := projC (wt − γ∇L(wt )).

At each iteration, we attempt to decrease the objective by moving along


the negative gradient direction, while ensuring that the next iterate
remains feasible, thanks to the projection step.
Similarly to the variational perspective of gradient descent in Eq. (16.2),
the projected gradient descent update is equivalent to
1
wt+1 = arg min L(wt ) + ⟨∇L(wt ), w − wt ⟩ + ∥w − wt ∥22 .
w∈C 2γ
This shows that projected gradient descent minimizes a trade-off between
staying close to wt and minimizing the linearization of L around wt ,
while staying in C.
394 First-order optimization

In terms of convergence rates, they remain the same as gradient


descent (Nesterov, 2018). For example, projected gradient descent on a
smooth convex function still converges at a rate R(T ) = O(1/T ).
There are numerous extensions of vanilla projected gradient descent.
Similarly to gradient descent, the stepsize can be automatically adjusted
using linesearch techniques and there exists accelerated variants. If
we replace ∇L(w) with a stochastic gradient ∇L(w; S), we obtain a
stochastic projected gradient descent.

16.3.2 Optimality conditions


In the unconstrained case, a minimum necessarily has a zero gradient.
In the constrained setting, there may not be any feasible parameters
with zero gradient. Instead, the optimality of a point is characterized
by the fact that no better solution can be found by moving along the
gradient at that point, while staying in the constraints. Formally, it
means that for any γ > 0, a minimizer w⋆ of L on C satisfies

w⋆ = projC (w⋆ − γ∇L(w⋆ )).

It can be shown that this condition is equivalent (Nesterov, 2018) to

⟨∇L(w⋆ ), w − w⋆ ⟩ ≥ 0 ∀w ∈ C.

16.3.3 Commonly-used projections


We now briefly review a few useful Euclidean projections.

• If C = RP , we obviously have

projC (w) = w.

Therefore, in the unconstrained setting, projected gradient descent


indeed recovers gradient descent.

• If C = [a, b]P (box constraints), we have

projC (w) = clip(w, a, b) := min{max{w, a}, b}.

where the min and max are applied coordinate-wise.


16.4. Proximal gradient method 395

• As a special case of the above, if C = RP+ (non-negative orthant),


projC (w) = max{w, 0},
also known as non-negative part or ReLu.

• If C = △P (unit probability simplex),


projC (w) = max{w − τ 1, 0},
where τ ∈ R is a constant ensuring that projC (w) normalizes to 1.
It is known that τ can be found in O(P log P ) using a sort. This
can be improved to O(P ) using a median-finding like algorithm.

16.4 Proximal gradient method

The constrained setting (with C a convex set) can be recast as uncon-


strained optimization, by extending our analysis to functions taking
infinite values. Let us denote the indicator function of the set C by

0 if w ∈ C
ιC (w) := .
+∞ otherwise

Clearly, the constrained problem in Eq. (16.3) can then be rewritten as


min L(w) + ιC (w).
w∈W

This suggests that constrained optimization is a special case of com-


posite objectives of the form
min L(w) + Ω(w),
w∈W

where Ω is a convex but potentially non-differentiable function. We


assume that we have access to an oracle associated with Ω called the
proximal operator.

Definition 16.3 (Proximal operator). The proximal operator asso-


ciated with Ω : W → R is
1
proxΩ (w) := arg min ∥w − v∥22 + Ω(v).
v∈W 2
396 First-order optimization

This leads to the proximal gradient method, reviewed below.

16.4.1 Variational perspective


With this method, the update reads

wt+1 = proxγΩ (wt − γ∇L(wt )).

This update again enjoys an intuitive variational perspective, namely,


1
wt+1 = arg min L(wt ) + ⟨∇L(wt ), w − wt ⟩ + ∥w − wt ∥22 + Ω(w).
w∈W 2γ

That is, we linearize L around wt , but keep Ω as is.


The proximal gradient method is popularly used when the objective
function contains a sparsity-inducing regularizer Ω. For example, for
the LASSO (Tibshirani, 1996), which aims at predicting targets y =
(y1 , . . . , yn )⊤ ∈ RN from observations X = (x1 , . . . , xn )⊤ ∈ RN ×P , we
set L(w) = 12 ∥Xw − y∥22 and Ω(w) = λ∥w∥1 , where λ > 0 controls
the regularization strength. In this case, proxΩ is the so-called soft-
thresholding operator (see below).
Convergence guarantees of the proximal gradient method remain
the same as for gradient descent, such as a O(1/T ) rate for smooth
convex functions.

16.4.2 Optimality conditions


An optimal solution of the problem is characterized by the fixed point
equation
w⋆ = proxγΩ (w⋆ − γ∇L(w⋆ )),
for all γ > 0 (Nesterov, 2018). In other words, the proximal gradient
method (which includes gradient descent and projected gradient descent
as special cases), can be seen as fixed point iteration schemes. Such
a viewpoint suggests using acceleration methods from the fixed point
literature such as Anderson acceleration (Pollock and Rebholz, 2021).
It is also useful when designing implicit differentiation schemes as
presented in Chapter 8.
16.5. Summary 397

16.4.3 Commonly-used proximal operators


We now briefly review a few useful proximal operators.

• If Ω(w) = 0, we have

proxγΩ (w) = w.

Therefore, with this proximal operator, the proximal gradient


method recovers gradient descent.

• If Ω(w) = ιC (w), we have

proxγΩ (w) = projC (w).

Therefore, with this proximal operator, the proximal gradient


method recovers projected gradient descent.

• If Ω(w) = λ∥w∥1 , we have

proxγΩ (w) = (sign(w) · max(|w| − γλ, 0)),

where the operations are applied coordinate-wise. This is the


so-called soft-thresholding operator.

• Ω(w) = λ g∈G ∥wg ∥2 where G is a partition of [P ] and wg


P

denotes the subvector restricted to g, then we have


h i
proxγΩ (w) = max(1 − λ · γ/∥wg ∥2 , 0)wg ,
g

which is used in the group lasso (Yuan and Lin, 2006) and can be
used to encourage group sparsity.

For a review of more proximal operators, see for instance (Bach et al.,
2012; Parikh, Boyd, et al., 2014).

16.5 Summary

• From a variational perspective, gradient descent is the algorithm


obtained when linearizing the objective function and using a
quadratic regularization term.
398 First-order optimization

• Projected gradient descent is the algorithm obtained when there


is an additional constraint (the Euclidean projection naturally
appearing, due to the quadratic regularization term).

• When the objective is the sum of a differentiable function and


a non-differentiable function, proximal gradient is the algorithm
obtained when the differentiable function is linearized but the
non-differentiable function is kept as is.

• We also reviewed various stochastic gradient based algorithms,


including vanilla SGD, SGD with momentum and Adam.
17
Second-order optimization

We review in this chapter methods whose iterations take the form

wt+1 := wt − γ t B t ∇L(wt ),

where γ t is a stepsize and B t is a pre-conditioning matrix involving


second-order derivatives.

17.1 Newton’s method

17.1.1 Variational perspective

We saw in Eq. (16.2) that gradient descent can be motivated from


a variational perspective, in which we use a linear approximation of
the objective around the current iterate, obtained from the current
gradient. Similarly, if we have access not only to the gradient but also
to the Hessian of the objective, we can use a quadratic approximation
of the objective around the current iterate. More precisely, given a
function L(w), we may consider minimizing the second-order Taylor
approximation of L(w) around the current iterate wt ,
1
L(w) ≈ L(wt ) + ⟨∇L(wt ), w − wt ⟩ + ⟨w − wt , ∇2 L(wt )(w − wt )⟩.
2

399
400 Second-order optimization

Newton’s method simply iteratively minimizes this quadratic approx-


imation around the current iteration wt , namely,
1
wt+1 = arg min L(wt ) + ⟨∇L(wt ), w−wt ⟩ + ⟨w−wt , ∇2 L(wt )(w−wt )⟩.
w∈W 2
(17.1)

If the Hessian is positive definite at wt , which we denote by ∇2 L(wt ) ≻


0, then the minimum is well-defined and unique (this is for example the
case if L is strictly convex). The iterates can then be written analytically
as
wt+1 = wt − ∇2 L(wt )−1 ∇L(wt ).
If the Hessian is not positive definite, the minimum may not be defined.
Ignoring this issue and taking the analytical formulation could be
dangerous, as it could amount to computing the maximum of the
quadratic instead if, for example, the quadratic was strictly concave
(i.e., ∇2 L(w) ≺ 0).

17.1.2 Regularized Newton method

A simple technique to circumvent this issue consists in adding a regu-


larization term to the Hessian. Namely, from a variational viewpoint,
we can add a proximity term 12 ∥w − wt ∥22 , encouraging to stay close to
the current wt . The iterates of this regularized Newton method then
take the form
1
wt+1 = arg min L(wt ) + ⟨∇L(wt ), w−wt ⟩ ⟨w−wt , ∇2 L(wt )(w−wt )⟩
w∈W 2
η t
+ ∥w − wt ∥22 ,
2
where η t controls the regularization strength. Assuming η t > 0 is strong
enough to make ∇2 L(wt ) + η t I positive-definite, we have

wt+1 = wt − dt ,

where we defined the direction

dt := (∇2 L(wt ) + η t I)−1 ∇L(wt ). (17.2)


17.1. Newton’s method 401

Other techniques to circumvent this issue include using cubic regular-


ization and modifying the spectral decomposition of the Hessian, by
thresholding the eigenvalues or taking their absolute values. We refer
the interested reader to, e.g., (Nesterov, 2018; Wright and Nocedal,
1999) for more details.

17.1.3 Approximate direction


We observe a main impediment for implementing such a second-order
optimization algorithm: even if we had access to the Hessian of the
objective for free and this Hessian was positive definite, computing the
exact direction dt in Eq. (17.2) requires computing an inverse-Hessian
vector product (IHVP) with the gradient ∇L(wt ). Doing so exactly
requires solving a linear system

(∇2 L(wt ) + η t I)dt = ∇L(wt ),

which a priori takes O(P 3 ) time. In practice, however, we can compute


IHVPs approximately, as explained in Section 9.4.

17.1.4 Convergence guarantees


While implementing Newton’s method comes at a higher computational
cost, it can also benefit from faster convergence rates. Briefly, if Newton’s
method is initialized at a point w0 ∈ W close enough from the mini-
mizer w⋆ of a µ-strongly convex function with M -Lipschitz continuous
Hessian (namely ∥w0 − w∗ ∥2 ≤ 3M 2µ
), then Newton’s method converges
at a quadratic rate (Nesterov, 2018), that is, R(t) ≤ O(exp(exp(−t))
(see Section 15.5 for a brief introduction to performance guarantees).
This is far superior to gradient descent. Such an efficiency motivated
the development of interior point methods, that have been a break-
through in constrained optimization, thanks to the use of log-barrier
penalties (Nesterov, 2018).

17.1.5 Linesearch
In practice, we may not have access to an initial point close enough
from the minimizer. In that case, even for strictly convex functions for
402 Second-order optimization

which Newton’s steps are well-defined, taking wt+1 = wt − dt may not


ensure a decrease of the objective values. Nevertheless, the direction dt
may define a descent direction as defined below.
Definition 17.1 (Descent direction). A point d ∈ W defines a de-
scent direction −d for an objective L at w, if there exists a
positive stepsize γ > 0 such that

L(w − γd) ≤ L(w).

If L is differentiable, −d is a descent direction if ⟨−d, ∇L(w)⟩ < 0.

For Newton’s method without regularization, dt = ∇2 L(wt )−1 ∇L(wt )


is then a descent direction at wt , as long as ∇L(wt ) ̸= 0 and ∇2 L(wt ) ≻
0. If ∇2 L(wt ) ̸≻ 0, choosing η t > 0 such that ∇2 L(wt ) + η t I ≻ 0, also
ensures that dt = −(∇2 L(wt ) + η t I)−1 ∇L(wt ) is a descent direction
(as long as ∇L(wt ) ̸= 0). Newton’s method is then generally equipped
with a linesearch method that attempts to take steps of the form x
wt+1 = wt − γ t dt
with γ t chosen as the largest stepsize among {ρτ , τ ∈ N} for ρ ∈ (0, 1)
until a sufficient decrease of the objective is satisfied such as, for c ∈
(0, 1),
L(wt − γ t dt ) ≤ L(wt ) − cγ t ⟨∇L(wt ), ∇2 L(wt )−1 ∇L(wt )⟩.
For strongly convex functions, such an implementation exhibits two
phases: a first phase during which Newton’s steps are “damped” by
using a stepsize γ t < 1 and a second phase of super-fast convergence
during which stepsizes γ t = 1 are taken, and the objective decreases very
fast. Even far from the optimum, Newton directions can advantageously
adapt to the local geometry of the objective to speed-up convergence
compared to a regular gradient descent as explained below.

17.1.6 Geometric interpretation


To understand the efficiency of Newton’s method compared to gradient
descent, consider the minimization of a simple quadratic
1 1
L(w) = aw12 + bw22
2 2
17.1. Newton’s method 403

for a ≫ b ≥ 0, as illustrated in Fig. 17.1. A gradient descent moves


along the directions ∇L(w) = (aw1 , bw2 )⊤ and its stepsize is limited
by the variations in the first coordinate leading to some oscillations. If
we were simply rescaling the gradient by (a, b), i.e., taking steps of the
form
wt+1 = wt − γ diag(a−1 , b−1 )L(wt ),
the variations in both coordinates would be normalized to one and the
stepsize could simply be chosen to γ = 1 to directly get w⋆ . In other
words, by adapting the geometry of the directions with the geometry
induced by the objective, we can circumvent the oscillations.
That’s exactly what Newton’s method does by modifying the gradi-
ent direction using the inverse of the Hessian. Formally, at iteration t,
consider the modified objective
L̃(v) = L(Av) for A = ∇2 L(wt )−1/2 ,
with L strictly convex and A the inverse matrix square root of the
Hessian. One easily verifies that a Newton step is equivalent to a
gradient step on L̃, that is,
v t+1 = v t − ∇L̃(v t ) ⇐⇒ wt+1 = wt − (∇2 L(wt ))−1 ∇L(wt )
where
wt = Av t = ∇2 L(wt )−1/2 v t .
In the geometry induced by A, the objective is generally better condi-
tioned as illustrated in Fig. 17.1. This explains the efficiency of Newton’s
method. In particular for any strongly convex quadratic, a Newton step
reaches the optimum in one iteration, while a gradient step can take
many more iterations.

17.1.7 Stochastic Newton’s method


Consider now an expected loss
min L(w) := ES∼ρ [L(w; S)] .
w∈W

In that case, an estimate of the Hessian can be constructed just like for
the gradient using that
h i
ES∼ρ ∇2 L(w; S) = ∇2 L(w).
404 Second-order optimization

Figure 17.1: Left: Minimization of a quadratic L(w) = 12 aw12 + 12 bw22 by gradient


descent. For a ≫ b ≥ 0, a gradient descent typically oscillates. Right: minimization
by Newton’s method amounts to change the geometry of the problem to avoid
oscillations.

Denote then

g(w; S) ≈ ∇L(w), H(w; S ′ ) ≈ ∇2 L(w)

some stochastic estimates of respectively of the gradient and the Hessian


with S, S ′ independently drawn from p or from mini-batch approaxima-
tions with varying mini-batch sizes. One implementation of a stochastic
Newton method can then be

wt+1 = wt − γ t (H(wt ; S ′ ) + η t I)−1 g(wt ; S),

for η t ≥ 0 such that (H(wt ; S ′ ) + η t )−1 ≻ 0 and γ t fixed or chosen to


satisfy some sufficient decrease condition. We refer the interested reader
to, e.g., (Xu et al., 2020), for more details and variants.

17.2 Gauss-Newton method

Newton’s method (17.1) is usually not properly defined for non-convex


objective functions, since the Hessian may not be positive definite at the
current iterate. We saw in Section 9.2 that the Gauss-Newton matrix can
be used to define a positive-semidefinite approximation of the Hessian.
Here, we revisit the Gauss-Newton method from a variational and
partial linearization perspective. While the original Gauss-Newton
method originates from nonlinear least-squares, we will first describe
an extension to arbitrary convex loss functions, since it is both more
general and easier to explain.
17.2. Gauss-Newton method 405

17.2.1 With exact outer function


Consider a composite objective of the form

L(w) := ℓ(f (w)),

where ℓ : M → R is a convex function, such as a convex loss function


applied on a given sample, and f : W → M is a nonlinear function,
such as a neural network with parameters w ∈ W, evaluated on the
same sample. We saw that gradient descent and Newton’s method
amount to using linear and quadratic approximations of L(w) around
the current iterate wt , respectively. As a middle ground between the
two, the Gauss-Newton method uses the linearization of f around wt

f (w) ≈ f (wt ) + ∂f (wt )(w − wt )

but keeps ℓ as is to obtain the objective

wt+1 := arg min ℓ(f (wt ) + ∂f (wt )(w − wt ))


w∈W
= arg min ℓ(∂f (wt )w + f (wt ) − ∂f (wt )wt )
w∈W
= arg min ℓ(J t w + δ t ),
w∈W

where we defined the shorthands J t := ∂f (wt ) and δ t := f (wt ) −


∂f (wt )wt . We call ℓ(J t w + δ t ) the partial linearization of L = ℓ ◦ f
at wt , as opposed to the full linearization of L used in gradient descent.
Since the composition of a convex function and of linear function is
convex, this objective is convex even if L(w) is nonconvex. In practice,
we often add a proximity term as regularization to define
ηt
wt+1 := arg min ℓ(J t w + δ t ) + ∥w − wt ∥22 . (17.3)
w∈W 2
We can see this update as an approximation of the proximal point
update
ηt
arg min L(w) + ∥w − wt ∥22 ,
w∈W 2
where L(w) has been replaced by its partial linearization. Solving
Eq. (17.3) using gradient-based solvers requires to compute the gradient
406 Second-order optimization

of w 7→ ℓ(J t w + δ t ), which is w 7→ (J t )∗ ∇ℓ(J t w + δ t ). Computing


this gradient by autodiff therefore requires to perform a forward pass
to compute the JVP J t w and a backward pass to compute the VJP
(J t )∗ ∇ℓ(z). See Section 2.3 for an introduction to these operators and
Chapter 8 for an introduction to autodiff.
The Gauss-Newton method with arbitrary convex outer loss is
often called modified Gauss-Newton (Nesterov, 2007) or prox-linear
(Drusvyatskiy and Paquette, 2019). The classical Gauss-Newton and
Levenberg-Marquardt (Levenberg, 1944; Marquardt, 1963) methods
originate from nonlinear least-squares and are recovered when ℓ(z) is
quadratic (Kelley, 1995), such as ℓ(z) := 12 ∥z −y∥22 , for y some reference
target. The Gauss-Newton method corresponds classically to not using
regularization (i.e., η t = 0) and the Levenberg-Marquardt method uses
regularization (usually called damping, potentially changing η t across
iterations). See e.g., (Messerer et al., 2021), for a survey of different
variants.

17.2.2 With approximate outer function

Another variant of the Gauss-Newton method consists in replacing the


convex loss ℓ with its quadratic approximation around z t := f (wt ),
1
q t (z) := ℓ(z t ) + ⟨∇ℓ(z t ), z − z t ⟩ + ⟨z − z t , ∇2 ℓ(z t )(z − z t )⟩ ≈ ℓ(z)
2
to define the update

ηt
wt+1 := arg min q t (J t w + δ t ) + ∥w − wt ∥22 .
w∈W 2

Notice that ℓ has been replaced by its quadratic approximation q t .


This objective is always a convex quadratic, unlike the objective
of the Newton method in Eq. (17.1), which is a priori a nonconvex
quadratic, if f is nonlinear. Simple calculations show that

wt+1 = arg min L(wt ) + ⟨q t , J t (w−wt )⟩


w∈W
1 ηt
+ ⟨w−wt , (J t )∗ Qt J t (w−wt )⟩ + ∥w − wt ∥22
2 2
17.2. Gauss-Newton method 407

where q t := ∇ℓ(f (wt )) ∈ M = RZ , Qt := ∇2 ℓ(f (wt )) ∈ RZ×Z . The


closed form solution is
wt+1 = wt − ((J t )∗ Qt J t + η t I)−1 (J t )∗ q t
= wt − (∇2GN (ℓ ◦ f )(wt ) + η t I)−1 ∇L(wt ),
where we used the (generalized) Gauss-Newton matrix of L = ℓ ◦ f ,
defined in Section 9.2.

17.2.3 Linesearch
Similarly to Newton’s method, the iterates of a Gauss-Newton method
may diverge when used alone. However, the direction −(∇2GN (ℓ◦f )(wt )+
η t I)−1 ∇L(wt ) defines a descent direction for any η t > 0 and can be
combined with a stepsize γ t (typically chosen using a linesearch) to
obtain iterates of the form
wt+1 = wt − γ t (∇2GN (ℓ ◦ f )(wt ) + η t I)−1 ∇L(wt ).

17.2.4 Stochastic Gauss-Newton


In deep learning, the objective generally consists in an expectation
over samples of the composition between a loss function and a network
function:
L(w) = ES∼ρ [L(w; S)] = E(X,Y )∼ρ [ℓ(f (w; X); Y )]
where S = (X, Y ) denotes a sample pair of input X with associated
label Y . In that case, as already studied in Section 9.2, the Gauss-
Newton matrix ∇2GN L is the expectation of the individual Gauss-Newton
matrices
∇2GN L(w; x, y) := ∂f (w; x)⊤ ∇2 ℓ(f (w; x))∂f (w; x),
∇2GN L(w) := E(X,Y )∼ρ [∇2GN L(w; x, y)].
We can estimate the gradient and the Gauss-Newton matrix by, respec-
tively, g(w; S) ≈ ∇L(w), and G(w; S ′ ) ≈ ∇2GN L(w) for S, S ′ ∼ ρ or
using mini-batch approximations. A stochastic Gauss-Newton method
therefore performs iterates
wt+1 := wt − γ t (G(w; S ′ ) + η t I)−1 g(w, S),
for η t ≥ 0 and γ t fixed or selected to satisfy some criterion.
408 Second-order optimization

17.3 Natural gradient descent

Natural gradient descent (Amari, 1998) follows a similar principle as


gradient descent: linearize the objective around the current iterate and
minimize this approximation together with a proximity term. It differs
from gradient descent in the choice of the proximity term: rather than
using a squared Euclidean distance between the parameters, it uses a
Kullback-Leibler divergence between the probability distributions
these parameters define.

Negative log-likelihood
We consider objectives of the form

min L(w) = ES∼ρ [L(w; S)] = ES∼ρ [− log qw (S)] ,


w∈W

where ρ is an unknown data distribution (but from which we can sam-


ple) and where qw is a probability distribution parameterized by w. As
reviewed in Chapter 3, the negative log-likelihood can be used as a loss
function (many loss functions can be seen from this perspective, includ-
ing the squared and logistic loss functions). In the unsupervised setting,
where S = Y , we simply use qw (Y ) as is. In the supervised setting,
where S = (X, Y ), we use the product rule P(X, Y ) = P(X)P(Y |X) to
parameterize qw (S) as

qw (x, y) := ρX (x)pθ (y),

where ρX is the marginal distribution for X, pθ (y) is the PMF/PDF


of a probability distribution and θ = f (w; x) is for instance a neural
network with parameters w ∈ W and input x ∈ X .

17.3.1 Variational perspective


Natural gradient descent is motivated by updates of the form

wt+1 = arg min L(wt ) + ⟨∇L(wt ), w − wt ⟩ + KL(qwt , qw ),


w∈W

where KL(p, q) := p(z) log p(z)


q(z) dz is the Kullback-Leibler (KL) diver-
R

gence. Unlike gradient descent, the proximity term is therefore between


17.3. Natural gradient descent 409

the current distribution qwt and a candidate probability distribution


qw . The above problem is intractable in general, as the KL may not
have a closed form. Nevertheless, its quadratic approximation can be
shown (Amari, 1998) to admit a simple form,
1
KL(qwt , qw ) ≈ ⟨w − wt , ∇2F L(wt )(w − wt )⟩
2
where we used the Fisher information matrix ∇2F L(w), studied in
Section 9.3. Equipped with this quadratic approximation of the KL
divergence, natural gradient descent amounts to compute iterates as

wt+1 := arg min L(wt ) + ⟨∇L(wt ), w − wt ⟩


w∈W
1 ηt
+ ⟨w − wt , ∇2F L(wt )(w − wt )⟩ + ∥w − wt ∥22 ,
2 2
where a quadratic proximity-term was added to ensure a unique solution.
This is a srictly convex problem as ∇2F L(wt ) is positive semi-definite.
The closed-form solution is

wt+1 = wt − (∇2F L(wt ) + η t I)−1 ∇L(wt ).

Because the Gauss-Newton and Fisher information matrices are equiv-


alent when pθ is an exponential family distribution (Proposition 9.6),
the Gauss-Newton and natural gradient methods coincide in this case.

17.3.2 Stochastic natural gradient descent


In practice, we may not have access to ∇L(wt ) in closed form as it is
an expectation over ρ. Moreover, ∇2F L(wt ) may not be computable in
closed form either. To estimate the Fisher information matrix, we can
use that (see Section 9.3) using the shorthand θ := f (w, X),

∇2F L(w) = EX∼ρX EY ∼pθ [∇L(w; X, Y ) ⊗ ∇L(w; X, Y )].

We can then build estimates g(wt ; S) ≈ ∇L(wt , S) and F (wt ; S ′ ) ≈


∇2F L(wt ) for S sampled from ρ and S ′ sampled from qwt (x, y) =
pX (x)ρθ (y). A stochastic natural gradient descent can then be imple-
mented as

wt+1 = wt − γ t (F (wt ; S ′ ) + η t I)−1 g(wt ; S),


410 Second-order optimization

where γ t is a stepsize, possibly chosen by linesearch.


In deep learning, the product with the inverse Fisher or Gauss-
Newton matrices can remain costly to compute. Several approximations
have been proposed, such as KFAC (Martens and Grosse, 2015; Botev
et al., 2017), which uses a computationally efficient structural approxi-
mation to these matrices.

17.4 Quasi-Newton methods

17.4.1 BFGS
A celebrated example of quasi-Newton method is the BFGS method
(Broyden, 1970; Fletcher, 1970; Goldfarb, 1970; Shanno, 1970), whose
acronym follows from its author names. The rationale of the BFGS
update stems once again from a variational viewpoint. We wish to
build a simple quadratic model of the objective ht (w) = L(wt ) +
⟨∇L(wt ), w − wt ⟩ + 12 ⟨w − wt , Qt (w − wt )⟩ for some Qt built along
the iterations rather than taken as ∇2 L(wt ). One desirable property of
such quadratic model would be that its gradients at consecutive iterates
match the gradients of the original function, i.e., ∇ht (wt ) = ∇L(wt )
and ∇ht (wt−1 ) = ∇L(wt−1 ). A simpler condition, called the secant
condition consists in considering the differences of these vectors, that
is, ensuring that
∇ht (wt ) − ∇ht (wt−1 ) = ∇L(wt ) − ∇L(wt−1 )
⇐⇒ Qt (wt − wt−1 ) = ∇L(wt ) − ∇L(wt−1 )
⇐⇒ wt − wt−1 = B t (∇L(wt ) − ∇L(wt−1 )),
for B t = (Qt )−1 . Building B t , a surrogate of the inverse of the Hessian
satisfying the secant equation, can then be done as
   
B t+1 := I −ρt st (y t )⊤ B t I −ρt st (y t )⊤ + ρt st (st )⊤
where
st := wt+1 − wt
y t := ∇L(wt+1 ) − ∇L(wt )
1
ρt := t t .
⟨s , y ⟩
17.5. Approximate Hessian diagonal inverse preconditionners 411

A typical implementation of BFGS stores Bt ∈ RP ×P in memory, which


is prohibitive when P is large.

17.4.2 Limited-memory BFGS

In practice, the limited-memory counterpart of BFGS, called LBFGS (Liu


and Nocedal, 1989), is often preferred. The key observation of LBFGS
is that we do not need to materialize Bt in memory: we only need
to multiply it with the gradient ∇L(wt ). That is, we can see Bt as
a linear map. Fortunately, the product between B t and any vector v
can be computed efficiently if we store (s1 , y 1 , ρ1 ), . . . , (st , y t , ρt ) in
memory. In practice, a small history of past values is used to reduce
memory and computational cost. Because LBFGS has the benefits of
second-order-like methods with much reduced cost, it has become a de-
facto algorithm, outperforming most other algorithms for medium-scale
problems without particular structure (Liu and Nocedal, 1989).

17.5 Approximate Hessian diagonal inverse preconditionners

One application of the approximations of the Hessian diagonal developed


in Section 9.7 is to obtain cheap approximations of the Hessian diagonal
inverse,
B t := diag(|H11
t −1
| , . . . , |HPt P |−1 ).
Such a scaling would for instance be sufficient to make the quadratic
example presented in Fig. 17.1 work. Many optimization algorithms,
including the popular ADAM, can be viewed as using a preconditioner
that approximates the inverse of the Hessian’s diagonal.

17.6 Summary

• We reviewed Newton’s method, the Gauss-Newton method, natu-


ral gradient descent, quasi-Newton methods and preconditioning
methods.

• We adopted a variational viewpoint, where the method’s next


iterate is computed as the solution of a trade-off between mini-
412 Second-order optimization

mizing an approximation of the function (linear, partially linear,


quadratic) and a proximity term (squared Euclidean, KL).

• All methods were shown to use iterates of the form

wt+1 := wt − γ t B t ∇L(wt )

but have different trade-offs between the cost it takes to evaluate


B t ∇L(wt ) and the richness of the information used about L.
18
Duality

In this chapter, we review duality principles in optimization.

18.1 Dual norms

We introduce in this section dual norms, since they are useful in this
book.

Definition 18.1 (Dual norms). Given a norm ∥u∥, its dual is

∥v∥∗ := max ⟨u, v⟩.


∥u∥≤1

Therefore, the dual norm of ∥ · ∥ is the support function of the


unit ball induced by the norm ∥ · ∥,

B∥·∥ := {u ∈ RD : ∥u∥ ≤ 1}.

We give examples of pairs of dual norms below.

413
414 Duality

Example 18.1 (Dual norm of p-norms). The p-norm is defined by


 1/p
D
∥u∥p := 
X
|uj |p  .
j=1

Its dual is ∥v∥q where q is such that p1 + 1q = 1. For instance, the


dual norm of the 2-norm is itself, since 12 + 12 = 1. The 1-norm and
the ∞-norm are dual of each other, since 11 + ∞ 1
= 1.

The definition of dual norm implies a generalization of Cauchy–Schwarz’s


inequality: for all u, v ∈ RD

|⟨u, v⟩| ≤ ∥u∥∗ ∥v∥.

See, e.g., Beck (2017, Lemma 1.4).

Proposition 18.1 (Conjugate of norms and squared norms). We know


that the conjugate of the support function is the indicator function.
Therefore, if f (u) = ∥u∥, then

0 if ∥v∥∗ ≤ 1
f ∗ (v) = ιB∥·∥ (v) = .
∞ otherwise

On the other hand, if f (u) = 21 ∥u∥2 , then


1
f ∗ (v) = ∥v∥2∗ .
2

18.2 Fenchel duality

We consider in this section standard objectives of the form

min L(w) := min ℓ(f (w)) + R(w),


w∈W w∈W

where f : W → M, ℓ : M → R and R : W → R. We first show that


the minimization of this objective, called the primal, can be lower
bounded by a concave maximization objective, called the dual, even
if the primal is nonconvex.
18.2. Fenchel duality 415

Proposition 18.2 (Weak duality). Let f : W → M (potentially non-


linear), ℓ : M → R (potentially nonconvex) and R : W → R (poten-
tially nonconvex). Then

min ℓ(f (w)) + R(w) ≥ max −Rf (α) − ℓ∗ (−α),


w∈W α∈M

where we used the conjugate

ℓ∗ (−α) := max⟨−α, θ⟩ − ℓ(θ)


θ∈M

and the “generalized conjugate”

Rf (α) := max ⟨α, f (w)⟩ − R(w).


w∈W

Moreover, ℓ∗ and Rf are both convex functions.

We emphasize that the result in Proposition 18.2 is fully general, in


the sense that it does not assume the linearity of f or the convexity of ℓ
and R. The caveat, of course, is that Rf and ℓ∗ are difficult to compute
in general, if f is nonlinear, and if ℓ and R are nonconvex.

Proof.

min ℓ(f (w)) + R(w)


w∈W
= min ℓ(θ) + R(w) s.t. θ = f (w)
w∈W
θ∈M
= min max ℓ(θ) + R(w) + ⟨α, θ − f (w)⟩
w∈W α∈M
θ∈M
≥ max min ℓ(θ) + R(w) + ⟨α, θ − f (w)⟩
α∈M w∈W
θ∈M
= max min ⟨α, −f (w)⟩ + R(w) + min ℓ(θ) + ⟨α, θ⟩
α∈M w∈W θ∈M
= max − max ⟨α, f (w)⟩ − R(w) − max⟨−α, θ⟩ − ℓ(θ)
α∈M w∈W θ∈M

= max −R (α) − ℓ (−α).
f
α∈M
416 Duality

In the case when f (w) = Aw, where A is a linear map, and when
both ℓ and R are convex, we can state a much stronger result.

Proposition 18.3 (Strong duality). Let A be a linear map from W


to M. Let ℓ : M → R and R : W → R be convex functions. Let A∗
denote the adjoint of A (Section 2.3). Then,

min ℓ(Aw) + R(w) = max −R∗ (A∗ α) − ℓ∗ (−α).


w∈W α∈M

Furthermore, the primal solution satisfies

w⋆ ∈ arg max⟨Aα⋆ , w⟩ − R(w).


w∈W

When R is strictly convex, the primal solution is uniquely deter-


mined by
w⋆ = ∇R∗ (A∗ α⋆ ).

Proof. Since f (w) = Aw, we have

Rf (α) := max ⟨α, f (w)⟩ − R(w)


w∈W
= max ⟨α, Aw⟩ − R(w)
w∈W
= max ⟨A∗ α, w⟩ − R(w)
w∈W
∗ ∗
= R (A α).

Furthermore, the inequality in the proof of Proposition 18.2 is an


equality, since the min max is that of a convex-concave function.

The maximization problem in Proposition 18.3 is called the Fenchel


dual. By strong duality, the value of the maximum and the value of the
minimum are equal. We can therefore choose to equivalently solve the
dual instead of the primal. This can be advantageous when the space
M is smaller than W.
We now apply the Fenchel dual to obtain the dual of regularized
multiclass linear classification.
18.3. Bregman divergences 417

Table 18.1: Examples of loss conjugates. For regression losses (squared, absolute),
where yi ∈ RM , we define ti = ϕ(yi ) = yi . For classification losses (logistic, per-
ceptron, hinge), where yi ∈ [M ], we define ti = ϕ(yi ) = eyi . To simplify some
expressions, we defined the change of variable µi := yi − αi .

ℓi (θi ) ℓ∗i (−αi )


Squared 1
2 ∥θi − ti ∥22 1 2
2 ∥αi ∥2 − ⟨ti , αi ⟩
Absolute ∥θi − ti ∥1 ι[−1,1]M (αi ) − ⟨ti , αi ⟩
Logistic LSE(θi ) − ⟨θ, ti ⟩ ⟨µi , log µi ⟩ + ι△M (µi )
Perceptron maxi∈[M ] θi − θy ι△M (µi )
Hinge maxi∈[M ] [i ̸= y] + θi − θy ι△M (µi ) − ⟨1 −ti , µi ⟩

Example 18.2 (Sum of separable loss functions). When the loss is


ℓ(θ) := N i=1 ℓi (θi ), where θ = Aw = (A1 w, . . . , AN w) ∈ M and
N
P

Ai is a linear map from W to M, we obtain


N N
min ℓi (Ai w) + R(w) = max −R(A∗ α) − ℓ∗i (−αi ),
X X
w∈W α∈MN
i=1 i=1

where A∗ α = (A∗1 α1 , . . . , A∗N αN ). Typically, we define

Ai w := W xi ,

where W ∈ RM ×D is a reshaped version of w ∈ W, xi ∈ RD is a


training sample, and M is the number of classes. In this case, we
then have
A∗i αi = αi x⊤
i .

Examples of loss function conjugates are given in Table 18.1.

18.3 Bregman divergences

Bregman divergences are a measure of difference between two points.

Definition 18.2 (Bregman divergence). The Bregman divergence gen-


418 Duality

erated by a differentiable convex function f : RD → R is

Bf (u, v) := f (u) − f (v) − ⟨∇f (v), u − v⟩


= ⟨∇f (v), v⟩ − f (v) − [⟨∇f (v), u⟩ − f (u)] ,

where u, v ∈ dom(f ).

Intuitively, the Bregman divergence is the difference between f (u)


and its linearization u 7→ f (v) + ⟨∇f (v), u − v⟩ around v. This is
illustrated in Fig. 18.1.

Example 18.3 (Examples of Bregman divergences). If f (u) = 12 ∥u∥22 ,


where dom(f ) = RD , then
1
Bf (u, v) = ∥u − v∥22 ,
2
the squared Euclidean distance. If f (u) = ⟨u, log u⟩, where
dom(f ) = RD
+ , then

D D D
uj X
Bf (u, v) = uj log uj +
X X
− vj ,
j=1
vj j=1 j=1

the (generalized) Kullback-Leibler divergence.

Properties
Bregman divergences enjoy several useful properties.

Proposition 18.4 (Properties of Bregman divergences). Let f : RD →


R be a differentiable convex function.

1. Non-negativity: Bf (u, v) ≥ 0 for all u, v ∈ dom(f ).

2. Positivity: Bf (u, v) = 0 if and only if u = v (when f is


strictly convex).

3. Convexity: Bf (u, v) is convex in u.

4. Dual-space form: Bf (u, v) = Bf ∗ (b, a), where b = ∇f (v) ∈


18.3. Bregman divergences 419

f (u)
Df (u, v)
f (v) + h∇f (v), u − vi

v u

Figure 18.1: The Bregman divergence generated by f is the difference between


f (u) and its linearization around v.

dom(f ∗ ) and a = ∇f (u) ∈ dom(f ∗ ).

Proof. The properties follow immediately from the convexity of f (u).

1. From Definition 15.6.

2. From the unicity of minimizers.

3. From the fact that u 7→ Bf (u, v) is the sum of f (u) and a linear
function of u.

The Bregman divergence can be used to define natural generaliza-


tions of the Euclidean projection and proximal operators, reviewed in
Section 16.3 and Section 16.4.
Definition 18.3 (Bregman proximal and projection operators). Let v ∈
dom(f ). The Bregman proximal operator is

bproxf,g (v) := arg min Bf (u, v) + g(u).


u∈dom(f )∩dom(g)

In particular, the Bregman projection onto C ⊆ dom(f ) is

bprojf,C (v) := arg min Bf (u, v).


u∈C

It turns out that these operators are intimately connected to the


gradient mapping of the convex conjugate.
420 Duality

Proposition 18.5 (Link with conjugate’s gradient). If Ω = f + g,


then for all θ ∈ dom(f ∗ )

∇Ω∗ (θ) = bproxf,g (∇f ∗ (θ)).

In particular, if Ω = f + ιC , then for all θ ∈ dom(f ∗ )

∇Ω∗ (θ) = bprojf,C (∇f ∗ (θ)).

We give two examples below.

Example 18.4 (Bregman projections on the simplex). If f (u) = 12 ∥u∥22 ,


then
1
bprojf,△D (v) = arg min ∥u − v∥22 .
u∈△D 2

If f (u) = ⟨u, log u − 1⟩, then

bprojf,△D (v) = arg min KL(u, v) = softmax(θ),


u∈RD
+

where v = ∇f ∗ (θ) = exp(θ).

Therefore, the softmax can be seen as a projection onto the proba-


bility simplex in the Kullback-Leilbler divergence sense!

18.4 Fenchel-Young loss functions

We end this chapter with a brief review of the Fenchel-Young family of


loss functions (Blondel et al., 2020), which includes all loss functions in
Table 18.1.
Definition 18.4 (Fenchel-Young loss). The Fenchel-Young loss func-
tion generated by Ω is

ℓΩ (θ, t) := Ω∗ (θ) + Ω(t) − ⟨θ, t⟩

where θ ∈ dom(Ω∗ ) and t ∈ dom(Ω).

Typically, we set θ = f (x, w), where f is a model prediction function


with parameters w and t = ϕ(y), where ϕ : Y → dom(Ω). For instance,
18.5. Summary 421

suppose we work with categorical outputs y ∈ [M ]. Then, we can set


ϕ(y) = ey , where ey is the one-hot encoding of y.
The important point to notice is that the Fenchel-Young loss is
defined over arguments in mixed spaces: θ belongs to the dual space,
while t belongs to the primal space. In fact, the Fenchel-Young loss
is intimately connected to the Bregman divergence, since BΩ (t, v) =
Ω∗ (θ) + Ω(t) − ⟨θ, t⟩, if we set θ = ∇Ω(v). The key properties of
Fenchel-Young loss functions are summarized below.

Proposition 18.6 (Properties of Fenchel-Young loss functions).

1. Non-negativity: ℓΩ (θ, t) ≥ 0 for all θ ∈ dom(Ω∗ ) and


t ∈ dom(Ω).

2. Positivity: ℓΩ (θ, t) = 0 if and only if ∇Ω∗ (θ) = t, assuming


Ω is strictly convex.

3. Convexity: ℓΩ (θ, t) is convex in θ (regardless of Ω) and in t


(if Ω is convex)

4. Relation with composite Bregman divergence:

0≤ BΩ (t, ∇Ω∗ (θ)) ≤ ℓΩ (θ, t) .


| {z } | {z }
possibly nonconvex in θ convex in θ

See Blondel et al. (2020) for an in-depth study of more properties.

18.5 Summary

• The convex conjugate serves as a powerful abstraction in Fenchal


duality, decoupling the dual expression and function-specific
terms.

• The convex conjugate is also tightly connected to Bregman


divergences and can be used to derive the family of Fenchel-
Young loss functions, which can be seen as primal-dual Bregman
divergences.
References

Abadi, M., A. Agarwal, P. Barham, E. Brevdo, Z. Chen, C. Citro, G. S.


Corrado, A. Davis, J. Dean, M. Devin, et al. (2016). “Tensorflow:
Large-scale machine learning on heterogeneous distributed systems”.
arXiv preprint arXiv:1603.04467.
Abernethy, J., C. Lee, and A. Tewari. (2016). “Perturbation techniques
in online learning and optimization”. Perturbations, Optimization,
and Statistics. 233.
Aji, S. M. and R. J. McEliece. (2000). “The generalized distributive
law”. IEEE transactions on Information Theory. 46(2): 325–343.
Amari, S.-I. (1998). “Natural gradient works efficiently in learning”.
Neural computation. 10(2): 251–276.
Bach, F., R. Jenatton, J. Mairal, G. Obozinski, et al. (2012). “Optimiza-
tion with sparsity-inducing penalties”. Foundations and Trends® in
Machine Learning. 4(1): 1–106.
Ball, K., E. A. Carlen, and E. H. Lieb. (2002). “Sharp uniform convexity
and smoothness inequalities for trace norms”. Inequalities: Selecta
of Elliott H. Lieb: 171–190.
Ball, W. W. R. (1960). A short account of the history of mathematics.
Courier Corporation.
Balog, M., N. Tripuraneni, Z. Ghahramani, and A. Weller. (2017). “Lost
relatives of the Gumbel trick”. In: International Conference on
Machine Learning. PMLR. 371–379.

422
References 423

Barndorff-Nielsen, O. (2014). Information and exponential families: in


statistical theory. John Wiley & Sons.
Baston, R. A. and Y. Nakatsukasa. (2022). “Stochastic diagonal esti-
mation: probabilistic bounds and an improved algorithm”. arXiv
preprint arXiv:2201.10684.
Baum, L. E. and T. Petrie. (1966). “Statistical inference for probabilistic
functions of finite state Markov chains”. The annals of mathematical
statistics. 37(6): 1554–1563.
Baur, W. and V. Strassen. (1983). “The complexity of partial deriva-
tives”. Theoretical computer science. 22(3): 317–330.
Bauschke Heinz, H. and L. Combettes Patrick. (2017). Convex Analysis
and Monotone Operator Theory in Hilbert Spaces, 2011. 2nd ed.
978–1.
Baydin, A. G., B. A. Pearlmutter, A. A. Radul, and J. M. Siskind.
(2018). “Automatic differentiation in machine learning: a survey”.
Journal of Marchine Learning Research. 18: 1–43.
Baydin, A. G., B. A. Pearlmutter, D. Syme, F. Wood, and P. Torr. (2022).
“Gradients without backpropagation”. arXiv preprint arXiv:2202.08587.
Beck, A. (2017). First-order methods in optimization. SIAM.
Beck, A. and M. Teboulle. (2012). “Smoothing and first order methods:
A unified framework”. SIAM Journal on Optimization. 22(2): 557–
580.
Becker, S. and Y. Le Cun. (1988). “Improving the convergence of back-
propagation learning with second order methods”. In: Proceedings
of the 1988 connectionist models summer school. 29–37.
Bekas, C., E. Kokiopoulou, and Y. Saad. (2007). “An estimator for the
diagonal of a matrix”. Applied numerical mathematics. 57(11-12):
1214–1229.
Bergstra, J., O. Breuleux, F. Bastien, P. Lamblin, R. Pascanu, G. Des-
jardins, J. Turian, D. Warde-Farley, and Y. Bengio. (2010). “Theano:
a CPU and GPU math expression compiler”. In: Proceedings of the
Python for scientific computing conference (SciPy). Vol. 4. No. 3.
Austin, TX. 1–7.
Berthet, Q., M. Blondel, O. Teboul, M. Cuturi, J.-P. Vert, and F. Bach.
(2020). “Learning with differentiable pertubed optimizers”. Advances
in neural information processing systems. 33: 9508–9519.
424 References

Blelloch, G. E. (1989). “Scans as primitive parallel operations”. IEEE


Transactions on computers. 38(11): 1526–1538.
Blondel, M. (2019). “Structured prediction with projection oracles”.
Advances in neural information processing systems. 32.
Blondel, M., Q. Berthet, M. Cuturi, R. Frostig, S. Hoyer, F. Llinares-
López, F. Pedregosa, and J.-P. Vert. (2021). “Efficient and Modular
Implicit Differentiation”. arXiv preprint arXiv:2105.15183.
Blondel, M., A. F. Martins, and V. Niculae. (2020). “Learning with
fenchel-young losses”. The Journal of Machine Learning Research.
21(1): 1314–1382.
Bolte, J., R. Boustany, E. Pauwels, and B. Pesquet-Popescu. (2022).
“On the complexity of nonsmooth automatic differentiation”. In:
The Eleventh International Conference on Learning Representations.
Bolte, J. and E. Pauwels. (2020). “A mathematical model for automatic
differentiation in machine learning”. Advances in Neural Information
Processing Systems. 33: 10809–10819.
Botev, A., H. Ritter, and D. Barber. (2017). “Practical Gauss-Newton
optimisation for deep learning”. In: International Conference on
Machine Learning. 557–565.
Boumal, N. (2023). An introduction to optimization on smooth manifolds.
Cambridge University Press.
Boyd, S. P. and L. Vandenberghe. (2004). Convex optimization. Cam-
bridge university press.
Bradbury, J., R. Frostig, P. Hawkins, M. J. Johnson, C. Leary, D.
Maclaurin, G. Necula, A. Paszke, J. VanderPlas, S. Wanderman-
Milne, and Q. Zhang. (2018). JAX: composable transformations of
Python+NumPy programs. Version 0.3.13. url: http://github.com/
google/jax.
Braun, M. and M. Golubitsky. (1983). Differential equations and their
applications. Vol. 2. Springer.
Brockhoff, D., A. Auger, N. Hansen, D. V. Arnold, and T. Hohm.
(2010). “Mirrored sampling and sequential selection for evolution
strategies”. In: Parallel Problem Solving from Nature, PPSN XI:
11th International Conference, Kraków, Poland, September 11-15,
2010, Proceedings, Part I 11. Springer. 11–21.
References 425

Broyden, C. G. (1970). “The convergence of a class of double-rank


minimization algorithms 1. general considerations”. IMA Journal
of Applied Mathematics. 6(1): 76–90.
Brucker, P. (1984). “An O(n) algorithm for quadratic knapsack prob-
lems”. Operations Research Letters. 3(3): 163–166.
Butcher, J. C. (2016). Numerical methods for ordinary differential
equations. John Wiley & Sons.
Cajori, F. (1993). A history of mathematical notations. Vol. 1. Courier
Corporation.
Céa, J. (1986). “Conception optimale ou identification de formes, calcul
rapide de la dérivée directionnelle de la fonction coût”. M2AN-
Modélisation mathématique et analyse numérique. 20(3): 371–402.
Chaudhuri, S. and A. Solar-Lezama. (2010). “Smooth interpretation”.
ACM Sigplan Notices. 45(6): 279–291.
Chen, R. T., Y. Rubanova, J. Bettencourt, and D. K. Duvenaud. (2018).
“Neural ordinary differential equations”. Advances in neural infor-
mation processing systems. 31.
Chen, X., N. Kayal, A. Wigderson, et al. (2011). “Partial derivatives in
arithmetic complexity and beyond”. Foundations and Trends® in
Theoretical Computer Science. 6(1–2): 1–138.
Clarke, F. H., Y. S. Ledyaev, R. J. Stern, and P. R. Wolenski. (2008).
Nonsmooth analysis and control theory. Vol. 178. Springer Science
& Business Media.
Clarke, F. H. (1975). “Generalized gradients and applications”. Trans-
actions of the American Mathematical Society. 205: 247–262.
Cohn, D. L. (2013). Measure theory. Vol. 5. Springer.
Condat, L. (2016). “Fast projection onto the simplex and the ℓ1 ball”.
Mathematical Programming. 158(1-2): 575–585.
Dangel, F., F. Kunstner, and P. Hennig. (2019). “Backpack: Packing
more into backprop”. arXiv preprint arXiv:1912.10985.
Davis, J. Q., K. Choromanski, J. Varley, H. Lee, J.-J. Slotine, V.
Likhosterov, A. Weller, A. Makadia, and V. Sindhwani. (2020).
“Time dependence in non-autonomous neural odes”. arXiv preprint
arXiv:2005.01906.
DeGroot, M. H. (1962). “Uncertainty, information, and sequential ex-
periments”. The Annals of Mathematical Statistics. 33(2): 404–419.
426 References

Dehghani, M., J. Djolonga, B. Mustafa, P. Padlewski, J. Heek, J. Gilmer,


A. P. Steiner, M. Caron, R. Geirhos, I. Alabdulmohsin, et al. (2023).
“Scaling vision transformers to 22 billion parameters”. In: Interna-
tional Conference on Machine Learning. PMLR. 7480–7512.
Deisenroth, M. P., A. A. Faisal, and C. S. Ong. (2020). Mathematics
for machine learning. Cambridge University Press.
Drusvyatskiy, D. and C. Paquette. (2019). “Efficiency of minimizing
compositions of convex functions and smooth maps”. Mathematical
Programming. 178: 503–558.
Duchi, J. C., M. I. Jordan, M. J. Wainwright, and A. Wibisono. (2015).
“Optimal rates for zero-order convex optimization: The power of two
function evaluations”. IEEE Transactions on Information Theory.
61(5): 2788–2806.
Duchi, J. C., S. Shalev-Shwartz, Y. Singer, and T. Chandra. (2008).
“Efficient projections onto the ℓ1 -ball for learning in high dimensions”.
In: Proc. of ICML.
Eisner, J. (2016). “Inside-outside and forward-backward algorithms are
just backprop (tutorial paper)”. In: Proceedings of the Workshop on
Structured Prediction for NLP. 1–17.
Elsayed, M. and A. R. Mahmood. (2022). “HesScale: Scalable Compu-
tation of Hessian Diagonals”. arXiv preprint arXiv:2210.11639.
Epperly, E. N., J. A. Tropp, and R. J. Webber. (2023). “XTrace: Making
the most of every sample in stochastic trace estimation”. arXiv
preprint arXiv:2301.07825.
Flanders, H. (1973). “Differentiation under the integral sign”. The
American Mathematical Monthly. 80(6): 615–627.
Fleming, W. H. and R. W. Rishel. (2012). Deterministic and stochastic
optimal control. Vol. 1. Springer Science & Business Media.
Fletcher, R. (1970). “A new approach to variable metric algorithms”.
The computer journal. 13(3): 317–322.
Foerster, J., G. Farquhar, M. Al-Shedivat, T. Rocktäschel, E. Xing,
and S. Whiteson. (2018). “Dice: The infinitely differentiable monte
carlo estimator”. In: International Conference on Machine Learning.
PMLR. 1529–1538.
Forney, G. D. (1973). “The viterbi algorithm”. Proceedings of the IEEE.
61(3): 268–278.
References 427

Franceschi, L., M. Donini, P. Frasconi, and M. Pontil. (2017). “For-


ward and reverse gradient-based hyperparameter optimization”. In:
International Conference on Machine Learning. PMLR. 1165–1173.
Frey, B. J., F. R. Kschischang, H.-A. Loeliger, and N. Wiberg. (1997).
“Factor graphs and algorithms”. In: Proceedings of the Annual Aller-
ton Conference on Communication Control and Computing. Vol. 35.
Citeseer. 666–680.
Frigyik, B. A., S. Srivastava, and M. R. Gupta. (2008). “An introduction
to functional derivatives”. Dept. Electr. Eng., Univ. Washington,
Seattle, WA, Tech. Rep. 1.
Frostig, R., M. J. Johnson, D. Maclaurin, A. Paszke, and A. Radul.
(2021). “Decomposing reverse-mode automatic differentiation”. arXiv
preprint arXiv:2105.09469.
Gautschi, W. (2011). Numerical analysis. Springer Science & Business
Media.
Getreuer, P. (2013). “A survey of Gaussian convolution algorithms”.
Image Processing On Line. 2013: 286–310.
Geweke, J. (1988). “Antithetic acceleration of Monte Carlo integration
in Bayesian inference”. Journal of Econometrics. 38(1-2): 73–89.
Gholaminejad, A., K. Keutzer, and G. Biros. (2019). “ANODE: Uncon-
ditionally Accurate Memory-Efficient Gradients for Neural ODEs”.
In: International Joint Conferences on Artificial Intelligence.
Gini, C. (1912). “Variabilità e mutabilità”. Reprinted in Memorie di
metodologica statistica (Ed. Pizetti E, Salvemini, T). Rome: Libreria
Eredi Virgilio Veschi.
Girard, A. (1989). “A fast ‘Monte-Carlo cross-validation’procedure for
large least squares problems with noisy data”. Numerische Mathe-
matik. 56: 1–23.
Goldfarb, D. (1970). “A family of variable-metric methods derived by
variational means”. Mathematics of computation. 24(109): 23–26.
Gomez, A. N., M. Ren, R. Urtasun, and R. B. Grosse. (2017). “The
reversible residual network: Backpropagation without storing acti-
vations”. Advances in neural information processing systems. 30.
Graves, A., G. Wayne, and I. Danihelka. (2014). “Neural turing ma-
chines”. arXiv preprint arXiv:1410.5401.
428 References

Greig, D. M., B. T. Porteous, and A. H. Seheult. (1989). “Exact max-


imum a posteriori estimation for binary images”. Journal of the
Royal Statistical Society Series B: Statistical Methodology. 51(2):
271–279.
Griewank, A. (1992). “Achieving logarithmic growth of temporal and spa-
tial complexity in reverse automatic differentiation”. Optimization
Methods and Software. 1(1): 35–54. doi: 10.1080/10556789208805505.
Griewank, A. (2003). “A mathematical view of automatic differentia-
tion”. Acta Numerica. 12: 321–398.
Griewank, A. (2012). “Who invented the reverse mode of differentiation”.
Documenta Mathematica, Extra Volume ISMP. 389400.
Griewank, A. and A. Walther. (2008). Evaluating derivatives: principles
and techniques of algorithmic differentiation. SIAM.
Grimm, J., L. Pottier, and N. Rostaing-Schmidt. (1996). “Optimal time
and minimum space-time product for reversing a certain class of
programs”. PhD thesis. INRIA.
Grünwald, P. D. and A. P. Dawid. (2004). “Game theory, maximum
entropy, minimum discrepancy and robust Bayesian decision theory”.
Annals of Statistics: 1367–1433.
Hallman, E., I. C. Ipsen, and A. K. Saibaba. (2023). “Monte Carlo
methods for estimating the diagonal of a real symmetric matrix”.
SIAM Journal on Matrix Analysis and Applications. 44(1): 240–269.
He, K., X. Zhang, S. Ren, and J. Sun. (2016). “Deep residual learning
for image recognition”. In: Proceedings of the IEEE conference on
computer vision and pattern recognition. 770–778.
Helfrich, K., D. Willmott, and Q. Ye. (2018). “Orthogonal recurrent
neural networks with scaled Cayley transform”. In: International
Conference on Machine Learning. PMLR. 1969–1978.
Hestenes, M. R., E. Stiefel, et al. (1952). Methods of conjugate gradients
for solving linear systems. Vol. 49. No. 1. NBS Washington, DC.
Hewitt, E. (1948). “Rings of real-valued continuous functions. I”. Trans-
actions of the American Mathematical Society. 64(1): 45–99.
Hida, T. and M. Hitsuda. (1976). Gaussian processes. Vol. 120. American
Mathematical Soc.
References 429

Hiriart-Urruty, J.-B. and C. Lemaréchal. (1993). Convex analysis and


minimization algorithms II. Vol. 305. Springer science & business
media.
Hutchinson, M. F. (1989). “A stochastic estimator of the trace of the
influence matrix for Laplacian smoothing splines”. Communications
in Statistics-Simulation and Computation. 18(3): 1059–1076.
Jaggi, M. (2013). “Revisiting Frank-Wolfe: Projection-free sparse convex
optimization”. In: International conference on machine learning.
PMLR. 427–435.
Jang, E., S. Gu, and B. Poole. (2016). “Categorical reparameterization
with gumbel-softmax”. arXiv preprint arXiv:1611.01144.
Jayaram, B. and M. Baczynski. (2008). Fuzzy Implications. Vol. 231.
Springer Science & Business Media.
Kakade, S., S. Shalev-Shwartz, A. Tewari, et al. (2009). “On the duality
of strong convexity and strong smoothness: Learning appl ications
and matrix regularization”. Tech report. 2(1): 35.
Karpathy, A. (2017). “Software 2.0”.
Kelley, C. T. (1995). Iterative methods for linear and nonlinear equations.
SIAM.
Kingma, D. P. and J. Ba. (2014). “Adam: A method for stochastic
optimization”. arXiv preprint arXiv:1412.6980.
Kingma, D. P. and M. Welling. (2013). “Auto-encoding variational
bayes”. arXiv preprint arXiv:1312.6114.
Klir, G. and B. Yuan. (1995). Fuzzy sets and fuzzy logic. Vol. 4. Prentice
hall New Jersey.
Kobyzev, I., S. Prince, and M. A. Brubaker. (2019). “Normalizing flows:
Introduction and ideas”. stat. 1050: 25.
Kreikemeyer, J. N. and P. Andelfinger. (2023). “Smoothing methods
for automatic differentiation across conditional branches”. IEEE
Access.
Krieken, E., J. Tomczak, and A. Ten Teije. (2021). “Storchastic: A frame-
work for general stochastic automatic differentiation”. Advances in
Neural Information Processing Systems. 34: 7574–7587.
Kunstner, F., P. Hennig, and L. Balles. (2019). “Limitations of the em-
pirical Fisher approximation for natural gradient descent”. Advances
in neural information processing systems. 32.
430 References

Lafferty, J., A. McCallum, and F. C. Pereira. (2001). “Conditional


random fields: Probabilistic models for segmenting and labeling
sequence data”.
Lan, G. (2012). “An optimal method for stochastic composite optimiza-
tion”. Mathematical Programming. 133(1-2): 365–397.
LeCun, Y. (1988). “A theoretical framework for back-propagation”. In:
Proceedings of the 1988 connectionist models summer school. Vol. 1.
21–28.
LeCun, Y. (2018). “Deep Learning est mort. Vive Differentiable Pro-
gramming!”
Levenberg, K. (1944). “A method for the solution of certain non-linear
problems in least squares”. Quarterly of applied mathematics. 2(2):
164–168.
Liu, D. C. and J. Nocedal. (1989). “On the limited memory method for
large scale optimization”. Mathematical Programming. 45: 503–528.
Liu, H., Z. Li, D. Hall, P. Liang, and T. Ma. (2023). “Sophia: A Scal-
able Stochastic Second-order Optimizer for Language Model Pre-
training”. arXiv preprint arXiv:2305.14342.
Loeliger, H.-A. (2004). “An introduction to factor graphs”. IEEE Signal
Processing Magazine. 21(1): 28–41.
Loshchilov, I. and F. Hutter. (2016). “SGDR: Stochastic gradient de-
scent with warm restarts”. In: International Conference on Learning
Representations.
Lucet, Y. (1997). “Faster than the fast Legendre transform, the linear-
time Legendre transform”. Numerical Algorithms. 16: 171–185.
Maclaurin, D., D. Duvenaud, and R. P. Adams. (2015). “Autograd:
Effortless gradients in numpy”. In: ICML 2015 AutoML workshop.
Vol. 238. No. 5.
Maddison, C. J., A. Mnih, and Y. W. Teh. (2016). “The concrete
distribution: A continuous relaxation of discrete random variables”.
arXiv preprint arXiv:1611.00712.
Marquardt, D. W. (1963). “An algorithm for least-squares estimation
of nonlinear parameters”. Journal of the society for Industrial and
Applied Mathematics. 11(2): 431–441.
References 431

Martens, J. (2020). “New insights and perspectives on the natural


gradient method”. Journal of Machine Learning Research. 21(1):
5776–5851.
Martens, J. and R. Grosse. (2015). “Optimizing neural networks with
Kronecker-factored approximate curvature”. In: International con-
ference on machine learning. 2408–2417.
Martins, A. and R. Astudillo. (2016). “From softmax to sparsemax: A
sparse model of attention and multi-label classification”. In: Inter-
national conference on machine learning. PMLR. 1614–1623.
Martins, J. R., P. Sturdza, and J. J. Alonso. (2003). “The complex-
step derivative approximation”. ACM Transactions on Mathematical
Software (TOMS). 29(3): 245–262.
Meent, J.-W. van de, B. Paige, H. Yang, and F. Wood. (2018). “An intro-
duction to probabilistic programming”. arXiv preprint arXiv:1809.10756.
Mensch, A. and M. Blondel. (2018). “Differentiable dynamic program-
ming for structured prediction and attention”. In: International
Conference on Machine Learning. PMLR. 3462–3471.
Messerer, F., K. Baumgärtner, and M. Diehl. (2021). “Survey of sequen-
tial convex programming and generalized Gauss-Newton methods”.
ESAIM: Proceedings and Surveys. 71: 64–88.
Meyer, R. A., C. Musco, C. Musco, and D. P. Woodruff. (2021).
“Hutch++: Optimal stochastic trace estimation”. In: Symposium on
Simplicity in Algorithms (SOSA). SIAM. 142–155.
Michelot, C. (1986). “A finite algorithm for finding the projection of a
point onto the canonical simplex of Rn ”. Journal of Optimization
Theory and Applications. 50(1): 195–200.
Mohamed, S., M. Rosca, M. Figurnov, and A. Mnih. (2020). “Monte
carlo gradient estimation in machine learning”. The Journal of
Machine Learning Research. 21(1): 5183–5244.
Mohri, M., F. Pereira, and M. Riley. (2008). “Speech recognition with
weighted finite-state transducers”. Springer Handbook of Speech
Processing: 559–584.
Morgenstern, J. (1985). “How to compute fast a function and all its
derivatives: A variation on the theorem of Baur-Strassen”. ACM
SIGACT News. 16(4): 60–62.
432 References

Morrey Jr, C. B. (2009). Multiple integrals in the calculus of variations.


Springer Science & Business Media.
Murphy, K. P. (2022). Probabilistic Machine Learning: An introduction.
MIT Press. url: http://probml.github.io/book1.
Murphy, K. P. (2023). Probabilistic Machine Learning: Advanced Topics.
MIT Press. url: http://probml.github.io/book2.
Mutze, U. (2013). “An asynchronous leapfrog method II”. arXiv preprint
arXiv:1311.6602.
Nemirovski, A. and D. Yudin. (1983). “Problem complexity and method
efficiency in optimization”.
Nesterov, Y. (2005). “Smooth minimization of non-smooth functions”.
Mathematical programming. 103: 127–152.
Nesterov, Y. (2007). “Modified Gauss–Newton scheme with worst case
guarantees for global performance”. Optimisation methods and soft-
ware. 22(3): 469–483.
Nesterov, Y. (2018). Lectures on convex optimization. Vol. 137. Springer.
Nesterov, Y. and V. Spokoiny. (2017). “Random gradient-free minimiza-
tion of convex functions”. Foundations of Computational Mathemat-
ics. 17: 527–566.
Papamakarios, G., E. Nalisnick, D. J. Rezende, S. Mohamed, and
B. Lakshminarayanan. (2021). “Normalizing flows for probabilistic
modeling and inference”. The Journal of Machine Learning Research.
22(1): 2617–2680.
Parikh, N., S. Boyd, et al. (2014). “Proximal algorithms”. Foundations
and trends® in Optimization. 1(3): 127–239.
Paszke, A., S. Gross, F. Massa, A. Lerer, J. Bradbury, G. Chanan, T.
Killeen, Z. Lin, N. Gimelshein, L. Antiga, A. Desmaison, A. Kopf, E.
Yang, Z. DeVito, M. Raison, A. Tejani, S. Chilamkurthy, B. Steiner,
L. Fang, J. Bai, and S. Chintala. (2019). “PyTorch: An Imperative
Style, High-Performance Deep Learning Library”. In: Advances in
Neural Information Processing Systems 32. 8024–8035.
Paulus, M., D. Choi, D. Tarlow, A. Krause, and C. J. Maddison. (2020).
“Gradient estimation with stochastic softmax tricks”. Advances in
Neural Information Processing Systems. 33: 5691–5704.
References 433

Petersen, F., C. Borgelt, H. Kuehne, and O. Deussen. (2021). “Learning


with algorithmic supervision via continuous relaxations”. Advances
in Neural Information Processing Systems. 34: 16520–16531.
Peyré, G. (2020). “Mathematical foundations of data sciences”. Rn. 1:
2.
Peyré, G. and M. Cuturi. (2019). “Computational optimal transport:
With applications to data science”. Foundations and Trends® in
Machine Learning. 11(5-6): 355–607.
Pollock, S. and L. G. Rebholz. (2021). “Anderson acceleration for con-
tractive and noncontractive operators”. IMA Journal of Numerical
Analysis. 41(4): 2841–2872.
Polyak, B. (1963). “Gradient methods for the minimisation of function-
als”. USSR Computational Mathematics and Mathematical Physics.
3(4): 864–878.
Polyak, B. T. (1964). “Some methods of speeding up the convergence
of iteration methods”. Ussr computational mathematics and mathe-
matical physics. 4(5): 1–17.
Pontryagin, L. S. (1985). “The mathematical theory of optimal processes
and differential games”. Trudy Mat. Inst. Steklov. 169: 119–158.
Rabiner, L. R. (1989). “A tutorial on hidden Markov models and selected
applications in speech recognition”. Proceedings of the IEEE. 77(2):
257–286.
Rademacher, H. (1919). “Über partielle und totale differenzierbarkeit
von Funktionen mehrerer Variabeln und über die Transformation
der Doppelintegrale”. Mathematische Annalen. 79(4): 340–359.
Radul, A., A. Paszke, R. Frostig, M. Johnson, and D. Maclaurin. (2022).
“You only linearize once: Tangents transpose to gradients”. arXiv
preprint arXiv:2204.10923.
Recht, B. (2016). “Mates of Costate”.
Recht, B. and R. Frostig. (2017). “Nesterov’s Punctuated Equilibrium”.
Rezende, D. J., S. Mohamed, and D. Wierstra. (2014). “Stochastic back-
propagation and approximate inference in deep generative models”.
In: International conference on machine learning. PMLR. 1278–
1286.
Rockafellar, R. T. and R. J.-B. Wets. (2009). Variational analysis.
Vol. 317. Springer Science & Business Media.
434 References

Rodriguez, O. H. and J. M. Lopez Fernandez. (2010). “A semiotic


reflection on the didactics of the chain rule”. The Mathematics
Enthusiast. 7(2): 321–332.
Roulet, V. and Z. Harchaoui. (2022). “Differentiable programming à la
Moreau”. In: ICASSP 2022-2022 IEEE International Conference
on Acoustics, Speech and Signal Processing (ICASSP). IEEE. 3498–
3502.
Saad, Y. and M. H. Schultz. (1986). “GMRES: A generalized minimal
residual algorithm for solving nonsymmetric linear systems”. SIAM
Journal on scientific and statistical computing. 7(3): 856–869.
Salimans, T., J. Ho, X. Chen, S. Sidor, and I. Sutskever. (2017). “Evo-
lution strategies as a scalable alternative to reinforcement learning”.
arXiv preprint arXiv:1703.03864.
Sander, M. E., P. Ablin, M. Blondel, and G. Peyré. (2021a). “Momentum
residual neural networks”. In: International Conference on Machine
Learning. PMLR. 9276–9287.
Sander, M. E., P. Ablin, M. Blondel, and G. Peyré. (2021b). “Momentum
residual neural networks”. In: International Conference on Machine
Learning. PMLR. 9276–9287.
Satterthwaite, F. (1942). “Generalized poisson distribution”. The Annals
of Mathematical Statistics. 13(4): 410–417.
Schlag, I., K. Irie, and J. Schmidhuber. (2021). “Linear transformers
are secretly fast weight programmers”. In: International Conference
on Machine Learning. PMLR. 9355–9366.
Schölkopf, B. and A. J. Smola. (2002). Learning with kernels: support
vector machines, regularization, optimization, and beyond. MIT
press.
Schulman, J., N. Heess, T. Weber, and P. Abbeel. (2015). “Gradient
estimation using stochastic computation graphs”. Advances in neural
information processing systems. 28.
Schwartz, J. (1954). “The formula for change in variables in a multiple
integral”. The American Mathematical Monthly. 61(2): 81–85.
Schwarz, H. (1873). “Communication”. Archives des Sciences Physiques
et Naturelles. 48: 38–44.
Sengupta, S., M. J. Harris, M. Garland, and J. D. Owens. (2010).
“Efficient Parallel Scan Algorithms for Manycore GPUs.”
References 435

Shanno, D. F. (1970). “Conditioning of quasi-Newton methods for


function minimization”. Mathematics of computation. 24(111): 647–
656.
Shannon, C. E. (1948). “A mathematical theory of communication”.
The Bell system technical journal. 27(3): 379–423.
Shawe-Taylor, J. and N. Cristianini. (2004). Kernel methods for pattern
analysis. Cambridge university press.
Squire, W. and G. Trapp. (1998). “Using complex variables to estimate
derivatives of real functions”. SIAM review. 40(1): 110–112.
Stoer, J., R. Bulirsch, R. Bartels, W. Gautschi, and C. Witzgall. (1980).
Introduction to numerical analysis. Vol. 1993. Springer.
Stumm, P. and A. Walther. (2010). “New algorithms for optimal online
checkpointing”. SIAM Journal on Scientific Computing. 32(2): 836–
854.
Sutskever, I., J. Martens, G. Dahl, and G. Hinton. (2013). “On the
importance of initialization and momentum in deep learning”. In:
International conference on machine learning. PMLR. 1139–1147.
Sutton, C., A. McCallum, et al. (2012). “An introduction to conditional
random fields”. Foundations and Trends® in Machine Learning. 4(4):
267–373.
Sutton, R. S., D. McAllester, S. Singh, and Y. Mansour. (1999). “Policy
gradient methods for reinforcement learning with function approxi-
mation”. Advances in neural information processing systems. 12.
Taylor, M. (2002). “Differential forms and the change of variable for-
mula for multiple integrals”. Journal of mathematical analysis and
applications. 268(1): 378–383.
Tibshirani, R. (1996). “Regression shrinkage and selection via the lasso”.
Journal of the Royal Statistical Society: Series B (Methodological).
58(1): 267–288.
Tignol, J.-P. (2015). Galois’ theory of algebraic equations. World Scien-
tific Publishing Company.
Tsallis, C. (1988). “Possible generalization of Boltzmann-Gibbs statis-
tics”. Journal of statistical physics. 52: 479–487.
van Krieken, E. (2024). “Optimisation in Neurosymbolic Learning Sys-
tems”. PhD thesis. Vrije Universiteit Amsterdam.
436 References

Vaswani, A., N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez,


Ł. Kaiser, and I. Polosukhin. (2017). “Attention is all you need”.
Advances in neural information processing systems. 30.
Vaswani, S., A. Mishkin, I. Laradji, M. Schmidt, G. Gidel, and S.
Lacoste-Julien. (2019). “Painless stochastic gradient: Interpolation,
line-search, and convergence rates”. Advances in neural information
processing systems. 32.
Verdu, S. and H. V. Poor. (1987). “Abstract dynamic programming
models under commutativity conditions”. SIAM Journal on Control
and Optimization. 25(4): 990–1006.
Vicol, P., L. Metz, and J. Sohl-Dickstein. (2021). “Unbiased gradient
estimation in unrolled computation graphs with persistent evolu-
tion strategies”. In: International Conference on Machine Learning.
PMLR. 10553–10563.
Virtanen, P., R. Gommers, T. E. Oliphant, M. Haberland, T. Reddy,
D. Cournapeau, E. Burovski, P. Peterson, W. Weckesser, J. Bright,
S. J. van der Walt, M. Brett, J. Wilson, K. J. Millman, N. Mayorov,
A. R. J. Nelson, E. Jones, R. Kern, E. Larson, C. J. Carey, İ. Polat,
Y. Feng, E. W. Moore, J. VanderPlas, D. Laxalde, J. Perktold,
R. Cimrman, I. Henriksen, E. A. Quintero, C. R. Harris, A. M.
Archibald, A. H. Ribeiro, F. Pedregosa, P. van Mulbregt, and SciPy
1.0 Contributors. (2020). “SciPy 1.0: Fundamental Algorithms for
Scientific Computing in Python”. Nature Methods. 17: 261–272.
Viterbi, A. (1967). “Error bounds for convolutional codes and an asymp-
totically optimum decoding algorithm”. IEEE transactions on In-
formation Theory. 13(2): 260–269.
Vorst, H. A. v. d. and H. A. van der Vorst. (1992). “Bi-CGSTAB: A
Fast and Smoothly Converging Variant of Bi-CG for the Solution of
Nonsymmetric Linear Systems”. SIAM Journal on Scientific and
Statistical Computing. 13(2): 631–644. url: http://dx.doi.org/10.
1137/0913035.
Wainwright, M. J. and M. I. Jordan. (2008). “Graphical models, expo-
nential families, and variational inference”. Foundations and Trends®
in Machine Learning. 1(1–2): 1–305.
References 437

Wang, Q., P. Moin, and G. Iaccarino. (2009). “Minimal repetition


dynamic checkpointing algorithm for unsteady adjoint calculation”.
SIAM Journal on Scientific Computing. 31(4): 2549–2567.
Wei, C., S. Kakade, and T. Ma. (2020). “The implicit and explicit
regularization effects of dropout”. In: International conference on
machine learning. PMLR. 10181–10192.
Werbos, P. J. (1990). “Backpropagation through time: what it does and
how to do it”. Proceedings of the IEEE. 78(10): 1550–1560.
Werbos, P. J. (1994). The roots of backpropagation: from ordered deriva-
tives to neural networks and political forecasting. Vol. 1. John Wiley
& Sons.
Wright, S. and J. Nocedal. (1999). “Numerical optimization”. Springer
Science. 35(67-68): 7.
Xu, P., F. Roosta, and M. W. Mahoney. (2020). “Second-order opti-
mization for non-convex machine learning: An empirical study”. In:
Proceedings of the 2020 SIAM International Conference on Data
Mining. SIAM. 199–207.
Yuan, M. and Y. Lin. (2006). “Model selection and estimation in re-
gression with grouped variables”. Journal of the Royal Statistical
Society: Series B (Statistical Methodology). 68(1): 49–67.
Zhang, A., Z. C. Lipton, M. Li, and A. J. Smola. (2021). “Dive into
deep learning”. arXiv preprint arXiv:2106.11342.
Zhou, X. (2018). “On the fenchel duality between strong convexity and
lipschitz continuou s gradient”. arXiv preprint arXiv:1803.06573.
Zhuang, J., N. C. Dvornek, S. Tatikonda, and J. S. Duncan. (2021).
“Mali: A memory efficient and reverse accurate integrator for neural
odes”. arXiv preprint arXiv:2102.04668.
Ziegler, D. M., N. Stiennon, J. Wu, T. B. Brown, A. Radford, D. Amodei,
P. Christiano, and G. Irving. (2019). “Fine-tuning language models
from human preferences”. arXiv preprint arXiv:1909.08593.

You might also like