0% found this document useful (0 votes)
117 views

The Elements of Differentiable Programming

Uploaded by

mohamed
Copyright
© © All Rights Reserved
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
117 views

The Elements of Differentiable Programming

Uploaded by

mohamed
Copyright
© © All Rights Reserved
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 300

arXiv:2403.14606v1 [cs.

LG] 21 Mar 2024

The Elements of
Differentiable Programming

Mathieu Blondel
Google DeepMind
[email protected]
Vincent Roulet
Google DeepMind
[email protected]

Draft (last update: March 21, 2024)


Contents

1 Introduction 4
1.1 What is differentiable programming? . . . . . . . . . . . . 4
1.2 Book goals and scope . . . . . . . . . . . . . . . . . . . . 6
1.3 Intended audience . . . . . . . . . . . . . . . . . . . . . . 7
1.4 How to read this book? . . . . . . . . . . . . . . . . . . . 7
1.5 Related work . . . . . . . . . . . . . . . . . . . . . . . . . 7

I Fundamentals 9

2 Differentiation 10
2.1 Univariate functions . . . . . . . . . . . . . . . . . . . . . 10
2.1.1 Derivatives . . . . . . . . . . . . . . . . . . . . . . 10
2.1.2 Calculus rules . . . . . . . . . . . . . . . . . . . . 13
2.1.3 Leibniz’s notation . . . . . . . . . . . . . . . . . . 15
2.2 Multivariate functions . . . . . . . . . . . . . . . . . . . . 16
2.2.1 Directional derivatives . . . . . . . . . . . . . . . . 16
2.2.2 Gradients . . . . . . . . . . . . . . . . . . . . . . 17
2.2.3 Jacobians . . . . . . . . . . . . . . . . . . . . . . 20
2.3 Linear differentiation maps . . . . . . . . . . . . . . . . . 26
2.3.1 The need for linear maps . . . . . . . . . . . . . . 26
2.3.2 Euclidean spaces . . . . . . . . . . . . . . . . . . . 27
2.3.3 Linear maps and their adjoints . . . . . . . . . . . 28
2.3.4 Jacobian-vector products . . . . . . . . . . . . . . 29
2.3.5 Vector-Jacobian products . . . . . . . . . . . . . . 30
2.3.6 Chain rule . . . . . . . . . . . . . . . . . . . . . . 31
2.3.7 Functions of multiple inputs (fan-in) . . . . . . . . 32
2.3.8 Functions of multiple outputs (fan-out) . . . . . . 34
2.3.9 Extensions to non-Euclidean linear spaces . . . . . 34
2.4 Second-order differentiation . . . . . . . . . . . . . . . . . 36
2.4.1 Second derivatives . . . . . . . . . . . . . . . . . . 36
2.4.2 Second directional derivatives . . . . . . . . . . . . 36
2.4.3 Hessians . . . . . . . . . . . . . . . . . . . . . . . 37
2.4.4 Hessian-vector products . . . . . . . . . . . . . . . 39
2.4.5 Second-order Jacobians . . . . . . . . . . . . . . . 39
2.5 Higher-order differentiation . . . . . . . . . . . . . . . . . 40
2.5.1 Higher-order derivatives . . . . . . . . . . . . . . . 40
2.5.2 Higher-order directional derivatives . . . . . . . . . 41
2.5.3 Higher-order Jacobians . . . . . . . . . . . . . . . 41
2.5.4 Taylor expansions . . . . . . . . . . . . . . . . . . 42
2.6 Differential geometry . . . . . . . . . . . . . . . . . . . . 43
2.6.1 Differentiability on manifolds . . . . . . . . . . . . 43
2.6.2 Tangent spaces and pushforward operators . . . . . 44
2.6.3 Cotangent spaces and pullback operators . . . . . 45
2.7 Generalized derivatives . . . . . . . . . . . . . . . . . . . 48
2.7.1 Rademacher’s theorem . . . . . . . . . . . . . . . 49
2.7.2 Clarke derivatives . . . . . . . . . . . . . . . . . . 49
2.8 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 52

3 Probabilistic learning 54
3.1 Probability distributions . . . . . . . . . . . . . . . . . . . 54
3.1.1 Discrete probability distributions . . . . . . . . . . 54
3.1.2 Continuous probability distributions . . . . . . . . 55
3.2 Maximum likelihood estimation . . . . . . . . . . . . . . . 56
3.2.1 Negative log-likelihood . . . . . . . . . . . . . . . 56
3.2.2 Consistency w.r.t. the Kullback-Leibler divergence . 56
3.3 Probabilistic supervised learning . . . . . . . . . . . . . . 57
3.3.1 Conditional probability distributions . . . . . . . . 57
3.3.2 Inference . . . . . . . . . . . . . . . . . . . . . . . 57
3.3.3 Binary classification . . . . . . . . . . . . . . . . . 58
3.3.4 Multiclass classification . . . . . . . . . . . . . . . 60
3.3.5 Regression . . . . . . . . . . . . . . . . . . . . . . 61
3.3.6 Multivariate regression . . . . . . . . . . . . . . . 62
3.3.7 Integer regression . . . . . . . . . . . . . . . . . . 63
3.3.8 Loss functions . . . . . . . . . . . . . . . . . . . . 63
3.4 Exponential family distributions . . . . . . . . . . . . . . . 65
3.4.1 Definition . . . . . . . . . . . . . . . . . . . . . . 65
3.4.2 The log-partition function . . . . . . . . . . . . . . 67
3.4.3 Maximum entropy principle . . . . . . . . . . . . . 68
3.4.4 Maximum likelihood estimation . . . . . . . . . . . 69
3.4.5 Probabilistic learning with exponential families . . . 69
3.5 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 71

II Differentiable programs 72

4 Parameterized programs 73
4.1 Representing computer programs . . . . . . . . . . . . . . 73
4.1.1 Computation chains . . . . . . . . . . . . . . . . . 73
4.1.2 Directed acylic graphs . . . . . . . . . . . . . . . . 74
4.1.3 Computer programs as DAGs . . . . . . . . . . . . 76
4.1.4 Arithmetic circuits . . . . . . . . . . . . . . . . . . 78
4.2 Feedforward networks . . . . . . . . . . . . . . . . . . . . 79
4.3 Multilayer perceptrons . . . . . . . . . . . . . . . . . . . . 79
4.3.1 Combining affine layers and activations . . . . . . . 79
4.3.2 Link with generalized linear models . . . . . . . . . 80
4.4 Activation functions . . . . . . . . . . . . . . . . . . . . . 81
4.4.1 Scalar-to-scalar nonlinearities . . . . . . . . . . . . 81
4.4.2 Vector-to-scalar nonlinearities . . . . . . . . . . . . 81
4.4.3 Scalar-to-scalar probability mappings . . . . . . . . 82
4.4.4 Vector-to-vector probability mappings . . . . . . . 83
4.5 Residual neural networks . . . . . . . . . . . . . . . . . . 85
4.6 Recurrent neural networks . . . . . . . . . . . . . . . . . . 86
4.6.1 Vector to sequence . . . . . . . . . . . . . . . . . 86
4.6.2 Sequence to vector . . . . . . . . . . . . . . . . . 88
4.6.3 Sequence to sequence (aligned) . . . . . . . . . . . 88
4.6.4 Sequence to sequence (unaligned) . . . . . . . . . 88
4.7 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 89

5 Control flows 90
5.1 Comparison operators . . . . . . . . . . . . . . . . . . . . 90
5.2 Soft inequality operators . . . . . . . . . . . . . . . . . . 92
5.2.1 Heuristic definition . . . . . . . . . . . . . . . . . 92
5.2.2 Stochastic process perspective . . . . . . . . . . . 92
5.3 Soft equality operators . . . . . . . . . . . . . . . . . . . 93
5.3.1 Heuristic definition . . . . . . . . . . . . . . . . . 93
5.3.2 Gaussian process perspective . . . . . . . . . . . . 94
5.4 Logical operators . . . . . . . . . . . . . . . . . . . . . . 95
5.5 Continuous extensions of logical operators . . . . . . . . . 96
5.5.1 Probabilistic continuous extension . . . . . . . . . 96
5.5.2 Triangular norms and co-norms . . . . . . . . . . . 98
5.6 If-else statements . . . . . . . . . . . . . . . . . . . . . . 98
5.6.1 Differentiating through branch variables . . . . . . 99
5.6.2 Differentiating through predicate variables . . . . . 100
5.6.3 Continuous relaxations . . . . . . . . . . . . . . . 101
5.7 Else-if statements . . . . . . . . . . . . . . . . . . . . . . 102
5.7.1 Encoding K branches . . . . . . . . . . . . . . . . 103
5.7.2 Conditionals . . . . . . . . . . . . . . . . . . . . . 104
5.7.3 Differentiating through branch variables . . . . . . 105
5.7.4 Differentiating through predicate variables . . . . . 106
5.7.5 Continuous relaxations . . . . . . . . . . . . . . . 106
5.8 For loops . . . . . . . . . . . . . . . . . . . . . . . . . . . 108
5.9 Scan functions . . . . . . . . . . . . . . . . . . . . . . . . 109
5.10 While loops . . . . . . . . . . . . . . . . . . . . . . . . . 110
5.10.1 While loops as cyclic graphs . . . . . . . . . . . . 110
5.10.2 Unrolled while loops . . . . . . . . . . . . . . . . . 111
5.10.3 Markov chain perspective . . . . . . . . . . . . . . 113
5.11 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 116
III Differentiating through programs 117

6 Finite differences 118


6.1 Forward differences . . . . . . . . . . . . . . . . . . . . . 118
6.2 Backward differences . . . . . . . . . . . . . . . . . . . . 119
6.3 Central differences . . . . . . . . . . . . . . . . . . . . . . 120
6.4 Higher-accuracy finite differences . . . . . . . . . . . . . . 121
6.5 Higher-order finite differences . . . . . . . . . . . . . . . . 122
6.6 Complex-step derivatives . . . . . . . . . . . . . . . . . . 123
6.7 Complexity . . . . . . . . . . . . . . . . . . . . . . . . . . 124
6.8 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 124

7 Automatic differentiation 126


7.1 Computation chains . . . . . . . . . . . . . . . . . . . . . 126
7.1.1 Forward-mode . . . . . . . . . . . . . . . . . . . . 127
7.1.2 Reverse-mode . . . . . . . . . . . . . . . . . . . . 129
7.1.3 Complexity of entire Jacobians . . . . . . . . . . . 134
7.2 Feedforward networks . . . . . . . . . . . . . . . . . . . . 136
7.2.1 Computing the adjoint . . . . . . . . . . . . . . . 136
7.2.2 Computing the gradient . . . . . . . . . . . . . . . 137
7.3 Computation graphs . . . . . . . . . . . . . . . . . . . . . 139
7.3.1 Forward-mode . . . . . . . . . . . . . . . . . . . . 139
7.3.2 Reverse-mode . . . . . . . . . . . . . . . . . . . . 140
7.3.3 Complexity, the Baur-Strassen theorem . . . . . . . 140
7.4 Implementation . . . . . . . . . . . . . . . . . . . . . . . 141
7.4.1 Primitive functions . . . . . . . . . . . . . . . . . 141
7.4.2 Closure under function composition . . . . . . . . 142
7.4.3 Examples of JVPs and VJPs . . . . . . . . . . . . 143
7.4.4 Automatic linear transposition . . . . . . . . . . . 144
7.5 Checkpointing . . . . . . . . . . . . . . . . . . . . . . . . 145
7.5.1 Recursive halving . . . . . . . . . . . . . . . . . . 146
7.5.2 Dynamic programming . . . . . . . . . . . . . . . 148
7.5.3 Online checkpointing . . . . . . . . . . . . . . . . 150
7.6 Reversible layers . . . . . . . . . . . . . . . . . . . . . . . 151
7.6.1 General case . . . . . . . . . . . . . . . . . . . . . 151
7.6.2 Case of orthonormal JVPs . . . . . . . . . . . . . 151
7.7 Randomized forward-mode estimator . . . . . . . . . . . . 152
7.8 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 152

8 Second-order automatic differentiation 154


8.1 Hessian-vector products . . . . . . . . . . . . . . . . . . . 154
8.1.1 Four possible methods . . . . . . . . . . . . . . . 154
8.1.2 Complexity . . . . . . . . . . . . . . . . . . . . . . 155
8.2 Gauss-Newton matrix . . . . . . . . . . . . . . . . . . . . 159
8.2.1 An approximation of the Hessian . . . . . . . . . . 159
8.2.2 Gauss-Newton chain rule . . . . . . . . . . . . . . 160
8.2.3 Gauss-Newton vector product . . . . . . . . . . . . 160
8.2.4 Gauss-Newton matrix factorization . . . . . . . . . 161
8.2.5 Stochastic setting . . . . . . . . . . . . . . . . . . 162
8.3 Fisher information matrix . . . . . . . . . . . . . . . . . . 162
8.3.1 Definition using the score function . . . . . . . . . 162
8.3.2 Link with the Hessian . . . . . . . . . . . . . . . . 163
8.3.3 Equivalence with the Gauss-Newton matrix . . . . 163
8.4 Inverse-Hessian vector product . . . . . . . . . . . . . . . 165
8.4.1 Definition as a linear map . . . . . . . . . . . . . . 165
8.4.2 Implementation with matrix-free linear solvers . . . 165
8.4.3 Complexity . . . . . . . . . . . . . . . . . . . . . . 166
8.5 Second-order backpropagation . . . . . . . . . . . . . . . 167
8.5.1 Second-order Jacobian chain rule . . . . . . . . . . 167
8.5.2 Computation chains . . . . . . . . . . . . . . . . . 169
8.5.3 Fan-in and fan-out . . . . . . . . . . . . . . . . . 170
8.6 Block diagonal approximations . . . . . . . . . . . . . . . 171
8.6.1 Feedforward networks . . . . . . . . . . . . . . . . 171
8.6.2 Computation graphs . . . . . . . . . . . . . . . . . 173
8.7 Diagonal approximations . . . . . . . . . . . . . . . . . . 173
8.7.1 Computation chains . . . . . . . . . . . . . . . . . 174
8.7.2 Computation graphs . . . . . . . . . . . . . . . . . 175
8.8 Randomized estimators . . . . . . . . . . . . . . . . . . . 176
8.8.1 Girard-Hutchinson estimator . . . . . . . . . . . . 176
8.8.2 Bartlett estimator for the factorization . . . . . . . 177
8.8.3 Bartlett estimator for the diagonal . . . . . . . . . 178
8.9 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 179
9 Inference in graphical models as differentiation 180
9.1 Chain rule of probability . . . . . . . . . . . . . . . . . . . 180
9.2 Conditional independence . . . . . . . . . . . . . . . . . . 181
9.3 Inference problems . . . . . . . . . . . . . . . . . . . . . . 182
9.3.1 Joint probability distributions . . . . . . . . . . . . 182
9.3.2 Likelihood . . . . . . . . . . . . . . . . . . . . . . 182
9.3.3 Maximum a-posteriori inference . . . . . . . . . . . 182
9.3.4 Marginal inference . . . . . . . . . . . . . . . . . . 183
9.3.5 Expectation, convex hull, marginal polytope . . . . 183
9.3.6 Complexity of brute force . . . . . . . . . . . . . . 185
9.4 Markov chains . . . . . . . . . . . . . . . . . . . . . . . . 185
9.4.1 The Markov property . . . . . . . . . . . . . . . . 186
9.4.2 Time-homogeneous Markov chains . . . . . . . . . 188
9.4.3 Higher-order Markov chains . . . . . . . . . . . . . 189
9.5 Bayesian networks . . . . . . . . . . . . . . . . . . . . . . 189
9.5.1 Expressing variable dependencies using DAGs . . . 189
9.5.2 Parameterizing Bayesian networks . . . . . . . . . 190
9.5.3 Ancestral sampling . . . . . . . . . . . . . . . . . 191
9.6 Markov random fields . . . . . . . . . . . . . . . . . . . . 191
9.6.1 Expressing factors using undirected graphs . . . . . 191
9.6.2 MRFs as exponential family distributions . . . . . . 192
9.6.3 Conditional random fields . . . . . . . . . . . . . . 194
9.6.4 Sampling . . . . . . . . . . . . . . . . . . . . . . . 194
9.7 Inference on chains . . . . . . . . . . . . . . . . . . . . . 194
9.7.1 The forward-backward algorithm . . . . . . . . . . 195
9.7.2 The Viterbi algorithm . . . . . . . . . . . . . . . . 196
9.8 Inference on trees . . . . . . . . . . . . . . . . . . . . . . 198
9.9 Inference as differentiation . . . . . . . . . . . . . . . . . 199
9.9.1 Inference as gradient of the log-partition . . . . . . 199
9.9.2 Semirings and softmax operators . . . . . . . . . . 200
9.9.3 Inference as backpropagation . . . . . . . . . . . . 202
9.10 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 204

10 Differentiating through optimization 205


10.1 Implicit functions . . . . . . . . . . . . . . . . . . . . . . 205
10.1.1 Optimization problems . . . . . . . . . . . . . . . 206
10.1.2 Nonlinear equations . . . . . . . . . . . . . . . . . 206
10.1.3 Application to bilevel optimization . . . . . . . . . 206
10.2 Envelope theorems . . . . . . . . . . . . . . . . . . . . . . 207
10.2.1 Danskin’s theorem . . . . . . . . . . . . . . . . . . 208
10.2.2 Rockafellar’s theorem . . . . . . . . . . . . . . . . 209
10.3 Implicit function theorem . . . . . . . . . . . . . . . . . . 210
10.3.1 Univariate functions . . . . . . . . . . . . . . . . . 210
10.3.2 Multivariate functions . . . . . . . . . . . . . . . . 211
10.3.3 JVP and VJP of implicit functions . . . . . . . . . 213
10.3.4 Proof of the implicit function theorem . . . . . . . 214
10.4 Adjoint state method . . . . . . . . . . . . . . . . . . . . 214
10.4.1 Differentiating nonlinear equations . . . . . . . . . 214
10.4.2 Relation with envelope theorems . . . . . . . . . . 216
10.4.3 Proof using the method of Lagrange multipliers . . 216
10.4.4 Proof using the implicit function theorem . . . . . 217
10.4.5 Reverse mode as adjoint method with backsubstitution217
10.5 Inverse function theorem . . . . . . . . . . . . . . . . . . 220
10.5.1 Differentiating inverse functions . . . . . . . . . . 220
10.5.2 Link with the implicit function theorem . . . . . . 220
10.5.3 Proof of inverse function theorem . . . . . . . . . 221
10.6 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 222

11 Differentiating through integration 224


11.1 Differentiation under the integral sign . . . . . . . . . . . 224
11.2 Differentiating through expectations . . . . . . . . . . . . 225
11.2.1 The easy case . . . . . . . . . . . . . . . . . . . . 226
11.2.2 Exact gradients . . . . . . . . . . . . . . . . . . . 226
11.2.3 Application to expected loss functions . . . . . . . 227
11.2.4 Application to experimental design . . . . . . . . . 228
11.3 Score function estimators, REINFORCE . . . . . . . . . . 229
11.3.1 Scalar-valued functions . . . . . . . . . . . . . . . 229
11.3.2 Variance reduction . . . . . . . . . . . . . . . . . . 231
11.3.3 Vector-valued functions . . . . . . . . . . . . . . . 233
11.3.4 Second derivatives . . . . . . . . . . . . . . . . . . 234
11.4 Path gradient estimators, reparametrization trick . . . . . 235
11.4.1 Location-scale transforms . . . . . . . . . . . . . . 235
11.4.2 Inverse transforms . . . . . . . . . . . . . . . . . . 236
11.4.3 Pushforward operators . . . . . . . . . . . . . . . 238
11.4.4 Change-of-variables theorem . . . . . . . . . . . . 240
11.5 Stochastic programs . . . . . . . . . . . . . . . . . . . . . 240
11.5.1 Stochastic computation graphs . . . . . . . . . . . 241
11.5.2 Examples . . . . . . . . . . . . . . . . . . . . . . 243
11.5.3 Unbiased gradient estimators . . . . . . . . . . . . 245
11.5.4 Local vs. global expectations . . . . . . . . . . . . 247
11.6 Differential equations . . . . . . . . . . . . . . . . . . . . 248
11.6.1 Parameterized differential equations . . . . . . . . 248
11.6.2 Continuous adjoint method . . . . . . . . . . . . . 251
11.6.3 Gradients via the continuous adjoint method . . . . 252
11.6.4 Gradients via reverse-mode on discretization . . . . 254
11.6.5 Reversible discretization schemes . . . . . . . . . . 255
11.6.6 Proof of the continuous adjoint method . . . . . . 257
11.7 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 259

IV Smoothing programs 261

12 Smoothing by optimization 262


12.1 Primal approach . . . . . . . . . . . . . . . . . . . . . . . 262
12.1.1 Infimal convolution . . . . . . . . . . . . . . . . . 262
12.1.2 Moreau envelope . . . . . . . . . . . . . . . . . . 263
12.2 Legendre–Fenchel transforms, convex conjugates . . . . . . 265
12.2.1 Definition . . . . . . . . . . . . . . . . . . . . . . 265
12.2.2 Closed-form examples . . . . . . . . . . . . . . . . 266
12.2.3 Properties . . . . . . . . . . . . . . . . . . . . . . 267
12.2.4 Conjugate calculus . . . . . . . . . . . . . . . . . 269
12.2.5 Fast Legendre transform . . . . . . . . . . . . . . 270
12.3 Dual approach . . . . . . . . . . . . . . . . . . . . . . . . 270
12.3.1 Duality between strong convexity and smoothness . 270
12.3.2 Smoothing by dual regularization . . . . . . . . . . 271
12.3.3 Equivalence between primal and dual regularizations 273
12.4 Examples . . . . . . . . . . . . . . . . . . . . . . . . . . . 273
12.4.1 Smoothed ReLU functions . . . . . . . . . . . . . 273
12.4.2 Smoothed max operators . . . . . . . . . . . . . . 274
12.4.3 Relaxed step functions (sigmoids) . . . . . . . . . 276
12.4.4 Relaxed argmax operators . . . . . . . . . . . . . . 277
12.5 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 278

13 Smoothing by integration 279


13.1 Convolution . . . . . . . . . . . . . . . . . . . . . . . . . 279
13.1.1 Convolution operators . . . . . . . . . . . . . . . . 279
13.1.2 Convolution with a kernel . . . . . . . . . . . . . . 280
13.1.3 Discrete convolution . . . . . . . . . . . . . . . . . 281
13.1.4 Differentiation . . . . . . . . . . . . . . . . . . . . 283
13.1.5 Multidimensional convolution . . . . . . . . . . . . 283
13.1.6 Link between convolution and infimal convolution . 283
13.2 Fourier and Laplace transforms . . . . . . . . . . . . . . . 284
13.2.1 Convolution theorem . . . . . . . . . . . . . . . . 284
13.2.2 Link between Fourier and Legendre transforms . . . 285
13.2.3 The soft Legendre-Fenchel transform . . . . . . . . 285
13.3 Examples . . . . . . . . . . . . . . . . . . . . . . . . . . . 289
13.3.1 Smoothed step function . . . . . . . . . . . . . . . 289
13.3.2 Smoothed ReLU function . . . . . . . . . . . . . . 290
13.4 Perturbation of blackbox functions . . . . . . . . . . . . . 291
13.4.1 Expectation in a location-scale family . . . . . . . 291
13.4.2 Gradient estimation by reparametrization . . . . . 292
13.4.3 Gradient estimation by SFE, Stein’s lemma . . . . 293
13.4.4 Link between reparametrization and SFE . . . . . . 294
13.4.5 Variance reduction and evolution strategies . . . . 295
13.4.6 Zero-temperature limit . . . . . . . . . . . . . . . 296
13.5 Gumbel tricks . . . . . . . . . . . . . . . . . . . . . . . . 296
13.5.1 The Gumbel distribution . . . . . . . . . . . . . . 296
13.5.2 Perturbed comparison . . . . . . . . . . . . . . . . 297
13.5.3 Perturbed argmax . . . . . . . . . . . . . . . . . . 298
13.5.4 Perturbed max . . . . . . . . . . . . . . . . . . . . 300
13.5.5 Gumbel trick for sampling . . . . . . . . . . . . . . 301
13.5.6 Perturb-and-MAP . . . . . . . . . . . . . . . . . . 301
13.5.7 Gumbel-softmax . . . . . . . . . . . . . . . . . . . 303
13.6 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 304
V Optimizing differentiable programs 306

14 Optimization basics 307


14.1 Objective functions . . . . . . . . . . . . . . . . . . . . . 307
14.2 Oracles . . . . . . . . . . . . . . . . . . . . . . . . . . . . 308
14.3 Variational perspective of optimization algorithms . . . . . 309
14.4 Classes of functions . . . . . . . . . . . . . . . . . . . . . 309
14.4.1 Lipschitz functions . . . . . . . . . . . . . . . . . 309
14.4.2 Smooth functions . . . . . . . . . . . . . . . . . . 310
14.4.3 Convex functions . . . . . . . . . . . . . . . . . . 312
14.4.4 Strongly-convex functions . . . . . . . . . . . . . . 314
14.4.5 Nonconvex functions . . . . . . . . . . . . . . . . 315
14.5 Performance guarantees . . . . . . . . . . . . . . . . . . . 316
14.6 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 319

15 First-order optimization 320


15.1 Gradient descent . . . . . . . . . . . . . . . . . . . . . . . 320
15.1.1 Variational perspective . . . . . . . . . . . . . . . 320
15.1.2 Convergence for smooth functions . . . . . . . . . 321
15.1.3 Momentum and accelerated variants . . . . . . . . 323
15.2 Stochastic gradient descent . . . . . . . . . . . . . . . . . 323
15.2.1 Stochastic gradients . . . . . . . . . . . . . . . . . 324
15.2.2 Vanilla SGD . . . . . . . . . . . . . . . . . . . . . 325
15.2.3 Momentum variants . . . . . . . . . . . . . . . . . 326
15.2.4 Adaptive variants . . . . . . . . . . . . . . . . . . 327
15.3 Projected gradient descent . . . . . . . . . . . . . . . . . 328
15.3.1 Variational perspective . . . . . . . . . . . . . . . 328
15.3.2 Optimality conditions . . . . . . . . . . . . . . . . 329
15.3.3 Commonly-used projections . . . . . . . . . . . . . 329
15.4 Proximal gradient method . . . . . . . . . . . . . . . . . . 330
15.4.1 Variational perspective . . . . . . . . . . . . . . . 331
15.4.2 Optimality conditions . . . . . . . . . . . . . . . . 331
15.4.3 Commonly-used proximal operators . . . . . . . . . 332
15.5 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 333
16 Second-order optimization 334
16.1 Newton’s method . . . . . . . . . . . . . . . . . . . . . . 334
16.1.1 Variational perspective . . . . . . . . . . . . . . . 334
16.1.2 Regularized Newton method . . . . . . . . . . . . 335
16.1.3 Approximate direction . . . . . . . . . . . . . . . . 336
16.1.4 Convergence guarantees . . . . . . . . . . . . . . . 336
16.1.5 Linesearch . . . . . . . . . . . . . . . . . . . . . . 336
16.1.6 Geometric interpretation . . . . . . . . . . . . . . 337
16.1.7 Stochastic Newton’s method . . . . . . . . . . . . 338
16.2 Gauss-Newton method . . . . . . . . . . . . . . . . . . . 339
16.2.1 With exact outer function . . . . . . . . . . . . . . 340
16.2.2 With approximate outer function . . . . . . . . . . 341
16.2.3 Linesearch . . . . . . . . . . . . . . . . . . . . . . 342
16.2.4 Stochastic Gauss-Newton . . . . . . . . . . . . . . 342
16.3 Natural gradient descent . . . . . . . . . . . . . . . . . . 343
16.3.1 Variational perspective . . . . . . . . . . . . . . . 343
16.3.2 Stochastic natural gradient descent . . . . . . . . . 344
16.4 Quasi-Newton methods . . . . . . . . . . . . . . . . . . . 345
16.4.1 BFGS . . . . . . . . . . . . . . . . . . . . . . . . 345
16.4.2 Limited-memory BFGS . . . . . . . . . . . . . . . 346
16.5 Approximate Hessian diagonal inverse preconditionners . . 346
16.6 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 346

17 Duality 348
17.1 Dual norms . . . . . . . . . . . . . . . . . . . . . . . . . 348
17.2 Fenchel duality . . . . . . . . . . . . . . . . . . . . . . . . 349
17.3 Bregman divergences . . . . . . . . . . . . . . . . . . . . 352
17.4 Fenchel-Young loss functions . . . . . . . . . . . . . . . . 355
17.5 Summary . . . . . . . . . . . . . . . . . . . . . . . . . . . 356

References 357
The Elements of
Differentiable Programming
Mathieu Blondel1 and Vincent Roulet1
1 Google DeepMind

ABSTRACT
Artificial intelligence has recently experienced remarkable
advances, fueled by large models, vast datasets, acceler-
ated hardware, and, last but not least, the transformative
power of differentiable programming. This new programming
paradigm enables end-to-end differentiation of complex com-
puter programs (including those with control flows and data
structures), making gradient-based optimization of program
parameters possible.
As an emerging paradigm, differentiable programming builds
upon several areas of computer science and applied mathe-
matics, including automatic differentiation, graphical mod-
els, optimization and statistics. This book presents a com-
prehensive review of the fundamental concepts useful for
differentiable programming. We adopt two main perspec-
tives, that of optimization and that of probability, with clear
analogies between the two.
Differentiable programming is not merely the differentiation
of programs, but also the thoughtful design of programs
intended for differentiation. By making programs differen-
tiable, we inherently introduce probability distributions over
their execution, providing a means to quantify the uncer-
tainty associated with program outputs.
Notation

Table 1: Naming conventions

Notation Description
X ⊆R D
Input space (e.g., features)
Y ⊆ RM Output space (e.g., classes)
Sk ⊆ RDk Output space on layer or state k
W ⊆ RP Weight space
Λ ⊆ RQ Hyperparameter space
Θ ⊆ RR Distribution parameter space, logit space
N Number of training samples
T Number of optimization iterations
x∈X Input vector
y∈Y Target vector
sk ∈ Sk State vector k
w∈W Network (model) weights
λ∈Λ Hyperparameters
θ∈Θ Distribution parameters, logits
π ∈ [0, 1] Probability value
π ∈ △M Probability vector

2
3

Table 2: Naming conventions (continued)

Notation Description
f Network function
f (·; x) Network function with x fixed
L Objective function
ℓ Loss function
κ Kernel function
ϕ Output embedding, sufficient statistic
step Heaviside step function
logisticσ Logistic function with temperature σ
logistic Shorthand for logistic1
pθ Model distribution with parameters θ
ρ Data distribution over X × Y
ρX Data distribution over X
µ, σ 2 Mean and variance
Z Random noise variable
1
Introduction

1.1 What is differentiable programming?

A computer program is a sequence of elementary instructions for per-


forming a task. In traditional programming, the program is typically
hand-written by a programmer. However, for certain tasks, such as im-
age recognition or text generation, hand-writing a program to perform
such tasks is nearly impossible.
This has motivated the need for statistical approaches based on
machine learning. With differentiable programming, while the overall
structure of the program is typically designed by a human, parameters of
the program (such as weights in a neural network) can be automatically
adjusted to achieve a task or optimize a criterion. This paradigm has
also been referred to as “software 2.0”. We give an informal definition.

Definition 1.1 (Differentiable programming). Differentiable program-


ming is a programming paradigm in which complex computer pro-
grams (including those with control flows and data structures) can
be differentiated end-to-end automatically, enabling gradient-based
optimization of parameters in the program.

In differentiable programming, a program is also defined as the

4
1.1. What is differentiable programming? 5

composition of elementary operations, forming a computation graph.


The key difference with classical computer programming is that the
program can be differentiated end-to-end, using automatic differenti-
ation (autodiff). Typically, it is assumed that the program defines a
mathematically valid function (a.k.a. pure function): the function
should return identical values for identical arguments and should not
have any side effects. Moreover, the function should have well-defined
derivatives, ensuring that it can be used in a gradient-based optimiza-
tion algorithm. Therefore, differentiable programming is not only the art
of differentiating through programs but also of designing meaningful
differentiable programs.

Why are derivatives important?

Machine learning typically boils down to optimizing a certain objec-


tive function, which is the composition of a loss function and a model
(network) function. Derivative-free optimization is called zero-order
optimization. It only assumes that we can evaluate the objective
function that we wish to optimize. Unfortunately, it is known to suffer
from the curse of dimensionality, i.e., it only scales to small di-
mensional problems, such as less than 10 dimensions. Derivative-based
optimization, on the other hand, is much more efficient and can scale to
millions or billions of parameters. Algorithms that use first and second
derivatives are known as first-order and second-order algorithms,
respectively.

Why is autodiff important?

Before the autodiff revolution, researchers and practitioners needed


to manually implement the gradient of the functions they wished to
optimize. Manually deriving gradients can become very tedious for
complicated functions. Moreover, every time the function is changed
(for example, for trying out a new idea), the gradient needs to be re-
derived. Autodiff is a game changer because it allows users to focus on
quickly and creatively experimenting with functions for their tasks.
6 Introduction

Differentiable programming is not just deep learning

While there is clearly overlap between deep learning and differentiable


programming, their focus is different. Deep learning studies artificial
neural networks composed of multiple layers, able to learn intermedi-
ate representations of the data. Neural network architectures have
been proposed with various inductive biases. For example, convolu-
tional neural networks are designed for images and transformers are
designed for sequences. On the other hand, differentiable programming
studies the techniques to differentiate through complex programs. It
is useful beyond deep learning: for instance in reinforcement learning,
probabilistic programming and scientific computing in general.

Differentiable programming is not just autodiff

While autodiff is a key ingredient of differentiable programming, this


is not the only one. Differentiable programming is also concerned with
the design of principled differentiable operations. In fact, much research
on differentiable programming has been devoted to make classifical
computer programming operations compatible with autodiff. As we
shall see, many differentiable relaxations can be interpreted in a proba-
bilistic framework. A core theme of this book is the interplay between
optimization, probability and differentiation. Differentiation is useful
for optimization and conversely, optimization can be used to design
differentiable operators.

1.2 Book goals and scope

The present book aims to provide an comprehensive introduction to


differentiable programming with an emphasis on core mathematical
tools.

• In Part I, we review fundamentals: differentiation and proba-


bilistic learning.

• In Part II, we review differentiable programs. This includes


neural networks, sequence networks and control flows.
1.3. Intended audience 7

• In Part III, we review how to differentiate through programs.


This includes automatic differentiation, but also differentiating
through optimization and integration (in particular, expectations).

• In Part IV, we review smoothing programs. We focus on two


main techniques: infimal convolution, which comes from the world
of optimization and convolution, which comes from the world of
integration. We also strive to spell out the connections between
them.

• In Part V, we review optimizing programs: basic optimiza-


tion concepts, first-order algorithms, second-order-algorithms and
duality.

Our goal is to present the fundamental techniques useful for differentiable


programming, not to survey how these techniques have been used in
various applications.

1.3 Intended audience

This book is intended to be a graduate-level introduction to differentiable


programming. Our pedagogical choices are made with the machine
learning community in mind. Some familiarity with calculus, linear
algebra, probability theory and machine learning is beneficial.

1.4 How to read this book?

This book does not need to be read linearly chapter by chapter. When
needed, we indicate at the beginning of a chapter what chapters are
recommended to be read as a prerequisite.

1.5 Related work

Differentiable programming builds upon a variety of connected topics.


We review in this section relevant textbooks, tutorials and software.
Standard textbooks on backpropagation and automatic differenti-
ation are that of Werbos (1994) and Griewank and Walther (2008).
8 Introduction

A tutorial with a focus on machine learning is provided by Baydin


et al. (2018). Automatic differentiation is also reviewed as part of more
general textbooks, such as those of Deisenroth et al. (2020), Murphy
(2022) (from a linear algebra perspective) and Murphy (2023) (from a
functional perspective; autodiff section authored by Roy Frostig). The
present book was also influenced by Peyré (2020)’s textbook on data
science. The history of reverse-mode autodiff is reviewed by Griewank
(2012).
A tutorial on different perspectives of backpropagation is “There and
Back Again: A Tale of Slopes and Expectations” (link), by Deisenroth
and Ong. A tutorial on implicit differentiation is “Deep Implicit Layers -
Neural ODEs, Deep Equilibirum Models, and Beyond” (link), by Kolter,
Duvenaud, and Johnson.
The standard reference on inference in graphical models and its
connection with exponential families is that of Wainwright and Jor-
dan (2008). Differential programming is also related to probabilistic
programming; see, e.g., Meent et al. (2018).
A review of smoothing from the infimal convolution perspective is
provided by Beck and Teboulle (2012). A standard textbook on convex
optimization is that of Nesterov (2018). A textbook on first-order
optimization methods is that of Beck (2017).
Autodiff implementations that accelerated the autodiff revolution
in machine learning are Theano (Bergstra et al., 2010) and Autograd
(Maclaurin et al., 2015). Major modern implementations of autodiff
include Tensorflow (Abadi et al., 2016), JAX (Bradbury et al., 2018),
and PyTorch (Paszke et al., 2019). We in particular acknowledge the
JAX team for influencing our view of autodiff.
Part I

Fundamentals
2
Differentiation

In this chapter, we review key differentiation concepts. In particular,


we emphasize on the fundamental role played by linear maps.

2.1 Univariate functions

2.1.1 Derivatives

Before studying derivatives, we briefly recall the definition of function


continuity.

Definition 2.1 (Continuous function). A function f : R → R is


continuous at a point w ∈ R if

lim f (v) = f (w).


v→w

A function f is said to be continuous if it is continuous at all points


in its domain.

In the following, we use Landau’s little o notation. We write

g(v) = o(f (v)) as v → w

10
2.1. Univariate functions 11

if
|g(v)|
lim = 0.
v→w |f (v)|

That is, the function f dominates g in the limit v → w. For example, f


is continuous at w if and only if

f (w + δ) = f (w) + o(1) as δ → 0.

We now explain derivatives. Consider a function f : R → R. As


illustrated in Fig. 2.1, its value on an interval [w0 , w0 + δ] can be
approximated by the secant between its values f (w0 ) and f (w0 + δ),
a linear function with slope (f (w0 + δ) − f (w0 ))/δ. In the limit of an
infinitesimal variation δ around w0 , the secant converges to the tangent
of f at w0 and the resulting slope defines the derivative of f at w0 . The
definition below formalizes this intuition.

Definition 2.2 (Derivative). The derivative of f : R → R at w ∈


R is defined as
f (w + δ) − f (w)
f ′ (w) := lim , (2.1)
δ→0 δ
provided that the limit exists. If f ′ (w) is well-defined at a particular
w, we say that the function f is differentiable at w.

Here, and in the following definitions, if f is differentiable at any


w ∈ R, we say that it is differentiable everywhere or differentiable
for short. If f is differentiable at a given w, then it is necessarily
continuous at w.

Proposition 2.1 (Differentiability implies continuity). If f : R → R


is differentiable at w ∈ R, then it is continuous at w ∈ R.

Proof. In little o notation, f is differentiable at w if there exists f ′ (w) ∈


R, such that

f (w + δ) = f (w) + f ′ (w)δ + o(δ) as δ → 0.

Since f ′ (w)δ + o(δ) = o(1) as δ → 0, f is continuous at w.


12 Differentiation

Figure 2.1: A function f can be locally approximated around a point w0 by a secant,


a linear function w 7→ aw + b with slope a and intercept b, crossing f at w0 with value
y0 = f (w0 ) and crossing at w0 + δ with value uδ = f (w0 + δ). Using u0 = aw0 + b
and uδ = a(w0 + δ) + b, we find that its slope is a = (f (w0 + δ) − f (w0 ))/δ and the
intercept is b = f (w0 ) − aw0 . The derivative f ′ (w) of a function f at a point w0 is
then defined as the limit of the slope a when δ → 0. It is the slope of the tangent
of f at w0 . The value f (w) of the function at w can then be locally approximated
around w0 by w 7→ f ′ (w0 )w + f (w0 ) − f ′ (w0 )w0 = f (w0 ) + f ′ (w0 )(w − w0 ).

In addition to enabling the construction of a linear approximation


of f in a neighborhood of w, since it is the slope of the tangent of f at
w, the derivative f ′ informs us about the monotonicity of f around
w. If f ′ (w) is positive, the function is increasing around w. Conversely,
if f ′ (w) is negative, the function is decreasing. Such information can be
used to develop iterative algorithms seeking to minimize f by computing
iterates of the form wt+1 = wt − γf ′ (wt ) for γ > 0, which move along
descent directions of f around wt .
For several elementary functions such as wn , ew , ln w, cos w or sin w,
their derivatives can be obtained directly by applying the definition of
the derivative in Eq. (2.1) as illustrated in Example 2.1.

Example 2.1 (Derivative of power function). Consider f (w) = wn


2.1. Univariate functions 13

for w ∈ R, n ∈ N \ {0}. For any δ ∈ R, we have


f (w + δ) − f (w) (w + δ)n − wn
=
δ δ
Pn n k n−k
δ w − wn
= k=0 k
n
! δ
n k−1 n−k
=
X
δ w
k=1
k
n
! !
n n−1 X n k−1 n−k
= w + δ w ,
1 k=2
k

where, in the second line, we used the binomial theorem. Since


n
1 = n and limδ→0
n k−1 n−k
= 0, we get f ′ (w) = nwn−1 .
Pn
k=2 k δ w

Remark 2.1 (Functions on a subset U of R). For simplicity, we pre-


sented the definition of the derivative for a function defined on the
whole set of real numbers R. If a function f : U → R is defined on a

subset U ⊆ R of the real numbers, as it is the case for f (w) = w
defined on U = R+ , the derivative of f at w ∈ U is defined by
the limit in (2.1) provided that the function f is well defined on a
neighborhood of w, that is, there exists r > 0 such that w + δ ∈ U
for any |δ| ≤ r. The function f is then said differentiable ev-
erywhere or differentiable for short if it is differentiable at any
point w in the interior of U, the set of points w ∈ U such that
{w + δ : |δ| ≤ r} ⊆ U for r sufficiently small. For points lying at the
boundary of U (such as a and b if U = [a, b]), one may define the
right and left derivatives of f at a and b, meaning that the limit is
taken by approaching a from the right or b from the left.

2.1.2 Calculus rules

For a given w ∈ R and two functions f : R → R and g : R → R, the


derivative of elementary operations on f and g such as their sums,
products or compositions can easily be derived from the definition of
the derivative, under appropriate conditions on the differentiability
properties of f and g at w. For example, if the derivative of f and g
14 Differentiation

exist at w then the derivative of their weighted sum or the derivatives


or their products exist and lead to the following rules

∀a, b ∈ R, (af + bg)′ (w) = af ′ (w) + bg ′ (w) (Linearity)


′ ′ ′
(f g) (w) = f (w)g(w) + f (w)g (w), (Product rule)

where (f g)(w) = f (w)g(w). The linearity can be verified directly from


the linearity of the limits. For the product rule, in little o notation, we
have, as δ → 0,

(f g)(w + δ) = (f (w) + f ′ (w)δ + o(δ))(g(w) + g ′ (w)δ + o(δ))


= f (w)g(w) + f ′ (w)g(w)δ + f (w)g ′ (w)δ + o(δ),

hence the result.


If the derivatives of g at w and of f at g(w) exist, then the derivative
of the composition (f ◦ g)(w) := f (g(w)) at w exist and is given by

(f ◦ g)′ (w) = f ′ (g(w))g ′ (w). (Chain rule)

We prove this result more generally in Proposition 2.2. As seen in the


sequel, the linearity and the product rule can be seen as byproducts of
the chain rule, making the chain rule the cornerstone of differentiation.
For now, consider a function that can be expressed as sums, products
or compositions of elementary functions such as f (w) = ew ln w + cos w2 .
Its derivative can be computed by applying the aforementioned rules
on the decomposition of f into elementary operations and functions, as
illustrated in Example 2.2.

Example 2.2 (Applying rules of differentiation). Consider f (w) =


ew ln w + cos w2 . The derivative of f on w > 0 can be computed
2.1. Univariate functions 15

step by step as follows, denoting sq(w) := w2 ,

f ′ (w) = (exp · ln)′ (w) + (cos ◦ sq)′ (w) (Linearity)


′ ′ ′
(exp · ln) (w) = exp (w) · ln(w) + exp(w) · ln (w) (Product rule)
′ ′ ′
(cos ◦ sq) (w) = cos (sq(w)) sq (w) (Chain rule)
′ ′
exp (w) = exp(w), ln (w) = 1/w, (Elem. func.)
′ ′
sq (w) = 2w, cos (w) = − sin(w). (Elem. func.)

We obtain then that f ′ (w) = ew ln w + ew /w − 2w sin w2 .

Such a process is purely mechanical and lends itself to an automated


procedure, which is the main idea of automatic differentiation presented
in Chapter 7.

2.1.3 Leibniz’s notation


The notion of derivative was first introduced independently by Newton
and Leibniz in the 18th century (Ball, 1960). The latter considered
derivatives as the quotient of infinitesimal variations. Namely, denoting
u = f (w) a variable depending on w through f , Leibniz considered the
derivative of f as the quotient
du du
f′ = with f ′ (w) =
dw dw w

where du and dw denote infinitesimal variations of u and w respectively


and the symbol |w denotes the evaluation of the derivative at a given
point w. This notation simplifies the statement of the chain rule first
discovered by Leibniz (Rodriguez and Lopez Fernandez, 2010) as we
have for v = g(w) and u = f (v)
du du dv
= · .
dw dv dw
This hints that derivatives are multiplied when considering compositions.
At evaluation, the chain rule in Leibniz notation recovers the formula
presented above as
du du dv
= = f ′ (g(w))g ′ (w) = (f ◦ g)′ (w).
dw w dv g(w) dw w
16 Differentiation

The ability of Leibniz’s notation to capture the chain rule as a mere


product of quotients made it popular throughout the centuries, especially
in mechanics (Ball, 1960). The rationale behind Leibniz’s notation,
that is, the concept of “infinitesimal variations” was questioned by
later mathematicians for its potential logical issues (Ball, 1960). The
notation f ′ (w) first introduced by Euler and further popularized by
Lagrange (Cajori, 1993) has then taken over in numerous mathematical
textbooks. The concept of infinitesimal variations has been rigorously
defined by considering the set of hyperreal numbers. They extend
the set of real numbers by considering each number as a sum of a
non-infinitesimal part and an infinitesimal part (Hewitt, 1948). The
formalism of infinitesimal variations further underlies the development
of automatic differentiation algorithms through the concept of dual
numbers.

2.2 Multivariate functions

2.2.1 Directional derivatives


Let us now consider a function f : RP → R of multiple inputs w =
(w1 , . . . , wP ) ∈ RP . The most important example in machine learning
is a function which, to the parameters w ∈ RP of a neural network,
associates a loss value in R. Variations of f need to be defined along
specific directions, such as the variation f (w + δv)−f (w) of f around
w ∈ RP in the direction v ∈ RP by an amount δ > 0. This consideration
naturally leads to the definition of the directional derivative.
Definition 2.3 (Directional derivative). The directional derivative
of f at w in the direction v is given by
f (w + δv) − f (w)
∂f (w)[v] := lim ,
δ→0 δ
provided that the limit exists.

One example of directional derivative consists in computing the


derivative of a function f at w in any of the canonical directions
ei := (0, . . . , 0, |{z}
1 , 0, . . . , 0).
i
2.2. Multivariate functions 17

This allows us to define the notion of partial derivatives, denoted for


i ∈ [P ]
f (w + δei ) − f (w)
∂i f (w) := ∂f (w)[ei ] = lim .
δ→0 δ
This is also denoted in Leibniz’s notation as ∂i f (w) = ∂f∂w (w)
i
or ∂i f (w) =
∂wi f (w). By moving along only the i coordinate of the function, the
th

partial derivative is akin to using the function ϕ(ωi ) = f (w1 , . . . , ωi , . . . , wP )


around ωi , letting all other coordinates fixed at their values wi .

2.2.2 Gradients
We now introduce the gradient vector, which gathers the partial deriva-
tives. We first recall the definitions of linear map and linear form.

Definition 2.4 (Linear map, linear form). A function l : RP → RM


is a linear map if for any a1 , a2 ∈ R, v1 , v2 ∈ RD ,

l[a1 v1 + a2 v2 ] = a1 l(v1 ) + a2 l[v2 ].

A linear map with values in R, l : RP → R, is called a linear form.

Linearity plays a crucial role in the differentiability of a function.

Definition 2.5 (Differentiability, single-output case). A function f :


RP → R is differentiable at w ∈ RP if its directional derivative
is defined along any direction, linear in any direction, and if
|f (w + v) − f (w) − ∂f (w)[v]|
lim = 0.
∥v∥2 →0 ∥v∥2

We can now introduce the gradient.

Definition 2.6 (Gradient). The gradient of a differentiable func-


tion f : RP → R at a point w ∈ RP is defined as the vector of
partial derivatives

∂1 f (w) ∂f (w)[e1 ]
   

∇f (w) :=  .
.. ..
=
   
  
. .

∂P f (w) ∂f (w)[eP ]
18 Differentiation

By linearity, the directional derivative of f at w in the direction


v = Pi=1 vi ei is then given by
P

P
∂f (w)[v] = vi ∂f (w)[ei ] = ⟨v, ∇f (w)⟩.
X

i=1

In the definition above, the fact that the gradient can be used to
compute the directional derivative is a mere consequence of the linearity.
However, in more abstract cases presented in later sections, the gradient
is defined through this property.
As a simple example, any linear function of the form f (w) = a⊤ w =
i=1 ai wi is differentiable as we have (a (w +v)−a w −a v)/∥v∥2 =
PP ⊤ ⊤ ⊤

0 for any v and in particular for ∥v∥ → 0. Moreover, its gradient is


naturally given by ∇f (w) = a.
Generally, to show that a function is differentiable and find its
gradient, one approach is to approximate f (w + v) around v = 0. If we
can find a vector g such that
f (w + v) = f (w) + ⟨g, v⟩ + o(∥v∥2 ),
then f is differentiable at w since ⟨g, ·⟩ is linear. Moreover, g is then
the gradient of f at w.

Remark 2.2 (Gateaux and Fréchet differentiability). Multiple defini-


tions of differentiability exist. The one presented in Definition 2.5 is
about Fréchet differentiable functions. Alternatively, if f : RP →
R has well-defined directional derivatives along any directions then
the function is Gateaux differentiable. Note that the existence of
directional derivatives in any directions is not a sufficient condition
for the function to be differentiable. In other words, any Fréchet
differentiable function is Gateaux differentiable, but the converse
is not true. As a counter-example, one can verify that the function
f (x1 , x2 ) = x31 /(x21 + x22 ) is Gateaux differentiable at 0 but not
(Fréchet) differentiable at 0 (because the directional derivative at 0
is not linear).
Some authors also require Gateaux differentiable functions to
have linear directional derivatives along any direction. These are
2.2. Multivariate functions 19

still not Fréchet differentiable functions. Indeed, the limit in Defini-


tion 2.5 is over any vectors tending to 0 (potentially in a pathological
way), while directional derivatives look at such limits uniquely in
terms of a single direction.
In the remainder of this chapter, all definitions of differentiability
are in terms of Fréchet differentiability.

Example 2.3 illustrates how to compute the gradient of the logistic


loss and validate its differentiability.

Example 2.3 (Gradient of logistic loss). Consider the logistic loss


ℓ(θ, y) := −y ⊤ θ + log M
i=1 e , that measures the prediction error
θi
P

of the logits θ ∈ RM w.r.t. the correct label y ∈ {e1 , . . . , eM }. Let


us compute the gradient of this loss w.r.t. θ for fixed y, i.e., we
want to compute the gradient of f (θ) := ℓ(θ, y). Let us decompose
f as f = l + logsumexp with l(θ) := ⟨−y, θ⟩ and
M
logsumexp(θ) := log
X
eθi ,
i=1

the log-sum-exp function. The function l is linear so differentiable


with gradient ∇l(θ) = −y. We therefore focus on logsumexp. Denot-
ing exp(θ) = (exp(θ1 ), . . . , exp(θM )), and using that exp(1 + x) =
1 + x + o(x), and log(1 + x) = x + o(x), we get

logsumexp(θ + v) = log(⟨exp(θ + v), 1⟩)


= log(⟨exp(θ), 1⟩ + ⟨exp(θ), v⟩ + o(∥v∥2 ))
exp(θ)
 
= log(⟨exp(θ), 1⟩) + , v + o(∥v∥2 ),
⟨exp(θ), 1⟩
The above decomposition of logsumexp(θ + v) shows that it is
differentiable, and that ∇logsumexp(θ) = softargmax(θ), where
    
M M
softargmax(θ) := eθ1 / 
X X
eθj  , . . . , eθM /  eθj  .
j=1 j=1

In total, we then get that ∇f (θ) = −y + softargmax(θ).


20 Differentiation

Linearity of gradients
The notion of differentiability for multi-inputs functions naturally inher-
its from the linearity of derivatives for single-input functions. For any
u1 , . . . , uM ∈ R and any multi-inputs functions f1 , . . . , fM differentiable
at w, the function u1 f1 + . . . + uM fM is differentiable at w and its
gradient is

∇(u1 f1 + . . . + uM fM )(w) = u1 ∇f1 (w) + . . . + uM ∇fM (w).

Why is the gradient useful?


The gradient defines the steepest ascent direction of f from w. To see
why, notice that

arg max ∂f (w)[v] = arg max ⟨v, ∇f (w)⟩ = ∇f (w)/∥∇f (w)∥2 ,


v∈RP ,∥v∥2 ≤1 v∈RP ,∥v∥2 ≤1

where we assumed ∇f (w) ̸= 0. The gradient ∇f (w) is orthogonal to


the level set of the function (the set of points w sharing the same
value f (w)) and points towards higher values of f as illustrated in
Fig. 2.2. Conversely, the negated gradient −∇f (w) points towards lower
values of f . This observation motivates the development of optimization
algorithms such as gradient descent. It is based on iteratively performing
the update wt+1 = wt − γ∇f (wt ), for γ > 0. It therefore seeks for a
minimizer of f by moving along the direction of steepest descent around
wt given, up to a multiplicative factor, by −∇f (wt ).

2.2.3 Jacobians
Let us now consider a multi-output function f : RP → RM defined by
f (w) := (f1 (w), . . . , fM (w)), where fj : RP → R. A typical example
in machine learning is a neural network. The notion of directional
derivative can be extended to such function by defining it as the vector
composed of the coordinate-wise directional derivatives:
 f (w+δv)−f (w) 
1 1
δ
f (w + δv) − f (w) ..
∂f (w)[v] := lim = lim   ∈ RM ,
 
δ→0 δ δ→0 
. 
fM (w+δv)−fM (w)
δ
2.2. Multivariate functions 21

Figure 2.2: The gradient of a function


f : R2 → R at (w1 , w2 ) is the normal Figure 2.3: The directional derivative
vector to the tangent space of the level of a parameterized curve f : R → R2 at
set Lf (w1 ,w2 ) = {(w1′ , w2′ ) : f (w1′ , w2′ ) = t is the tangent to the curve at the point
f (w1 , w2 )} and points towards points f (w) ∈ R2 .
with higher function values.

where the limits (provided that they exist) are applied coordinate-wise.
The directional derivative of f in the direction v ∈ RP is therefore the
vector that gathers the directional derivative of each fj , i.e., ∂f (w)[w] =
(∂fj (v)[v])M
j=1 . In particular, we can define the partial derivatives of
f at w as the vectors

∂i f1 (w)
 

∂i f (w) := ∂f (w)[ei ] = 
 ..  ∈ RM .

 . 
∂i fM (w)
As for the usual definition of the derivative, the directional derivative
can provide a linear approximation of a function around a current input
as illustrated in Fig. 2.3 for a parameterized curve f : R → R2 .
Just as in the single-output case, differentiability is defined not only
as the existence of directional derivatives in any direction but also by
the linearity in the chosen direction.

Definition 2.7 (Differentiability, multi-output case). A function f :


RP → RM is (Fréchet) differentiable at a point w ∈ RP if its
directional derivative is defined along any directions, linear along
22 Differentiation

any directions, and,


∥f (w + v) − f (w) − ∂f (w)[v]∥2
lim = 0.
∥v∥2 →0 ∥v∥2

The partial derivatives of all function coordinates are gathered in


the Jacobian matrix.

Definition 2.8 (Jacobian). The Jacobian of a differentiable func-


tion f : RP → RM at w is defined as the matrix of all partial
derivatives of all coordinate functions provided they exist,

∂1 f1 (w) . . . ∂P f1 (w)
 

∂f (w) := 
 .. .. ..  ∈ RM ×P .


. . . 
∂1 fM (w) . . . ∂P fM (w)
The Jacobian can be represented by stacking columns of partial
derivatives or rows of gradients,

∇f1 (w)⊤
 

..
 
∂f (w) = ∂1 f (w), . . . , ∂P f (w) = 
 
 . .

∇fM (w)⊤

By linearity, the directional derivative of f at w along any input


direction v = Pi=1 vi ei ∈ RP is then given by
P

P
∂f (w)[v] = vi ∂i f (w) = ∂f (w)v ∈ RM .
X

i=1

Notice that we use bold ∂ to indicate the Jacobian matrix. The


Jacobian matrix naturally generalizes the concepts of derivatives and
gradients presented earlier. As for the single input case, to show that
a function is differentiable, one approach is to approximate f (w + v)
around v = 0. If we find a linear map l such that

f (w + v) = f (w) + l[v] + o(∥v∥2 ),

then f is differentiable at w. Moreover, if l is represented by matrix J


such that l[v] = J v then J = ∂f (w).
2.2. Multivariate functions 23

As a simple example, any linear function f (w) = Aw for A ∈ RM ×P


is differentiable, since all its coordinate-wise components are single-
output linear functions, and the Jacobian of f at any w is given by
∂f (w) = A.

Remark 2.3 (Special cases of the Jacobian). For single-output func-


tions f : RP → R, i.e., M = 1, the Jacobian matrix reduces to a
row vector identified as the transpose of the gradient, i.e.,

∂f (w) = ∇f (w)⊤ ∈ R1×P .

For a single-input function f : R → RM , the Jacobian reduces to a


single column vector of directional derivatives, denoted

f1′ (w)
 

∂f (w) = f ′ (w) := 
 ..  M ×1
 . ∈R .

′ (w)
fM

For a single-input single-output function f : R → R, the Jacobian


reduces to the derivative of f , i.e.,

∂f (w) = f ′ (w) ∈ R.

Example 2.4 illustrates the form of the Jacobian matrix for the
element-wise application of a differentiable function such as the softplus
activation. This example already shows that the Jacobian takes a simple
diagonal matrix form. As a consequence, the directional derivative
associated with this function is simply given by an element-wise product
rather than a full matrix-vector product as suggested in Definition 2.8.
We will revisit this point in Section 2.3.

Example 2.4 (Jacobian matrix of the softplus activation). Consider


the element-wise application of the softplus defined for w ∈ RP by

σ(w1 )
 
 . 
 ..  ∈ R
f (w) :=  where σ(w) := log(1 + ew ).
 P

σ(wP )
24 Differentiation

Since σ is differentiable, each coordinate of this function is differen-


tiable and the overall function is differentiable. The j th coordinate of
f is independent of the ith coordinate of w for i = ̸ j, so ∂i fj (w) = 0
for i ̸= j. For i = j, the result boils down to the derivative of σ
at wj . That is, ∂j fj (w) = σ ′ (wj ), where σ ′ (w) = ew /(1 + ew ). The
Jacobian of f is therefore a diagonal matrix

σ (w1 ) 0 ... 0
 ′ 

.. .. .. 
. .

 0 . 
∂f (w) = diag(σ ′ (w1 ), . . . , σ ′ (wP )) :=  . .
 
 . .. ..
 . . . 0 

0 . . . 0 σ ′ (wP )

Variations along outputs


Rather than considering variations of f along an input direction v ∈ RP ,
we may also consider the variations of f along an output direction
u ∈ RM , namely, computing the gradients ∇(u⊤ f )(w) of single-output
functions, where we defined
(u⊤ f )(w) := u⊤ f (w) ∈ R.
In particular, we may consider computing the gradients ∇fj (w) of each
function coordinate fj = e⊤j f at w, where ej is the j
th canonical vector

in R . The infinitesimal variations of f at w along any output direction


M

u= M j=1 uj ej ∈ R
M are given by
P

M
∇(u⊤ f )(w) = uj ∇fj (w) = ∂f (w)⊤ u ∈ RP ,
X

j=1

where ∂f (w)⊤ is the Jacobian’s transpose. Using the definition of


derivative as a limit, we obtain for i ∈ [P ]
u⊤ (f (w + δei ) − f (w))
∇i (u⊤ f )(w) = [∂f (w)⊤ u]i = lim .
δ→0 δ

Chain rule
Equipped with a generic definition of differentiability and the associated
objects, gradients and Jacobians, we can now generalize the chain rule,
2.2. Multivariate functions 25

previously introduced for single-input single-output functions.

Proposition 2.2 (Chain rule). Consider f : RP → RM and g :


RM → RR . If f is differentiable at w ∈ RP and g is differen-
tiable at f (w) ∈ RM , then the composition g ◦ f is differentiable
at w ∈ RP and its Jacobian is given by

∂(g ◦ f )(w) = ∂g(f (w))∂f (w).

Proof. We progressively approximate g ◦ f (w + v) using the differentia-


bility of f at w and g at f (w),
g(f (w + v)) = g(f (w) + ∂f (w)v + o(∥v∥))
= g(f (w)) + ∂g(f (w))∂f (w)v + o(∥v∥).
Hence, g ◦ f is differentiable at w with Jacobian ∂g(f (w))∂f (w).

Proposition 2.2 can be seen as the cornerstone of any derivative


computations. For example, it can be used to rederive the linearity or the
product rule associated to the derivatives of single-input single-outptut
functions.
When g is scalar-valued, combined with Remark 2.3, we obtain a
simple expression for ∇(g ◦ f ).

Proposition 2.3 (Chain rule, scalar-valued case). Consider f : RP →


RM and g : RM → R. The gradient of the composition is given by

∇(g ◦ f )(w) = ∂f (w)⊤ ∇g(f (w)).

This is illustrated with linear regression in Example 2.5.

Example 2.5 (Linear regression). Consider the squared residuals of


a linear regression of N inputs x1 , . . . , xN ∈ RD onto N targets
y1 , . . . , yN ∈ R with a vector w ∈ RD , that is, f (w) = ∥Xw −
y∥22 = N i=1 (xi w − yi ) for X = (x1 , . . . , xN )
⊤ 2 ⊤ ∈ RN ×D and
P

y = (y1 , . . . , yN ) ∈ R .
⊤ N

The function f can be decomposed into a linear mapping


f1 (w) = Xw and a squared error f2 (p) = ∥p − y∥22 , so that
f = f2 ◦ f1 . We can then apply the chain rule in Proposition 2.3 to
26 Differentiation

get
∇f (w) = ∂f1 (w)⊤ ∇f2 (f1 (w))
provided that f1 , f2 are differentiable at w and f1 (w), respectively.
The function f1 is linear so differentiable with Jacobian ∂f1 (w) =
X. On the other hand the partial derivatives of f2 are given by
∂j f2 (p) = 2(pj − yj ) for j ∈ {1, . . . , N }. Therefore, f2 is differen-
tiable at any p and its gradient is ∇f2 (p) = 2(p − y). By combining
the computations of the Jacobian of f1 and the gradient of f2 , we
then get the gradient of f as

∇f (w) = 2X ⊤ (f1 (w) − y) = 2X ⊤ (Xw − y).

2.3 Linear differentiation maps

The Jacobian matrix is useful as a representation of the partial deriva-


tives. However, the core idea underlying the definition of differentiable
functions, as well as their implementation in an autodiff framework,
lies in the access to two key linear maps. These two maps encode in-
finitesimal variations along input or output directions and are referred
to, respectively, as Jacobian-vector product (JVP) and Vector-
jacobian product (VJP). This section formalizes these notions, in the
context of Euclidean spaces.

2.3.1 The need for linear maps

So far, we have focused on functions f : RP → RM , that take a vector as


input and produce a vector as output. However, functions that use matrix
or even tensor inputs/outputs are common place in neural networks. For
example, consider the function of matrices of the form f (W ) = W x,
where x ∈ RD and W ∈ RM ×D . This function takes a matrix as input,
not a vector. Of course, a matrix W ∈ RM ×D can always be “flattened”
into a vector w ∈ RM D , by stacking the columns of W . We denote this
operation by w = vec(W ) and its inverse by W = vec−1 (w). We can
then equivalently write f (W ) as f˜(w) = f (vec−1 (w)) = vec−1 (w)x,
so that the previous framework applies. However, we will now see that
this would be inefficient.
2.3. Linear differentiation maps 27

Indeed, the resulting Jacobian of f˜ at any w consists in a matrix


of size RM ×M D , which, after some computations, can be observed to
be mostly filled with zeros. Getting the directional derivative of f at
W ∈ RM ×D in a direction V ∈ RM ×D would consist in (i) vectorizing
V into v = vec(V ), (ii) computing the matrix-vector product ∂ f˜(w)v
at a cost of M 3 D2 computations (ignoring the fact that the Jacobian
has zero entries), (iii) re-shaping the result into a matrix.
On the other hand, since f is linear in its matrix input, we can
infer that the directional derivative of f at any W ∈ RM ×D in any
direction V ∈ RM ×D is simply given by the function itself applied on
V . Namely, we have ∂f (W )[V ] = f (V ) = V x, which is simple to
implement and clearly only requires M D2 operations. Note that the
cost would have been the same, had we ignored the non-zero entries of
∂ f˜(w). The point here is that by considering the operations associated
to the differentiation of a function as linear maps rather than using the
associated representation as a Jacobian matrix, we can streamline the
associated implementations and exploit the structures of the underlying
input or output space. To that end, we now recall the main abstractions
necessary to extend the previous definitions in the context of Euclidean
spaces.

2.3.2 Euclidean spaces

Linear spaces, a.k.a. vector spaces, are spaces equipped (and closed
under) an addition rule compatible with multiplication by a scalar
(we limit ourselves to the field of reals). Namely, in a vector space E,
there exists the operations + and ·, such that for any u, v ∈ E, and
a ∈ R, we have u + v ∈ E and a · u ∈ E. Euclidean spaces are linear
spaces equipped with a basis e1 , . . . , eP ∈ E. Any element v ∈ E can be
decomposed as v = Pi=1 vi ei for some unique scalars v1 , . . . , vP ∈ R. A
P

canonical example of Euclidean space is the set RP of all vectors of size


P that we already covered. The set of matrices RP1 ×P2 of size P1 × P2
is also naturally a Euclidean space generated by the set of canonical
matrices Eij ∈ {0, 1}P1 ×P2 for i ∈ [P1 ], j ∈ [P2 ] filled with zero except
at the (i, j)th entry filled with one. For example, W ∈ RP1 ×P2 can be
written W = Pi,j=1
P 1 ,P2
Wij Eij .
28 Differentiation

Euclidean spaces are naturally equipped with a notion of inner


product defined below.

Definition 2.9 (Inner product). An inner product on a vector


space E is a function ⟨·, ·⟩ : E × E → R that is

• bilinear, that is, x 7→ ⟨x, w⟩ and y 7→ ⟨v, y⟩ are linear for


any w, v ∈ E

• symmetric, that is, ⟨w, v⟩ = ⟨w, v⟩ for any w, v ∈ E

• positive definite, that is, ⟨w, w⟩ ≥ 0 for any w ∈ E, and


⟨w, w⟩ = 0 if and only if w = 0.

An inner product defines a norm ∥w∥ :=


p
⟨w, w⟩.

The norm induced by an inner product defines a distance ∥w − v∥


between w, v ∈ E, and therefore a notion of convergence.
For vectors, where E = RP , the inner product is the usual one
⟨w, v⟩ = Pi=1 wi vi . For matrices, where E = RP1 ×P2 , the inner product
P

is the so-called Frobenius inner product. It is defined for any W , V ∈


RP1 ×P2 by
PX
1 ,P2

⟨W , V ⟩ := ⟨vec(W ), vec(V )⟩ = Wij Vij = tr(W ⊤ V ),


i,j=1

where tr(Z) := Pi=1 Zii is the trace operator defined for square matrices
P

Z ∈ RP ×P . For tensors of order R, which generalize matrices to E =


RP1 ×...×PR , the inner product is defined similarly for W, V ∈ RP1 ×...×PR
by
P1 X
,...,PR
⟨W, V⟩ := ⟨vec(W), vec(V)⟩ = Wi1 ...iR Vi1 ...iR ,
i1 ,...,iR =1

where Wi1 ...iR is the (i1 , . . . , iR )th entry of W.

2.3.3 Linear maps and their adjoints


The notion of linear map defined in Definition 2.4 naturally extends
to Euclidean spaces. Namely, a function l : E → F from a Euclidean
2.3. Linear differentiation maps 29

space E onto a Euclidean space F is a linear map if for any w, v ∈ E


and a, b ∈ R, we have l[aw + bv] = a · l[w] + b · l[v]. When E = RP and
F = RM , there always exists a matrix A ∈ RM ×P such that l[x] = Ax.
Therefore, we can think of A as the “materialization” of l.
We can define the adjoint operator of a linear map.

Definition 2.10 (Adjoint operator). Given two Euclidean spaces E


and F equipped with inner products ⟨·, ·⟩E and ⟨·, ·⟩F , the adjoint
of a linear map l : E → F is the unique linear map l∗ : F → E such
that for any v ∈ E and u ∈ F,

⟨l[v], u⟩F = ⟨v, l∗ [u]⟩E .

The adjoint can be thought as the counterpart of the matrix trans-


pose for linear maps. When l[v] = Av, we have l∗ [u] = A⊤ u since
⟨l[v], u⟩F = ⟨Av, u⟩F = ⟨v, A⊤ u⟩E = ⟨v, l∗ [u]⟩E .

2.3.4 Jacobian-vector products


We now define the directional derivative using linear maps, leading
to the notion of Jacobian-vector product (JVP). This can be used to
facilitate the treatment of functions on matrices or be used for further
extensions to infinite-dimensional spaces. In the following, E and F
denote two Euclidean spaces equipped with norms ⟨·, ·⟩E and ⟨·, ·⟩F . We
start by defining differentiability in general Euclidean spaces.

Definition 2.11 (Differentiability in Euclidean spaces). A function f :


E → F is differentiable at a point w ∈ E if the directional
derivative along v ∈ E
f (w + δv) − f (w)
∂f (w)[v] := lim = l[v],
δ→0 δ
is well-defined for any v ∈ E, linear in v and if
∥f (w + v) − f (w) − l[v]∥2
lim = 0.
∥v∥2 →0 ∥v∥2

We can now formally define the Jacobian-vector product.


30 Differentiation

Definition 2.12 (Jacobian-vector product). For a differentiable func-


tion f : E → F, the linear map ∂f (w) : E → F, mapping v to
∂f (w)[v] is called the Jacobian-vector product (JVP) by anal-
ogy with Definition 2.8. Note that ∂f : E → (E → F).

Strictly speaking, v belongs to E. Therefore it may not be a vector,


if for instance E is the set of real matrices. We adopt the name JVP, as
it is standard in the literature.

Recovering the gradient


Previously, we saw that for differentiable functions with vector input
and scalar output, the directional derivative is equal to the inner prod-
uct between the direction and the gradient. The same applies when
considering differentiable functions from a Euclidean space with single
outputs, except that the gradient is now an element of the input space
and the inner product is the one associated with the input space.

Proposition 2.4 (Gradient). If a function f : E → R is differen-


tiable at w ∈ E, then there exists ∇f (w) ∈ E, called the gradient
of f at w such that the directional derivative of f at w along any
input direction v ∈ E is given by

∂f (w)[v] = ⟨∇f (w), v⟩E .

In Euclidean spaces, the existence of the gradient can simply be


shown by decomposing the partial derivative along a basis of E. Such a
definition generalizes to infinite-dimensional (e.g., Hilbert spaces) spaces
as seen in Section 2.3.9.

2.3.5 Vector-Jacobian products


For functions with vector input and vector outputs, we already discussed
infinitesimal variations along output directions. The same approach
applies for Euclidean spaces and is tied to the adjoint of the JVP as
detailed in Proposition 2.5.
2.3. Linear differentiation maps 31

Proposition 2.5 (Vector-Jacobian product). If a function f : E →


F is differentiable at w ∈ E, then its infinitesimal variation along an
output direction u ∈ F is given by the adjoint map ∂f (w)∗ : F →
E of the JVP, called the vector-Jacobian product (VJP). It
satisfies

∇⟨u, f ⟩F (w) = ∂f (w)∗ [u],

where we denoted ⟨u, f ⟩F (w) := ⟨u, f (w)⟩F . Note that ∂f (·)∗ : E →


(F → E).

Proof. The chain rule presented in Proposition 2.2 naturally generalizes


to Euclidean spaces (see Proposition 2.6). Since ⟨u, ·⟩ is linear, its
directional derivative is itself. Therefore, the directional derivative of
⟨u, f ⟩F is

∂(⟨u, f ⟩F )(w)[v] = ⟨u, ∂f (w)[v]⟩F


= ⟨∂f (w)∗ [u], v⟩F .

As this is true for any v ∈ E, ∂f (w)∗ [u] is the gradient of ⟨u, f ⟩F per
Proposition 2.4.

2.3.6 Chain rule


The chain rule presented before in terms of Jacobian matrices can
readily be formulated in terms of linear maps to take advantage of the
implementations of the JVP and VJP as linear maps.

Proposition 2.6 (Chain rule, general case). Consider f : E → F


and g : F → G for E, F, G some Euclidean spaces. If f is differ-
entiable at w ∈ E and g is differentiable at f (w) ∈ F, then the
composition g ◦ f is differentiable at w ∈ E. Its JVP is given by

∂(g ◦ f )(w)[v] = ∂g(f (w))[∂f (w)[v]]


32 Differentiation

and its VJP is given by

∂(g ◦ f )(w)∗ [u] = ∂f (w)∗ [∂g(f (w))∗ [u]].

The proof follows the one of Proposition 2.2. When the last function
is scalar-valued, which is often the case in machine learning, we obtain
the following simplified result.

Proposition 2.7 (Chain rule, scalar case). Consider f : E → F and


g : F → R, the gradient of the composition is given by

∇(g ◦ f )(w) = ∂f (w)∗ [∇g(f (w))].

2.3.7 Functions of multiple inputs (fan-in)


Oftentimes, the inputs of a function do not belong to only one Euclidean
space but to a product of them. An example is f (x, W ) = W x, which
is defined on E = RD × RM ×D . In such a case, it is convenient to
generalize the notion of partial derivatives to handle blocks of inputs.
Consider a function f (w1 , . . . , wS ) defined on E = E1 × . . . × ES ,
where wi ∈ Ei . We denote the partial derivative with respect to the
ith input wi along vi ∈ Ei as ∂i f (w1 , . . . , wS )[vi ]. Equipped with this
notation, we can analyze how JVPs or VJPs are decomposed along
several inputs.

Proposition 2.8 (Multiple inputs). Consider a differentiable func-


tion of the form f (w) = f (w1 , . . . , wS ) with signature f : E → F,
where w := (w1 , . . . , wS ) ∈ E and E := E1 × · · · × ES . Then the
JVP with the input direction v = (v1 , . . . , vS ) ∈ E is given by

∂f (w)[v] = ∂f (w1 , . . . , wS )[v1 , . . . , vS ] ∈ F


S
= ∂i f (w1 , . . . , wS )[vi ].
X

i=1
2.3. Linear differentiation maps 33

The VJP with the output direction u ∈ F is given by

∂f (w)∗ [u] = ∂f (w1 , . . . , wS )∗ [u] ∈ E


= (∂1 f (w1 , . . . , wS )∗ [u], . . . , ∂S f (w1 , . . . , wS )∗ [u]).

Example 2.6 (Matrix-vector product). Consider f (x, W ) = W x,


where W ∈ RM ×D and x ∈ RD . This corresponds to setting
E = E1 × E2 = RD × RM ×D and F = RM . For the JVP, letting
v ∈ RD and V ∈ RM ×D , we obtain

∂f (x, W )[v, V ] = W v + V x ∈ F.

We can also access the individual JVPs as


∂1 f (x, W )[v] = W v ∈ F,
∂2 f (x, W )[V ] = V x ∈ F.

For the VJP, letting u ∈ RM , we obtain

∂f (x, W )∗ [u] = (W ⊤ u, ux⊤ ) ∈ E.

We can access the individual VJPs by

∂1 f (x, W )∗ [u] = W ⊤ u ∈ E1 ,
∂2 f (x, W )∗ [u] = ux⊤ ∈ E2 .

Remark 2.4 (Nested inputs). It is sometimes convenient to group


inputs in meaningful parts. For instance, if the input is naturally
broken down into two parts x = (x1 , x2 ), where x1 is a text
part and x2 is an image part, and the network parameters are
naturally grouped into three layers w = (w1 , w2 , w3 ), we can write
f (x, w) = f ((x1 , x2 ), (w1 , w2 , w3 )). This is mostly a convenience
and we can again reduce it to a function of a single input, thanks
to the linear map perspective in Euclidean spaces.

Remark 2.5 (Hiding away inputs). It will often be convenient to


ignore inputs when differentiating. We use the semicolon for this
34 Differentiation

purpose. For instance, a function of the form L(w; x, y) (notice


the semicolon) has signature L : W → R because we treat x and
y as constants. Therefore, the gradient is ∇L(w; x, y) ∈ W. On
the other hand, the function L(w, x, y) (notice the comma) has
signature L : W × X × Y → R so its gradient is ∇L(w, x, y) ∈
W × X × Y. If we need to access partial gradients, we use indexing,
e.g., ∇1 L(w, x, y) ∈ W or ∇w L(w, x, y) ∈ W when there is no
ambiguity.

2.3.8 Functions of multiple outputs (fan-out)


Similarly, it is often convenient to deal with functions of multiple
outputs.
Proposition 2.9 (Multiple outputs). Consider a differentiable func-
tion of the form f (w) = (f1 (w), . . . , fT (w)), with signatures f : E →
F and fi : E → Fi , where F := F1 × · · · × FT . Then the JVP with
the input direction v ∈ E is given by

∂f (w)[v] = (∂f1 (w)[v], . . . , ∂fT (w)[v]) ∈ F.

The VJP with the output direction u = (u1 , . . . , uT ) ∈ F is

∂f (w)∗ [u] = ∂f (w)∗ [u1 , . . . , uT ] ∈ E


T
= ∂fi (w)∗ [ui ].
X

i=1

Combined with the chain rule, we obtain that the Jacobian of


h(w) := g(f (w)) = g(f1 (w), . . . , fT (w))
is ∂h(w) = i=1 ∂i f (g(w)) ◦ ∂gi (w) and therefore the JVP is
PT

T
∂h(w)[v] = ∂i f (g(w))[∂gi (w)[v]].
X

i=1

2.3.9 Extensions to non-Euclidean linear spaces


We focused on Euclidean spaces, i.e., linear spaces with a finite basis.
However, the notions introduced earlier can be defined in more generic
2.3. Linear differentiation maps 35

spaces.
For example, directional derivatives (see Definition 2.11) can
be defined in any linear space equipped with a norm and complete
with respect to this norm. Such spaces are called Banach spaces.
Completeness is a technical assumption that requires that any Cauchy
sequence converges (a Cauchy sequence is a sequence whose elements
become arbitrarily close to each other as the sequence progresses). A
function f : E → F defined from a Banach space E onto a Banach space
F is then called Gateaux differentiable if its directional derivative is
defined along any direction (where limits are defined w.r.t. the norm in
F). Some authors also require the directional derivative to be linear to
define a Gateaux differentiable function.
Fréchet differentiability can also naturally be generalized to
Banach spaces. The only difference is that, in generic Banach spaces,
the linear map l satisfying Definition 2.11 must be continuous, i.e., there
must exist C > 0, such that l[v] ≤ C∥v∥, where ∥ · ∥ is the norm in the
Banach space E.
The definitions of gradient and VJPs require in addition a notion of
inner product. They can be defined in Hilbert spaces, that is, linear
spaces equipped with an inner product and complete with respect to
the norm induced by the inner product (they could also be defined
in a Banach space by considering operations in the dual space, see,
e.g. (Clarke et al., 2008)). The existence of the gradient is ensured by
Riesz’s representation theorem which states that any continuous
linear form in a Hilbert space can be represented by the inner product
with a vector. Since for a differentiable function f : E → R, the JVP
∂f (w) : E → R is a linear form, Riesz’s representation theorem ensures
the existence of the gradient as the element g ∈ E such that ∂f (w)v =
⟨g, v⟩ for any v ∈ E. The VJP is also well-defined as the adjoint of the
JVP w.r.t. the inner product of the Hilbert space.
As an example, the space of squared integrable functions on R is a
Hilbert space equipped with the inner product ⟨a, b⟩ := a(x)b(x)dx.
R

Here, we cannot find a finite number of functions that can express all
possible functions on R. Therefore, this space is not a mere Euclidean
space. Nevertheless, we can consider functions on this Hilbert space
(called functionals to distinguish them from the elements of the space).
36 Differentiation

The associated directional derivatives and gradients, can be defined


and are called respectively, functional derivative and functional
gradient, see, e.g., Frigyik et al. (2008) and references therein.

2.4 Second-order differentiation

2.4.1 Second derivatives


For a single-input, single-output differentiable function f : R → R,
its derivative at any point is itself a function f ′ : R → R. We may
then consider the derivative of the derivative at any point: the second
derivative.

Definition 2.13 (Second derivative). The second derivative f (2) (w)


of a differentiable function f : R → R at w ∈ R is defined as the
derivative of f ′ at w, that is,
f ′ (w + δ) − f ′ (w)
f (2) (w) := lim ,
δ→0 δ
provided that the limit is well-defined. If the second derivative
of a function f is well-defined at w, the function is said twice
differentiable at w.

If f has a small second derivative at a given w, the derivative around


w is almost constant. That is, the function behaves like a line around
w, as illustrated in Fig. 2.4. Hence, the second derivative is usually
interpreted as the curvature of the function at a given point.

2.4.2 Second directional derivatives


For a multi-input function f : RP → R, we saw that the directional
derivative encodes infinitesimal variations of f along a given direction.
To analyze the second derivative, the curvature of the function at a
given point w, we can consider the variations along a pair of directions,
as defined below.

Definition 2.14 (Second directional derivative). The second direc-


tional derivative of f : RP → R at w ∈ RP along v, v ′ ∈ RP
2.4. Second-order differentiation 37

Figure 2.4: Points at which the second derivative is small are points along which
the function is well approximated by its tangent line. On the other hand, point with
large second derivative tend to be badly approximated by the tangent line.

is defined as the directional derivative of w 7→ ∂f (w)[v] along v ′ ,


that is,
∂f (w + δv ′ )[v] − ∂f (w)[v]
∂ 2 f (w)[v, v ′ ] := lim ,
δ→0 δ
provided that ∂f (w)[v] is well-defined around w and that the limit
exists.

Of particular interest are the variations of a function around the


canonical directions: the second partial derivatives, defined as
2
∂ij f (w) := ∂ 2 f (w)[ei , ej ]

for ei , ej the ith and j th canonical directions in RP , respectively. In


Leibniz notation, the second partial derivatives are denoted
∂ 2 f (w)
2
∂ij f (w) = .
∂wi ∂wj

2.4.3 Hessians
For a multi-input function, twice differentiability is simply defined as
the differentiability of any directional derivative ∂f (w)[v] w.r.t. w.
38 Differentiation

Definition 2.15 (Twice differentiability). A function f : RP → R is


twice differentiable at w ∈ RP if it is differentiable and ∂f : RP →
(RP → R) is also differentiable at w.

As a result, the second directional derivative is a bilinear form.

Definition 2.16 (Bilinear map, bilinear form). A function b : RP ×


RP → RM is a bilinear map if b[v, ·] : RP → R is linear for any v
and b[·, v ′ ] is linear for any v ′ . That is,
P P X
P
b[v, v ′ ] = vi b[ei , v ′ ] = vi vj′ b[ei , ej ],
X X

i=1 i=1 j=1

for v = i=1 vi ei and v ′ = i=1 vi′ ei . A bilinear map with values


PP PP

in R, b : RP × RP → R, is called a bilinear form.

The second directional derivative can be computed along any two


directions from the knowledge of the second partial derivatives gathered
in the Hessian.

Definition 2.17 (Hessian). The Hessian of a twice differentiable


function f : RP → R at w is the P × P matrix gathering all second
partial derivatives:

∂11 f (w) . . . ∂1P f (w)


 

∇ f (w) := 
2
 .. .. .. 
 . . . ,

∂P 1 f (w) . . . ∂P P f (w)
provided that all second partial derivatives are well-defined.
The second directional derivative at w is bilinear in any direc-
tions v = Pi=1 vi ei and v ′ = Pi=1 vi′ ei . Therefore,
P P

P
∂ 2 f (w)[v, v ′ ] = vi vj′ ∂ 2 f (w)[ei , ej ] = ⟨v, ∇2 f (w)v ′ ⟩.
X

i,j=1

Given the gradient of f , the Hessian is equivalent to the transpose


of the Jacobian of the gradient. By slightly generalizing the notation ∇
to denote the transpose of the Jacobian of a function (which matches
2.4. Second-order differentiation 39

its definition for single-output functions), we have that the Hessian can
be expressed as ∇2 f (w) = ∇(∇f )(w), which justifies its notation.
Similarly as for the differentiability of a function f , twice differen-
tiability of f at w is equivalent to having the second partial derivatives
not only defined but also continuous in a neighborhood of w. Remark-
ably, by requiring twice differentiability, i.e., continuous second partial
derivatives, the Hessian is guaranteed to be symmetric (Schwarz, 1873).

Proposition 2.10 (Symmetry of the Hessian). If a function f : RP →


R is twice differentiable at w, then its Hessian ∇2 f (w) is symmetric,
that is, ∂ij
2 f (w) = ∂ 2 f (w) for any i, j ∈ {1, . . . P }.
ji

The symmetry of the Hessian means that it can alternatively be


written as ∇2 f (w) = (∂ji
2 f (w))P
i,j=1 = ∂(∇f )(w), i.e., the Jacobian of
the gradient of f .

2.4.4 Hessian-vector products


Similarly to the Jacobian, we can exploit the formal definition of the
Hessian as a bilinear form to extend its definition for Euclidean spaces.
In particular, we can define the notion of Hessian-vector product.

Definition 2.18 (Hessian-vector product). If a function f : E → R


defined on a Euclidean space E with inner product ⟨·, ·⟩, is twice
differentiable at w ∈ E, then for any v ∈ E, there exists ∇2 f (w)[v],
called the Hessian-vector product (HVP) of f at w along v
such that for any v ′ ∈ E,

∂ 2 f (w)[v, v ′ ] = ⟨v ′ , ∇2 f (w)[v]⟩.

In particular for E = RP , the HVP is ∇2 f (w)[v] = (∂ 2 f (w)[v, ei ])Pi=1 .

From an autodiff point of view, the HVP can be implemented in


four different ways, as explained in Section 8.1.

2.4.5 Second-order Jacobians


The previous definitions naturally extend to multi-output functions
f : E → F, where f := (f1 , . . . , fM ), fj : E → Fj and F := F1 ×· · ·×FM .
40 Differentiation

The second directional derivative is defined by gathering the second


derivatives of all coordinates, that is, for w, v, v ′ ∈ E,

∂f (w)[v, v ′ ] = (∂fj (w)[v, v ′ ])M


j=1 ∈ F.

The function f is twice differentiable if and only if all its coordinates are
twice differentiable. The second directional derivative is then a bilinear
map. We can then compute second directional derivatives as
P
∂ 2 f (w)[v, v ′ ] = vi vj′ ∂ 2 f (w)[ei , ej ] = (⟨v, ∇2 fj (w)v ′ ⟩)M
X
j=1 .
i,j=1

When E = RP and Fj = R, so that F = RM , the bilinear map can be


materialized as a tensor

∂ 2 f (w) = (∂ 2 f (w)[ei , ej ])Pi,j=1 ∈ RM ×P ×P ,

the “second-order Jacobian” of f . However, similarly to the Hessian,


it is usually more convenient to apply the bilinear map to prescribed
vectors v and v ′ than to materialize as a tensor.

2.5 Higher-order differentiation

2.5.1 Higher-order derivatives

Derivatives can be extended to any order. Formally, the nth derivative


can be defined inductively as follows for a single-input, single-output
function.

Definition 2.19 (nth order derivative). The nth derivative f (n) of a


function f : R → R at w ∈ R is defined as

f (n−1) (w + δ) − f (n−1) (w)


f (n) (w) := (f (n−1) )′ (w) = lim
δ→0 δ
provided that f n−1 is differentiable around w and that the limit
exists. In such a case, the function is said n times differentiable at
w.
2.5. Higher-order differentiation 41

2.5.2 Higher-order directional derivatives


For a multi-input function f , we can naturally extend the notion of
directional derivative as follows.

Definition 2.20 (nth order directional derivative). The nth directional


derivative of f : RP → R at w ∈ RP along v1 , . . . , vn is defined as

∂ n f (w)[v1 , . . . , vn ]
= ∂(∂ n−1 f (·)[v1 , . . . , vn−1 ])[vn ]
∂f (w + δvn )[v1 , . . . , vn−1 ] − ∂f (w)[v1 , . . . , vn−1 ]
= lim
δ→0 δ

A multi-input function f is n-times differentiable if it is n − 1


differentiable and its n − 1 directional derivative along any direction
is differentiable. As a consequence the nth directional derivative is a
multilinear form.

Definition 2.21 (Multilinear map, multilinear form). A function c :


⊗ni=1 RP → RM is a multilinear map if it is linear in each coor-
dinate given all others fixed, that is, if vj 7→ c[v1 , . . . , vj , . . . , vn ]
is linear in vj for any j ∈ [n]. It is a multilinear form if it has
values in R.

The nth order directional derivative is then given by


P
∂ n f (w)[v1 , . . . , vn ] = v1,i1 . . . vn,in ∂ n f (w)[ei1 , . . . , ein ].
X

i1 ,...,in =1

Its materialization is an nth order tensor

∇n f (w) = (∂ n f (w)[ei1 , . . . , ein ])Pi1 ,...,in =1 ∈ RP ×...×P .

2.5.3 Higher-order Jacobians


All above definitions extend directly to the case of multi-output functions
f : E → F, where F = F1 × · · · × FM . The nth directional derivatives

∂ n f (w)[v1 , . . . , vn ] = (∂ n fj (w)[v1 , . . . , vn ])M


j=1 .
42 Differentiation

The function f is then n times differentiable if if it is n − 1 differentiable


and its n − 1 directional derivative along any direction is differentiable.
As a consequence, the nth directional derivative is a multilinear map.
The nth directional derivative can be decomposed into partial derivatives
as
P
∂ n f (w)[v1 , . . . , vn ] = v1,i1 . . . vn,in ∂ n f (w)[ei1 , . . . , ein ].
X

i1 ,...,in =1

When E = RP and F = RM , it is materialized by an n + 1th order


tensor

∂ n f (w) = (∂ n fj (w)[ei1 , . . . , ein ])M,P,...,P


j=1,i1 ,...,in =1 ∈ R
M ×P ×...×P
.

2.5.4 Taylor expansions


With Landau’s little o notation, we have seen that if a function is
differentiable it is approximated by a linear function in v,

f (w + v) = f (w) + ⟨∇f (w), v⟩ + o(∥v∥2 ).

Such an expansion of the function up to its first derivative is called the


first-order Taylor expansion of f around w.
If the function f is twice differentiable, we can approximate it by a
quadratic in v, leading to the second-order Taylor expansion of f
around w,
1
f (w + v) = f (w) + ⟨∇f (w), v⟩ + ⟨v, ∇2 f (w)v⟩ + o(∥v∥22 ).
2
Compared to the first-order Taylor approximation, it is naturally more
accurate around w, as reflected by the fact that ∥v∥32 ≤ ∥v∥22 for
∥v∥2 ≤ 1.
More generally, we can build the nth order Taylor expansion of
a n times differentiable function f : RP → RM around w ∈ RP by
1
f (w + v) = f (w) + ∂f (w)[v] + ∂ 2 f (w)[v, v] + . . .
2
1 n
+ ∂ f (w)[v, . . . , v ] + o(∥v∥n2 ).
n! | {z }
n times
2.6. Differential geometry 43

Note that, using the change of variable w′ = w + v ⇐⇒ v = w′ − w,


it is often convenient to write the nth Taylor expansion of f (w′ ) around
w as
n
1 j
f (w′ ) = f (w) + ∂ f (w)[w′ − w, . . . , w′ − w] + o(∥w′ − w∥n2 ).
X
j!
j=1
| {z }
j times

Taylor expansions will prove useful in Chapter 6 for computing deriva-


tives by finite differences.

2.6 Differential geometry

In this chapter, we progressively generalized the notion of derivative


from real numbers to vectors and variables living in a linear space (a.k.a.
vector space), either finite dimensional or infinite dimensional. We can
further generalize these notions by considering a local notion of linearity.
This is formalized by smooth manifolds in differential geometry,
whose terminology is commonly adopted in the automatic differentiation
literature and software. In this section, we give a brief overview of
derivatives on smooth manifolds (simply referred as manifolds), and
refer to Boumal (2023) for a complete introduction.

2.6.1 Differentiability on manifolds


Essentially, a manifold is a set that can be locally approximated by
a Euclidean space. The most common example is a sphere like the
Earth. Seen from the Moon, the Earth is not a plane, but locally, at a
human level, it can be seen as a flat surface. Euclidean spaces are also
trivial examples of manifold. A formal characterization of the sphere
as a manifold is presented in Example 2.7. For now, we may think of
a “manifold” as some set (e.g., the sphere) contained in some ambient
Euclidean space; note however that manifolds can be defined generally
without being contained in a Euclidean space (Boumal, 2023, Chapter
8). Differentiability in manifolds is simply inherited from the notion of
differentiability in the ambient Euclidean space.
44 Differentiation

Definition 2.22 (Differentiability of restricted functions). Let M and


N be manifolds. A function f : M → N defined from M ⊆ E to
N ⊆ F, with E and F Euclidean spaces, is differentiable if f is
the restriction of a differentiable function f¯ : E → F, so that f
coincides with f¯ on M.

Our objective is to formalize the directional derivatives and gradients


for functions defined on manifolds. This formalization leads to the
definitions of tangent spaces and cotangent spaces, and the associated
generalizations of JVP and VJP operators as pushforward and pullback
operators, respectively.

2.6.2 Tangent spaces and pushforward operators


To generalize the notion of directional derivatives of a function f , the
one property we want to preserve is the chain rule. Rather than starting
from the variations of f at a given point along a direction, we start with
the variations of f along curves. Namely, on a manifold like the sphere
S P in RP , we can look at curves α : R → S P passing by w ∈ S P at
time 0, that is, α(0) = w. For single-input functions like α, we denoted
for simplicity α′ (0) := (α1′ (0), . . . , αP′ (0)). The directional derivative of
f must typically serve to define the derivative of f ◦ α at 0, such that
(f ◦ α)′ (0) = ∂f (w)[α′ (0)]. In the case of the sphere, as illustrated in
Fig. 2.5, the derivative α′ (0) of a curve α passing through a point w is
always tangent to the sphere at w. The tangent plane to the sphere
at w, then captures all possible relevant vectors to pass to the JVP
we are building. To define the directional derivative of a function f on
a manifold, we therefore restrict ourselves to an operator defined on
the tangent space Tw M, whose definition below is simplified for our
purposes.

Definition 2.23 (Tangent space). The tangent space of a mani-


fold M at w ∈ M is defined as

Tw M := {v = α′ (0) for any α : R → M differentiable s.t. α(0) = w}.

In the case of the sphere in Fig. 2.5, the tangent space is a plane, that
is, a Euclidean space. This property is generally true: tangent spaces
2.6. Differential geometry 45

Figure 2.5: A differentiable function f defined from a sphere M to a sphere N


defines a push-forward operator that maps tangent vectors (derivatives of functions
on the sphere passing by w) in the tangent space Tw M of M at w to tangent vectors
of N at f (w) in the tangent space Tf (w) N .

are Euclidean spaces such that we will be able to define directional


derivatives as linear operators. Now, if f is differentiable and goes from
a manifold M to a manifold N , then f ◦ α is a differentiable curve in N .
Therefore, (f ◦ α)′ (0) is the derivative of a curve passing through f (w)
at 0 and is tangent to N at f (w). Hence, the directional derivative
of f : M → N at w can be defined as a function from the tangent
space Tw M of M at w onto the tangent space Tf (w) N of N at f (w).
Overall, we built the directional derivative (JVP) by considering how a
composition of f with any curve α pushes forward the derivative of α
into the derivative of f ◦ α. The resulting JVP is called a pushforward
operator in differentiable geometry.

Definition 2.24 (Pushforward operator). Given two manifolds M


and N , the pushforward operator of a differentiable function
f : M → N at w ∈ M is the linear map ∂f (w) : Tw M → Tf (w) N
defined by
∂f (w)[v] := (f ◦ α)′ (0),
for any v ∈ Tw M such that v = α′ (0), for a differerentiable curve
α : R → M, passing by w at 0, i.e., α(0) = w.

2.6.3 Cotangent spaces and pullback operators


To generalize the JVP, we composed f : M → N with any single-input
function α : R → M giving values on the manifold. The derivative of
46 Differentiation

any such α is then pushed forward from Tw M to Tf (w) N by the action


of f . To define the VJP, we take a symmetric approach. We consider all
single-output differentiable functions β : N → R defined on y ∈ N with
y = f (w) for some w ∈ M. We then want to pull back the derivatives
of β when precomposing it by f . Therefore, the space on which the
VJP acts is the space of directional derivatives of any β : N → R at y,
defining the cotangent space.

Definition 2.25 (Cotangent space). The cotangent space of a


manifold N at y ∈ N is defined as

Ty∗ N = {u = ∂β(y) for any β : N → R differentiable}


= {u : Ty N → R for any linear map u},

Note that elements of the cotangent space are linear mappings, not
vectors. This distinction is important to define the pullback operator
as an operator on functions as done in measure theory. From a linear
algebra viewpoint, the cotangent space is exactly the dual space of
Ty N , that is, the set of linear maps from Ty N to R, called linear forms.
As Ty N is a Euclidean space, its dual space Ty∗ N is also a Euclidean
space. The pullback operator is then defined as the operator that gives
access to directional derivatives of β ◦ f given the directional derivative
of β at f (w).

Definition 2.26 (Pullback operator). Given two manifolds M and


N , the pullback operator of a differentiable function f : M → N
at w ∈ M is the linear map ∂f (w)⋆ : Tf∗(w) N → Tw∗ M defined by

∂f (w)⋆ u := ∂(β ◦ f )(w),

for any u ∈ Tf (w) N ∗ such that ∂β(f (w)) = u, for a differentiable


function β : N → R.

Contrary to the pushforward operator that acts on vectors, the


pullback operator acts on linear forms. Hence, the slight difference in
notation between ∂f (w)⋆ and ∂f (w)∗ , the adjoint operator of ∂f (w).
To properly define the adjoint operator ∂f (w)∗ , we need a notion
of inner product. Since tangent spaces are Euclidean spaces, we can
2.6. Differential geometry 47

define an inner product ⟨·, ·⟩w for each Tw M and w ∈ M, making M


a Riemannian manifold. Equipped with these inner products, the
cotangent space can be identified with the tangent space, and we can
define gradients.

Definition 2.27 (Gradients in Riemannian manifolds). Let M be a


Riemannian manifold equipped with inner products ⟨·, ·⟩w . For any
cotangent vector u ∈ Tw∗ M, with w ∈ M, there exists a unique
tangent vector u ∈ Tw M such that

∀v ∈ Tw M, u[v] = ⟨u, v⟩w .

In particular for any differentiable function f : M → R, we can


define the gradient of f as the unique tangent vector ∇f (w) ∈
Tw M such that

∀v ∈ Tw M, ∂f (w)[v] = ⟨∇f (w), v⟩.

Therefore, rather than pulling back directional derivatives, we can


pull back gradients. The corresponding operator is then naturally the
adjoint ∂f (w)∗ of the pushforward operator. Namely, given two Rieman-
nian manifolds M and N , and a differentiable function f : M → N ,
we have
(∂f (w)⋆ u) [v] = ⟨∂f (w)∗ [u], v⟩ for any v ∈ Tw M
for u = ⟨·, u⟩ ∈ Tf∗(w) N represented by u ∈ Tf (w) N .

Example 2.7 (The sphere as a manifold). The sphere S P in RP is


defined as the set of points w ∈ RP , satisfying c(w) := ⟨w, w⟩−1 =
0, with JVP ∂c(w)[v] = 2⟨w, v⟩.
For any v = (v1 , . . . , vP −1 ) ∈ RP −1 close enough to the P − 1
coordinates of a point w on the sphere, we can define ψ1 (v) =
1 − ⟨v, v⟩ such that ψ(v) = (v1 , . . . , vP −1 , ψ1 (v)) satisfies ⟨ψ(v), ψ(v)⟩ =
p

1, that is c(ψ(w)) = 1. With the help of the mapping ψ −1 from a


neighborhood of w in the sphere to RP −1 , we can see the sphere
locally as a space of dimension P − 1.
The tangent space can be naturally characterized in terms of
48 Differentiation

Function f M→N
Push-forward ∂f (w) Tw M → Tf (w) N
Pullback ∂f (w)⋆ Tf∗(w) N → Tw∗ M
Adjoint of pushforward ∂f (w)∗ Tf (w) N → Tw M

Table 2.1: For a differentiable function f defined from a manifold M onto a manifold
N , the JVP is generalized with the notion of pushforward ∂f (w). The counterpart
of the pushforward is the pullback operation ∂f (w)⋆ that acts on linear forms in the
tangent spaces. For Riemannian manifolds, the pullback operation can be identified
with the adjoint operator ∂f (w)∗ of the pushforward operator as any linear form is
represented by a vector.

the constraining function c. Namely, a curve α : R → M such that


α(0) = w satisfies for any δ ∈ R, c(α(δ)) = 0. Hence, differentiating
the implicit equation, we have

(c ◦ α)′ (0) = ∂c(w)[α′ (0)].

That is, α′ (0) is in the null space of ∂c(w), denoted

Null(∂c(w)) := {v ∈ RP : ∂c(w)[v] = 0}.

The tangent space of M at w is then

Tw M = Null(2⟨w, ·⟩)
= {v ∈ RP : ⟨w, v⟩ = 0}

We recover naturally that the tangent space is a Euclidean space


of dimension P − 1 defined as the set of points orthogonal to w.

2.7 Generalized derivatives

While we largely focus on differentiable functions in this book, it is


important to characterize non-differentiable functions. We distinguish
here two cases: continuous functions and non-continuous functions. For
the former case, there exists generalizations of the notion of directional
derivative, gradient and Jacobian, presented below. For non-continuous
functions, even if derivatives exist almost everywhere, they may be
2.7. Generalized derivatives 49

uninformative. For example, piecewise constant functions, encountered


in e.g. control flows (Chapter 5), are almost everywhere differentiable
but with zero derivatives. In such cases, surrogate functions can be
defined to ensure the differentiability of a program (Part IV).

2.7.1 Rademacher’s theorem


We first recall the definition of locally Lipschitz continuous function.

Definition 2.28 (Locally Lipschitz continuous function). A function


f : E → F, is Lipschitz continuous if there exists C ≥ 0 such that
for any x, y ∈ E,

∥f (x) − f (y)∥ ≤ C∥x − y∥.

A function f : E → F is locally Lipschitz-continuous if for any


x ∈ E, there exists a neighborhood U of x such that f restricted
to U is Lipschitz-continuous.

Rademacher’s theorem (Rademacher, 1919) ensures that f is differ-


entiable almost everywhere; see also Morrey Jr (2009) for a standard
proof.

Proposition 2.11 (Rademacher’s theorem). Let E and F denote Eu-


clidean spaces. If f : E → F is locally Lipschitz-continuous, then f
is almost everywhere differentiable, that is the set of points in E at
which f is not differentiable is of (Lebesgue) measure zero.

2.7.2 Clarke derivatives


Rademacher’s theorem hints that the definitions of directional deriva-
tives, gradients and Jacobians may be generalized to locally Lipschitz
continuous functions. This is what Clarke (1975) did in his seminal
work, which laid the foundation of nonsmooth analysis. The first
building block is a notion of generalized directional derivative.
50 Differentiation

Definition 2.29 (Clarke generalized directional derivative). The Clarke


generalized directional derivative of a locally Lipschitz con-
tinuous function f : E → R at w ∈ E in the direction v ∈ E
is
f (u + δv) − f (u)
∂C f (w)[v] := lim sup ,
u→w δ
δ↘0

provided that the limit exists, where δ ↘ 0 means that δ approaches


0 by non-negative values and the limit superior is defined as

lim sup f (x) := lim sup{f (x) : x ∈ B(a, ε) \ {a}}


x→a ε→0

for B(a, ε) := {x ∈ E : ∥x − a∥ ≤ ε} the ball centered at a of


radius ε.
There are two differences with the usual definition of a directional
derivative: (i) we considered slopes of the function in a neighborhood of
the point rather than at the given point, (ii) we took a limit superior
rather than a usual limit. The first point is rather natural in the light
of Rademacher’s theorem: we can properly characterize variations on
points where the function is differentiable, therefore we may take the
limits of these slopes as a candidate slope for the point of interest. The
second point is more technical but essential: it allows us to characterize
the directional derivative as the supremum of some linear forms (Clarke
et al., 2008). These linear forms in turn define a set of generalized
gradients (Clarke et al., 2008, Chapter 2).

Definition 2.30 (Clarke generalized gradient). A Clarke general-


ized gradient of a locally Lipschitz function f : E → R at w ∈ E
is a point g ∈ E such that ∀v ∈ E

∂f (w)[v] ≥ ⟨g, v⟩.

The set of Clarke generalized gradients is called the Clarke subd-


ifferential of f at w.

Definition 2.29 and Definition 2.30 can be used in non-Euclidean


spaces, such as Banach or Hilbert spaces (Clarke et al., 2008). In
Euclidean spaces, the Clarke generalized gradients can be characterized
2.7. Generalized derivatives 51

more simply thanks to Rademacher’s theorem (Clarke et al., 2008,


Theorem 8.1). Namely, they can be defined as a convex combination of
limits of gradients of f evaluated at a sequence in E \ Ω that converges
to w (Proposition 2.12). In the following, the convex hull of a set S ⊆ E,
the set of convex combinations of elements of S, is denoted
m
conv(S) := {λ1 s1 + . . . + λm sm : m ∈ N, λi ≥ 0, λi = 1, si ∈ S}.
X

i=1

Proposition 2.12 (Characterization of Clarke generalized gradients).


Let f : E → R be a locally Lipschitz continuous and denote Ω the
set of points at which f is not differentiable (Proposition 2.11). An
element g ∈ E is a Clarke generalized gradient of f at w ∈ E if and
only if
 
g ∈ conv lim ∇f (vn ) : (vn )+∞
n=1 s.t. vn ∈ E \ Ω, vn → w .
n→+∞ n→+∞

The Jacobian of a function f : E → F between two Euclidean spaces


can be generalized similarly (Clarke et al., 2008, Section 3.3).

Definition 2.31 (Clarke generalized Jacobian). Let f : E → F be a


locally Lipschitz continuous and denote Ω the set of points at which
f is not differentiable (Proposition 2.11). A Clarke generalized
Jacobian of f at w ∈ E is an element J of
 
conv lim ∂f (vn ) : (vn )+∞
n=1 s.t. vn ∈ E \ Ω, vn → w .
n→+∞ n→+∞

For a continuously differentiable function f : E → F or f : E → R,


there is a unique generalized gradient, that is the usual gradient (Clarke
et al., 2008, Proposition 3.1, page 78). The chain-rule can be generalized
to these objects (Clarke et al., 2008). Recently, Bolte and Pauwels (2020)
and Bolte et al. (2022) further generalized Clarke gradients through the
definition of conservative gradients to define automatic differentiation
schemes for nonsmooth functions.
52 Differentiation

2.8 Summary

The usual definition of derivatives of real-valued univariate functions


extends to multivariate functions f : RP → R through the notion of
directional derivative ∂f (w)[v] at w ∈ RP in the direction v ∈ RP .
To take advantage of the representation of w = Pj=1 wj ej using the
P

canonical bases {e1 , . . . , eP }, the definition of differentiable functions


requires the linearity of the directional derivative w.r.t. the direction
v. This requirement gives rise to the notion of gradient ∇f (w) ∈ RP ,
the vector that gathers the partial derivatives and further defines the
steepest ascent direction at w. For vector-input, vector-output
functions f : RP → RM , the directional derivative leads to the definition
of Jacobian matrix ∂f (w) ∈ RM ×P , the matrix which gathers all
partial derivatives (notice that we use bold ∂). The chain rule is then
the product of Jacobian matrices.
These notions can be extended to general Euclidean spaces, such
as the spaces of matrices or tensors. For functions of the form f : E →
R, the gradient is ∇f (w) ∈ E. More generally, for functions of the
form f : E → F, the Jacobian ∂f (w) can be seen as a linear map
(notice the non-bold ∂). The directional derivative at w ∈ E naturally
defines a linear map l[v] = ∂f (w)[v], where ∂f (w) : E → F is called
the Jacobian vector product (JVP) and captures the infinitesimal
variation at w ∈ E along the input direction v ∈ E.
Its adjoint ∂f (w)∗ : F → E defines another linear map l[u] =
∂f (w)∗ [u] called the vector Jacobian product (VJP) and captures
the infinitesimal variation at w ∈ E along the output direction u ∈ F.
The chain rule is then the composition of these linear maps. For
the particular case when we compose a scalar-valued function ℓ (such
as a loss function) with a vector-valued function f (such as a network
function), the gradient is given by ∇(ℓ ◦ f )(w) = ∂f (w)∗ ∇ℓ(f (w)).
This is why being able to apply the adjoint to a gradient, which as
we shall see can be done with reverse-mode autodiff, is so pervasive in
machine learning.
The definitions of JVP and VJP operators can further be generalized
in the context of differentiable geometry. In that framework, the JVP
amounts to the pushforward operator that acts on tangent vectors.
2.8. Summary 53

The VJP amounts to the pullback operator that acts on cotangent


vectors.
We also saw that the Hessian matrix of a function f (w) from RP
to R is denoted ∇2 f (w) ∈ RP ×P . It is symmetric if the second partial
derivatives are continuous. Seen as linear map, the Hessian leads to
the notion of Hessian-vector product (HVP), which we saw can be
reduced to the JVP or the VJP of ∇f (w).
The main take-away message of this chapter is that computing the
directional derivative or the gradient of compositions of functions does
not require computing intermediate Jacobians but only to evaluate
linear maps (JVPs or VJPs) associated with these intermediate functions.
The goal of automatic differentiation, presented in Chapter 7, is precisely
to provide an efficient implementation of these maps for computation
chains or more generally for computation graphs.
3
Probabilistic learning

In this chapter, we review how to perform probabilistic learning. We


also introduce exponential family distributions, as they play a key role
in this book.

3.1 Probability distributions

3.1.1 Discrete probability distributions

A discrete probability distribution over a set Y is specified by its


probability mass function (PMF) p : Y → [0, 1]. The probability of
y ∈ Y is then defined by

P(Y = y) := p(y),

where Y denotes a random variable. When Y follows a distribution p,


we write Y ∼ p (with some abuse of notation, we use the same letter p
to denote the distribution and the PMF). The expectation of ϕ(Y ),
where Y ∼ p and ϕ : Y → RM , is then

E[ϕ(Y )] =
X
p(y)ϕ(y),
y∈Y

54
3.1. Probability distributions 55

its variance (for one-dimensional variables) is


V[ϕ(Y )] = E[(ϕ(Y ) − E[ϕ(Y )])2 ] = p(y)(ϕ(y) − E[ϕ(Y )])2
X

y∈Y

and its mode is


arg max p(y).
y∈Y
The Kullback-Leibler (KL) divergence (also known as relative entropy)
between two discrete distributions over Y, with associated PMFs p and
q, is the statistical “distance” defined by
p(y) p(Y )
   
KL(p, q) := p(y) log = EY ∼p log
X
.
y∈Y
q(y) q(Y )

3.1.2 Continuous probability distributions


A continuous probability distribution over Y is specified by its proba-
bility density function (PDF) p : Y → R+ . The probability of A ⊆ Y
is then Z
P(Y ∈ A) = p(y)dy.
A
The definitions of expectation, variance and KL divergence are defined
analogously to the discrete setting, simply replacing y∈Y with Y .
P R

Speficially, the expectation of ϕ(Y ) is


Z
E[ϕ(Y )] = p(y)ϕ(y)dy,
Y

the variance is
Z
V[ϕ(Y )] = E[(ϕ(Y ) − E[ϕ(Y )])2 ] = p(y)(ϕ(y) − E[ϕ(Y )])2 dy
Y

and the KL divergence is


p(y) p(Y )
Z    
KL(p, q) := p(y) log dy = EY ∼p log .
Y q(y) q(Y )
The mode is defined as the arg maximum of the PDF.
When Y = R, we can also define the cumulative distribution
function (CDF)
Z b
P(Y ≤ b) = p(y)dy.
−∞
56 Probabilistic learning

The probability of Y lying in the semi-closed interval (a, b] is then


P(a < Y ≤ b) = P(Y ≤ b) − P(Y ≤ a).

3.2 Maximum likelihood estimation

3.2.1 Negative log-likelihood


We saw that a probability distribution over Y is specified by p(y), which
is called the probability mass function (PMF) for discrete variables
or the probability density function (PDF) for continuous variables. In
practice, the true distribution p generating the data is unknown and we
wish to approximate it with a distribution pλ , with parameters λ ∈ Λ.
Given a finite set of observations y1 , . . . , yN , how do we fit λ ∈ Λ to
the data? This can be done by maximizing the likelihood of the data,
i.e., we seek to solve
N
b N := arg max pλ (yi ).
Y
λ
λ∈Λ i=1

This is known as maximum likelihood estimation (MLE). Because


the log function is monotonically increasing, this is equivalent to mini-
mizing the negative log-likelihood, i.e., we have
N
b N = arg min − log pλ (yi ).
X
λ
λ∈Λ i=1

Example 3.1 (MLE for the normal distribution). Suppose we set pλ


to the normal distribution with parameters λ = (µ, σ), i.e.,

1 1
2 !
y−µ

pλ (y) := √ exp − .
σ 2π 2 σ

Then, given observations y1 , . . . , yN , the MLE estimators for µ and


σ 2 are the sample mean and the sample variance, respectively.

3.2.2 Consistency w.r.t. the Kullback-Leibler divergence


It is well-known that the MLE estimator is consistent, in the sense of
the Kullback-Leibler divergence. That is, denoting the true distribution
3.3. Probabilistic supervised learning 57

p and
p(Y )
 
λ∞ := arg min KL(p, pλ ) = EY ∼p log ,
λ∈Λ pλ (Y )
then λ
b N → λ∞ in expectation over the observations, as N → ∞. This
can be seen by using
N N
1 X p(yi ) 1 X
 
KL(p, pλ ) ≈ log = log p(yi ) − log pλ (yi )
N i=1 pλ (yi ) N i=1
and the law of large numbers.

3.3 Probabilistic supervised learning

3.3.1 Conditional probability distributions


Many times in machine learning, instead of a probability P(Y = y) for
some y ∈ Y, we wish to define a conditional probability P(Y = y|X =
x), for some input x ∈ X . This can be achieved by reduction to an
unconditional probability distribution,
P(Y = y|X = x) := pλ (y)
where
λ := f (x, w)
and f is a model function with model parameters w ∈ W. That is,
rather than being a deterministic function from X to Y, f is a function
from X to Λ, the distribution parameters of the output distribution
associated with the input.
We emphasize that λ could be a single parameter or a collection of
parameters. For instance, in the Bernouilli distribution, λ = π, while in
the univariate normal distribution, λ = (µ, σ).
In Section 3.4 and throughout this book, we will also use the notation
pθ instead of pλ when θ are the canonical parameters of an exponential
family distribution (Section 3.4).

3.3.2 Inference
The main advantage of this probabilistic approach is that our prediction
model is much richer than if we just learned a function from X to Y.
58 Probabilistic learning

We now have access to the whole distribution over possible outcomes in


Y and can compute various statistics:
• Probability: P(Y = y|X = x) or P(Y ∈ A|X = x),

• Expectation: E[ϕ(Y )|X = x] for some function ϕ,

• Variance: V[ϕ(Y )|X = x],

• Mode: arg maxy∈Y pλ (y).


We now review probability distributions useful for binary classification,
multiclass classification, regression, multivariate regression, and integer
regression. In the following, to make the notation more lightweight, we
omit the dependence on x.

3.3.3 Binary classification


For binary outcomes, where Y = {0, 1}, we can use a Bernouilli
distribution with parameter
λ := π ∈ [0, 1].
When a random variable Y is distributed according to a Bernouilli
distribution with parameter π, we write
Y ∼ Bernouilli(π).
The PMF of this distribution is

π if y = 1
pλ (y) := .
1 − π if y = 0
The Bernouilli distribution is a binomial distribution with a single
trial. Since y ∈ {0, 1}, the PMF can be rewritten as
pλ (y) = π y (1 − π)1−y .
The mean is
E[Y ] = π = P(Y = 1)
and the variance is
V[Y ] = π(1 − π) = P(Y = 1)P(Y = 0).
3.3. Probabilistic supervised learning 59

Parameterization using a sigmoid

Since the parameter π of a Bernouilli distribution needs to belong to


[0, 1], we typically use a sigmoid function, such as a logistic function
as the output layer:

π := f (x, w) := logistic(g(x, w)),

where g : X × W → R is for example a neural network and


1
logistic(a) := ∈ (0, 1).
1 + exp(−a)

See also Section 4.4.3 for more details. When g is linear in w, this is
known as binary logistic regression.

Remark 3.1 (Link with the logistic distribution). The logistic distri-
bution with mean and scale parameters µ and σ is a continuous
probability distribution with PDF
u−µ
 
pµ,σ (u) := p0,1
σ
where
exp(−z)
p0,1 (z) := .
(1 + exp(−z))2
The corresponding CDF is
Z u
u−µ
 
P(U ≤ u) = pµ,σ (u)du = logistic .
−∞ σ
Therefore, if
U ∼ Logistic(µ, σ)
and
u−µ
  
Y ∼ Bernouilli logistic ,
σ
then
P(Y = 1) = P(U ≤ u).
Here, U can be interpreted as a latent continuous variable and u
as a threshold.
60 Probabilistic learning

3.3.4 Multiclass classification


For categorical outcomes with M possible choices, where Y = [M ],
we can use a categorical distribution with parameters

λ := π ∈ △M ,

where we define the probability simplex

△M := {π ∈ RM
+ : ⟨π, 1⟩ = 1},

i.e., the set of valid discrete probability distributions. When Y follows


a categorical distribution with parameter π, we write

Y ∼ Categorical(π).

The PMF of the categorical distribution is

pλ (y) := ⟨π, ϕ(y)⟩ = πy ,

where
ϕ(y) := ey
is the standard basis vector for the coordinate y ∈ [M ].
Since Y is a categorical variable, it does not make sense to compute
the expectation of Y but we can compute that of ϕ(Y ) = eY ,

EY ∼pλ [ϕ(Y )] = π.

That is, the mean and the probability distribution (represented by the
vector π) are the same in this case.

Parameterization using a softargmax


Since the parameter vector π of a categorical distribution needs to
belong to △M , we typically use a softargmax as the output layer:

π := f (x, w) := softargmax(g(x, w)),

where g : X × W → RM is for example a neural network and


exp(u)
softargmax(u) := P ∈ relint(△M ).
j exp(uj )
3.3. Probabilistic supervised learning 61

The output of the softargmax is in the relative interior of △M , relint(△M ) =


△M ∩ RM >0 . That is, the produced probabilities are always strictly posi-
tive. The categorical distribution is a multinomial distribution with
a single trial. When g is linear in w, this is therefore known as multi-
class or multinomial logistic regression, though strictly speaking a
multinomial distribution could use more than one trial.

3.3.5 Regression

For real outcomes, where Y = R, we can use, among other choices, a


normal distribution with parameters

λ := (µ, σ),

where µ ∈ R is the mean parameter and σ ∈ R+ is the standard


deviation parameter. The PDF is

1 1 (y − µ)2
!
pλ (y) := √ exp − .
σ 2π 2 σ2

The expectation is
EY ∼pλ [Y ] = µ.

One advantage of the probabilistic perspective, however, is that we are


not limited to predicting the mean. We can also compute the CDF

1 y−µ
  
P(Y ≤ y) = 1 + erf √ , (3.1)
2 σ 2

where we used the error function


2
Z z
2
erf(z) := √ e−t dt.
π 0

This function is available in most scientific computing libraries, such as


SciPy. From the CDF, we also easily obtain

1 b−µ a−µ
    
P(a < Y ≤ b) = erf √ − erf √ .
2 σ 2 σ 2
62 Probabilistic learning

Parameterization
Typically, in regression, the mean is output by a model, while the
standard deviation σ is kept fixed (typically set to 1). Since µ is uncon-
strained, we can simply set

µ := f (x, w) ∈ R,

where f : X × W → R is for example a neural network. That is, the


output of f is the mean of the distribution,

EY ∼pλ [Y ] = µ = f (x, w).

We can also use µ to predict P(Y ≤ y) or P(a < Y ≤ b), as shown


above.

3.3.6 Multivariate regression


More generally, for multivariate outcomes, where Y = RM , we can
use a multivariate normal distribution with parameters

λ := (µ, Σ),

where µ ∈ RM is the mean and Σ ∈ RM ×M is the covariance matrix.


The PDF is
1 1
 
pλ (y) := q exp − ⟨y − µ, Σ−1 (y − µ)⟩ .
2π M |Σ| 2

Using a diagonal covariance matrix is equivalent to using M independent


normal distributions for each Yj , for j ∈ [M ]. The expectation is

EY ∼pλ [Y ] = µ.

Parameterization
Typically, in multivariate regression, the mean is output by a model,
while the covariance matrix is kept fixed (typically set to the identity
matrix). Since µ is again unconstrained, we can simply set

µ := f (x, w) ∈ RM .
3.3. Probabilistic supervised learning 63

More generally, we can parametrize the function f so as to output both


the mean µ and the covariance matrix Σ, i.e.,
(µ, Σ) := f (x, w) ∈ RM × RM ×M .
The function f must be designed such that Σ is symmetric and positive
semi-definite. This is easy to achieve for instance by parametrizing
Σ = SS ⊤ for some matrix S.

3.3.7 Integer regression


For integer outcomes, where Y = N, we can use, among other choices,
a Poisson distribution with mean parameter λ > 0. The PMF is
λy exp(−λ)
pλ (y) := .
y!
The Poisson distribution implies that the index of dispersion (the
ratio between variance and mean) is 1, since
E[Y ] = V[Y ] = λ.
When this assumption is inappropriate, one can use generalized Poisson
distributions (Satterthwaite, 1942).

Parameterization using an exponential


Since the parameter λ of a Poisson distribution needs to be strictly
positive, we typically use an exponential function as output layer:
λ := f (x, w) := exp(g(x, w)) > 0,
where g : X × W → R.

3.3.8 Loss functions


In the conditional setting, we can now use maximum likelihood estima-
tion (MLE) to estimate the model parameters w ∈ W of f . Given a
set of input-output pairs (x1 , y1 ), . . . , (xN , yN ), we choose the model
parameters that maximize the likelihood of the data,
N
b N := arg max pλi (yi ),
Y
w
w∈W i=1
64 Probabilistic learning

where
λi := f (xi , w).
Again, this is equivalent to minimizing the negative log-likelihood,
N
b = arg min − log pλi (yi ).
X
w
w∈W i=1

Interestingly, MLE allows us to recover several popular loss functions.

• For the Bernouilli distribution with parameter λi = πi = logistic(g(xi , w)),


we have

− log pλi (yi ) = − [yi log πi + (1 − yi ) log(1 − πi )] ,

which is the binary logistic loss function.

• For the categorical distribution with parameters λi = πi =


softargmax(g(xi , w)), we have
M
− log pλi (yi ) = log exp(πi,j ) − πi,yi
X

j=1

= logsumexp(πi ) − ⟨πi , eyi ⟩,

which is the multiclass logistic loss function, also known as


cross-entropy loss.

• For the normal distribution with mean λi = µi = f (xi , w) and


fixed variance σi2 , we have
1 1 1
− log pλi (yi ) = (yi − µi )2 + log σi2 + log(2π),
σi2 2 2
which is, up to constant and with unit variance, the squared loss
function.

• For the Poisson distribution with mean λi exp(g(xi , w)), we have

− log pλi (yi ) = −(yi µi − exp(µi ) − log(yi !)),

which is the Poisson loss function.


3.4. Exponential family distributions 65

3.4 Exponential family distributions

3.4.1 Definition
The exponential family is a class of probability distributions, whose
PMF or PDF can be written in the form
h(y) exp [⟨θ, ϕ(y)⟩]
pθ (y) =
exp(A(θ))
= h(y) exp [⟨θ, ϕ(y)⟩ − A(θ)] ,
where θ are the natural or canonical parameters of the distribution.
The function h is known as the base measure. The function ϕ is the
sufficient statistic: it holds all the information about y and is used
to embed y in a vector space. The function A is the log-partition
or log-normalizer (see below for a details). All the distributions we
reviewed in Section 3.3 belong to the exponential family. With some
abuse of notation, we use pλ for the distribution in original form and
pθ for the distribution in exponential family form. As we will see,
we can go from θ to λ and vice-versa. We illustrate how to rewrite a
distribution in exponential family form below.
Example 3.2 (Bernouilli distribution). The PMF of the Bernouilli
distribution with parameter λ = π equals

pλ (y) := π y (1 − π)1−y
= exp(log(π y (1 − π)1−y ))
= exp(y log(π) + (1 − y) log(1 − π))
= exp(log(π/(1 − π))y + log(1 − π))
= exp(θy − log(1 + exp(θ)))
= exp(θy − softplus(θ))
=: pθ (y).

Therefore, Bernouilli distributions belong to the exponential family,


with natural parameter θ = logit(π) := log(π/(1 + π)).
We rewrite the previously-described distributions in exponential
family form in Table 3.1. This list is non-exhaustive: there are many
more distributions in the exponential family!
66 Probabilistic learning

Table 3.1: Examples of distributions in the exponential family.

Bernouilli Categorical
Y {0, 1} [M ]
λ π = logistic(θ) π = softmax(θ)
θ logit(π) log π + exp(A(θ))
ϕ(y) y ey
A(θ) softplus(θ) = − log(1 − π) logsumexp(θ)
h(y) 1 1

Normal (location only) Normal (location-scale)


Y R R
λ µ = θσ (µ, σ 2 ) = ( −θ
2θ2 , 2θ2 )
1 −1

µ
θ σ ( σµ2 , 2σ
−1
2)
y
ϕ(y) σ (y, y 2 )
θ2 µ2 −θ12 µ2
A(θ) 2 = 2σ 2 4θ2 − 12 log(−2θ2 ) = σ2
+ log σ
2
exp( −y2 )
h(y) √ 2σ √1
2πσ 2π

Multivariate normal Poisson


Y RM N
λ (µ, Σ) = (− 12 θ2−1 θ1 , − 12 θ2−1 ) λ = exp(θ)
θ (Σ−1 µ, − 21 Σ−1 ) log λ
ϕ(y) (y, yy ⊤ ) y
1 ⊤ −1
A(θ) − 4 θ1 θ2 − 12 log | − 2θ2 | exp(θ)
= 21 µ⊤ Σ−1 µ + 12 log |Σ|
h(y) (2π)−M/2 1/y!
3.4. Exponential family distributions 67

3.4.2 The log-partition function


The log-partition function A is the logarithm of the distribution’s
normalization factor. That is,

A(θ) := log h(y) exp [⟨θ, ϕ(y)⟩]


X

y∈Y

for discrete random variables and


Z
A(θ) := log h(y) exp [⟨θ, ϕ(y)⟩] dy
Y

for continuous random variables. We denote the set of valid parameters

Θ := {θ ∈ RM : A(θ) < +∞} ⊆ RM .

We can conveniently rewrite A(θ) for discrete random variables as

A(θ) = logsumexp(B(θ)) := log [B(θ)]y ,


X

y∈Y

and similarly for continuous variables. Here, we defined the linear map

B(θ) := (⟨θ, ϕ(y)⟩ + log h(y))y∈Y .

Since A(θ) is the composition of logsumexp, a convex function, and of


B, a linear map, we immediately obtain the following proposition.

Proposition 3.1 (Convexity of the log-partition). A(θ) is a convex


function.

A major property of the log-partition function is that its gradient


coincides with the expectation of ϕ(Y ) according to pθ .

Proposition 3.2 (Gradient of the log-partition).

µ(θ) := ∇A(θ) = EY ∼pθ [ϕ(Y )] ∈ M.

Proof. The result follows directly from

∇A(θ) = ∂B(θ)∗ ∇logsumexp(B(θ)) = (ϕ(y))y∈Y softmax(B(θ)).


68 Probabilistic learning

The gradient ∇A(θ) is therefore often called the mean function.


The set of achievable means µ(θ) is defined by
M := conv(ϕ(Y)) := {Ep [ϕ(Y )] : p ∈ P(Y)},
where conv(S) is the convex hull of S and P(Y) is the set of valid
probability distributions over Y.
Similarly, the Hessian ∇2 A(θ) coincides with the covariance matrix
of ϕ(Y ) according to pθ (Wainwright and Jordan, 2008, Chapter 3).
When the exponential family is minimal, which means that the
parameters θ uniquely identify the distribution, it is known that ∇A
is a one-to-one mapping from Θ to M. That is, µ(θ) = ∇A(θ) and
θ = (∇A)−1 (µ(θ)).

3.4.3 Maximum entropy principle


Suppose we observe the empirical mean µ b = N 1
i=1 ϕ(yi ) ∈ M of
n
P

some observations y1 , . . . , yN . How do we find a probability distribution


achieving this mean? Clearly, such a distribution may not be unique.
One way to choose among all possible distributions is by using the
maximum entropy principle. Let us define the Shannon entropy by
H(p) := − p(y) log p(y)
X

y∈Y

for discrete variables and by


Z
H(p) := − p(y) log p(y)dy
Y

for continuous variables. This captures the level of “uncertainty” in


p, i.e., it is maximized when the distribution is uniform. Then, the
maximum entropy distribution satisfying the first-order moment
condition (i.e., whose expectation matches the empirical mean) is
p⋆ := arg max H(p) s.t. EY ∼p [ϕ(Y )] = µ.
b
p∈P(Y)

It can be shown that the maximum entropy distribution is necessarily


in the exponential family with sufficient statistics defined by ϕ and its
canonical parameters θ coincide with the Lagrange multipliers of
the above constraint (Wainwright and Jordan, 2008, Section 3.1).
3.4. Exponential family distributions 69

3.4.4 Maximum likelihood estimation


Similarly as in Section 3.2, to fit the parameters θ ∈ Θ of an exponential
family distribution to some observations y1 , . . . , yN , we can use the
MLE principle, i.e.,
N N
θb = arg max pθ (yi ) = arg min − log pθ (yi ).
Y X
θ∈Θ i=1 θ∈Θ i=1

Fortunately, for exponential family distributions, the log probabil-


ity/density enjoys a particularly simple form.

Proposition 3.3 (Negative log-likelihood). The negative log-likelihood


of an exponential family distribution is

− log pθ (y) = A(θ) − ⟨θ, ϕ(y)⟩ − log h(y).

Its gradient is

−∇θ log pθ (y) = ∇A(θ) − ϕ(y) = EY ∼pθ [ϕ(Y )] − ϕ(y)

and its Hessian is

−∇2θ log pθ (y) = ∇2 A(θ),

which is independent of y.

It follows from Proposition 3.1 that θ 7→ − log pθ (y) is convex.


Interestingly, we see that the gradient is the residual between the
expectation according to the model and the observation.

3.4.5 Probabilistic learning with exponential families


In the supervised probabilistic learning setting, we wish to estimate a
conditional distribution of the form pw (y | x). Given a model function
f , such as a neural network, a common approach for defining such a
conditional distribution is by reduction to the unconditional setting,
pw (y | x) := pθ (y) where θ := f (x, w).
In other words, the role of f is to produce the parameters of pθ given
some input x. It is a function from X × W to Θ. Note that f must be
70 Probabilistic learning

designed such that it produces an output in


Θ := {θ ∈ RM : A(θ) < +∞}.
For instance, as we previously discussed, for a multivariate normal
distribution, where θi = (µi , Σi ) = f (xi , w), we need to ensure that Σi
is a positive semidefinite matrix.
Given input-output pairs {(xi , yi )}N i=1 , we then seek to find the
parameters w of f (x, w) by minimizing
N N
arg min − log pθi (yi ) = arg min A(θi ) − ⟨θi , ϕ(yi )⟩
X X
w∈W i=1 w∈W i=1
where θi := f (xi , w). While − log pθ (y) is a convex function of θ for
exponential family distributions, we emphasize that − log pf (x,w) (y) is
typically a nonconvex function of w, when f is a nonlinear function,
such as a neural network.

Inference
Once we found w by minimizing the objective function above, there are
several possible strategies to perform inference for a new input x.
• Expectation. When the goal is to compute the expectation of
ϕ(Y ), we can use ∇A(f (x, w)). That is, we compute the distribu-
tion parameters associated with x by θ = f (x, w) and then we
compute the mean by µ = ∇A(θ). When f is linear in w, the
composition ∇A ◦ f is called a generalized linear model.

• Probability. When the goal is to compute the probability of a


certain y, we can compute the distribution parameters associated
with x by θ = f (x, w) and then we can compute P(Y = y|X =
x) = pθ (y). In the particular case of the categorical distribution
(of which the Bernouilli distribution is a special case), we point
out again that the mean and the probability vector coincide:
µ = p = ∇A(θ) = softargmax(θ) ∈ △M .

• Other statistics. When the goal is to compute other quantities,


such as the variance or the CDF, we can convert the natural pa-
rameters θ to the original distribution parameters λ (see Table 3.1
3.5. Summary 71

for examples). Then, we can use established formulas for the


distribution in original form, to compute the desired quantities.

3.5 Summary

In this chapter, we started by reviewing discrete and continuous


probability distributions. We saw how to fit distribution parameters to
data using the maximum likelihood estimation (MLE) principle and
saw its connection with the Kullback-Leibler divergence. Instead of
designing a model function from the input space X to the output space
Y, we saw that we can perform probabilistic supervised learning
by designing a model function from X to distribution parameters Λ.
Leveraging the so-obtained parametric conditional distribution then
allowed us to compute, not only output probabilities, but also various
statistics such as the mean and the variance of the outputs. Finally,
we reviewed the exponential family, a principled generalization of
numerous distributions, which we saw is tightly connected with the
maximum entropy principle. Importantly, the approaches described
in this chapter produce perfectly valid computation graphs, mean-
ing that we can combine them with neural networks and we can use
automatic differentiation, to compute their derivatives.
Part II

Differentiable programs
4
Parameterized programs

Neural networks can be thought as parameterized programs: programs


with learnable parameters. In this chapter, we begin by reviewing how to
represent programs mathematically. We then review several key neural
network architectures and components.

4.1 Representing computer programs

4.1.1 Computation chains

To begin with, we consider simple programs that apply a sequence of


functions f1 , . . . , fK to an input s0 ∈ S0 . We call such programs com-
putation chains. For example, an image may go through a sequence
of transformations such as cropping, rotation, normalization, and so on.
In neural networks, the transformations are typically parametrized and
the parameters are learned, leading to feedforward networks, presented
in Section 4.2. Another example of sequence of functions is a for loop,
presented in Section 5.8.

73
74 Parameterized programs

... ...

Figure 4.1: A computation chain is a sequence of function compositions. In the


graph above, each intermediate node represents a single function. The first node
represents the input, the last node the output. Edges represent the dependencies of
the functions with respect to previous outputs or to the initial input.

Formally, a computation chain can be written as

s0 ∈ S0
s1 := f1 (s0 ) ∈ S1
..
.
sK := fK (sK−1 ) ∈ SK
f (s0 ) := sK . (4.1)

Here, s0 is the input, sk ∈ Sk is an intermediate state of the program,


and sK ∈ SK is the final output. Of course, the domain (input space)
of fk must be compatible with the image (output space) of fk−1 . That is,
we should have fk : Sk−1 → Sk . We can write the program equivalently
as

f (s0 ) = (fK ◦ · · · ◦ f2 ◦ f1 )(s0 )


= fK (. . . f2 (f1 (s0 ))).

A computation chain can be represented by a directed graph, shown


in Fig. 4.1. The edges in the chain define a total order. The order is
total, since two nodes are necessarily linked to each other by a path.

4.1.2 Directed acylic graphs


In generic programs, intermediate functions may depend, not only on
the previous function output, but on the outputs of several different
functions. Such dependences are best expressed using graphs.
A directed graph G = (V, E) is defined by a set of vertices or
nodes V and a set of edges E defining directed dependencies between
4.1. Representing computer programs 75

Figure 4.2: Example of a directed acyclic graph. Here the nodes are V =
{0, 1, 2, 3, 4}, the edges are E = {(0, 1), (0, 2), (0, 3), (1, 3), (2, 3), (1, 4), (3, 4)}. Parents
of the node 3 are parents(3) = {0, 1, 2}. Children of node 1 are children(1) = {3, 4}.
There is a unique root, 0, and a unique leaf, 4; 0 → 3 → 4 is a path from 0 to 4. This
is an acyclic graph since there is no cycle (i.e., a path from a node to itself). We can
order nodes 0 and 3 as 0 ≤ 3 since there is no path from 3 to 0. Similarly, we can
order 1 and 2 as 1 ≤ 2 since there is no path from 2 to 1. Two possible topological
orders of the nodes are (0, 1, 2, 3, 4) and (0, 2, 1, 3, 4).

vertices. An edge (i, j) ∈ E is an ordered pair of vertices i ∈ V and


j ∈ V. It is also denoted i → j, to indicate that j depends on i. For
representing inputs and outputs, it will be convenient to use incoming
half-edges → j and outgoing half-edges i →.

In a graph G = (V, E), the parents of a vertex j is the set of nodes


pointing to j, denoted parents(j) := {i : i → j}. The children of
a vertex i is the set of nodes i is pointing to, that is, children(i) :=
{j : i → j}. Vertices without parents are called roots and vertices
without children are called leaves.

A path from i to j is defined by a sequence of vertices j1 , . . . , jm ,


potentially empty, such that i → j1 → . . . → jm → j. An acyclic graph
is a graph such that there exists no vertex i with a path from i to i. A
directed acyclic graph (DAG) is a graph that is both directed and
acyclic.

The edges of a DAG define a partial order of the vertices, denoted


i ⪯ j if there exists a path from i to j. The order is partial, since
two vertices may not necessarily be linked to each other by a path.
Nevertheless, we can define a total order called a topological order:
any order such that i ≤ j if and only if there is no path from j to i.
76 Parameterized programs


Figure 4.3: Representation of f (x1 , x2 ) = x2 ex1 x1 + x2 ex1 as a DAG, with
functions as nodes and variables as edges. The function is decomposed as 8 functions
in topological order.

4.1.3 Computer programs as DAGs


We assume that a program defines a mathematically valid function (a.k.a.
pure function): the program should return identical values for identical
arguments and should not have any side effects. We also assume that the
program halts, i.e., that it terminates in a finite number of steps. As
such a program is made of a finite number of intermediate functions and
intermediate variables, the dependencies between functions and variables
can be expressed using a directed acyclic graph (DAG). Without loss of
generality, we make the following simplifying assumptions:
1. There is a single input s0 ∈ S0 .

2. There is a single output sK ∈ SK .

3. Each intermediate function fk in the program outputs a single


variable sk ∈ Sk .
We number the nodes as V := {0, 1, . . . , K}. Node 0 is the root, corre-
sponding to the input s0 ∈ S0 . Node K is the leaf, corresponding to the
final output sK ∈ SK . Because of the third assumption above, apart
from s0 , each variable sk is in bijection with a function fk . Therefore,
node 0 represents the input s0 , and nodes 1, . . . , K represent both a
function fk and an output variable sk .
Edges in the DAG represent dependences. The parents i1 , . . . , ipk :=
parents(k) of node k, where pk := |parents(k)|, indicate the variables
sparents (k) := si1 , . . . , sipk that the function fk needs to perform its com-
putation. Put differently, the parents i1 , . . . , ipk indicate the functions
fi1 , . . . , fipk that need to be evaluated, prior to evaluating fk .
4.1. Representing computer programs 77

Algorithm 4.1 Executing a program


Functions: f1 , . . . , fK in topological order
Input: input s0 ∈ S0
1: for k := 1, . . . , K do
2: Retrieve parent nodes i1 , . . . , ipk := parents(k)
3: Compute sk := fk (sparents(k) ) := fk (si1 , . . . , sipk )
4: Output: f (s0 ) := sK

An example of computation graph in our formalism is presented


in Fig. 4.3.

Executing a program

To execute a program, we need to ensure that we evaluate the intermedi-


ate functions in the correct order. Therefore, we assume that the nodes
0, 1, . . . , K are in a topological order (if this is not the case, we need
to perform a topological sort first). We can then execute a program by
evaluating for k ∈ [K]

sk := fk (sparents(k) ) := fk (si1 , . . . , sipk ) ∈ Sk .

Note that we can either view fk as a single-input function of sparents(k) ,


which is a tuple of elements, or as a multi-input function of si1 , . . . , sipk .
The two views are essentially equivalent.
The procedure for executing a program is summarized in Algo-
rithm 4.1.

Dealing with multiple program inputs or outputs

When a program has multiple inputs, we can always group them into
s0 ∈ S0 as s0 = (s0,1 , . . . , s0,N0 ) with S0 = (S0,1 × · · · × S0,N0 ), since
later functions can always filter out what elements of s0 they need.
Likewise, if an intermediate function fk has multiple outputs, we can
always group them as a single output sk = (sk,1 , . . . , sk,Nk ) with Sk =
(Sk,1 × · · · × Sk,Nk ), since later functions can filter out the elements of
sk that they need.
78 Parameterized programs

Figure 4.4: Two possible representations of a program. Left: Functions and output
variables are represented by the same nodes. Right: functions and variables are
represented by a disjoint set of nodes.

Alternative representation: bipartite graphs


In our formalism, because a function fk always have a single output
sk , a node k can be seen as representing both the variable sk and
the function fk . Alternatively, as shown in Fig. 4.4, we can represent
variables and functions as separate nodes, that is, using a bipartite
graph. This formalism is akin to factor graphs (Frey et al., 1997;
Loeliger, 2004) used in probabilistic modeling, but with directed edges.
One advantage of this formalism is that is allows functions to explicitly
have multiple outputs. We focus on our formalism for simplicity.

4.1.4 Arithmetic circuits


Arithmetic circuits are one of the simplest examples of computation
graph, originating from computational complexity theory. Formally,
an arithmetic circuit over a field F, such as the reals R, is a directed
acyclic graph (DAG) whose root nodes are elements of F and whose
functions fk are either + or ×. The latter are often called gates.
Contrary to the general computation graph case, because each fk is
either + or ×, it is important to allow the graph to have several root
nodes. Root nodes can be either variables or constants, and should
belong to F.
Arithmetic circuits can be used to compute polynomials. There
are potentially multiple arithmetic circuits for representing a given
polynomial. One important question is then to find the most efficient
arithmetic circuit for computing a given polynomial. To compare arith-
metic circuits representing the same polynomial, an intuitive notion of
complexity is the circuit size, as defined below.
4.2. Feedforward networks 79

Definition 4.1 (Circuit and polynomial sizes). The size S(C) of a


circuit C is the number of edges in the directed acyclic graph
representing C. The size S(f ) a polynomial f is the smallest S(C)
among all C representing f .

For more information on arithmetic circuits, we refer the reader to


the excellent book of Chen et al. (2011).

4.2 Feedforward networks

A feedforward network can be seen as a computation chain with pa-


rameterized functions fk ,
s0 := x
s1 := f1 (s0 , w1 )
s2 := f2 (s1 , w2 )
..
.
sK := fK (sK−1 , wK ),
for a given input x ∈ X and learnable parameters w1 , . . . , wK ∈
W1 × · · · × WK . Each function fk is called a layer and each sk ∈ Sk
can be seen as an intermediate representation of the input x. The
dimensionality of Sk is known as the width (or number of hidden units)
of layer k. A feedforward network defines a function sK =: f (x, w) from
X × W to SK , where w := (w1 , . . . , wK ) ∈ W := W1 × . . . × WK .
Given such a parameterized program, we can learn the parameters
by adjusting w to fit some data. For instance, given a dataset of (xi , yi )
pairs, we may minimize the squared loss ∥yi −f (xi , w)∥22 on average over
the data, w.r.t. w. The minimization of such a loss requires accessing
its gradients with respect to w.

4.3 Multilayer perceptrons

4.3.1 Combining affine layers and activations


In the previous section, we did not specify how to parametrize the
feedforward network. A typical parametrization, called the multilayer
80 Parameterized programs

perceptron (MLP), uses fully-connected (also called dense) layers of


the form
sk = fk (sk−1 , wk ) := ak (Wk sk−1 + bk ),
where we defined the tuple wk := (Wk , bk ) and where we assumed that
Wk and bk are a matrix and vector of appropriate size. We can further
decompose the layer into two functions. The function s 7→ Wk s + bk
is called an affine layer. The function v 7→ ak (v) is a parameter-free
nonlinearity, often called an activation function (see Section 4.4).
More generally, we may replace the matrix-vector product Wk sk−1
by any parametrized linear function of sk−1 . For example, convolu-
tional layers use the convolution of an input sk−1 with some filters
Wk .
Remark 4.1 (Dealing with multiple inputs). Sometimes, it is neces-
sary to deal with networks of multiple inputs. For example, sup-
pose we want to design a function g(x1 , x2 , wg ), where x1 ∈ X1
and x2 ∈ X2 . A simple way to do so is to use the concatenation
x := x1 ⊕ x2 ∈ X2 ⊕ X2 as input to a network f (x, wf ). Alterna-
tively, instead of concatenating x1 and x2 at the input layer, they
can be concatenated after having been through one or more hidden
layers.

4.3.2 Link with generalized linear models


When the depth is K = 1 (only one layer), the output of an MLP is
s1 = a1 (W1 x + b1 ).
This is called a generalized linear model (GLM); see Section 3.4.
Therefore, MLPs include GLMs as a special case. In particular, when
a1 is the softargmax (see Section 4.4), we obtain (multiclass) logistic
regression. For general depth K, the output of an MLP is
sK = aK (WK sK−1 + bK ).
This can be seen as a GLM on top of learned representation sK−1
of the input x. This is the main appeal of MLPs: they learn the feature
representation and the output model at the same time! We will see that
MLPs can also be used as subcomponents in other architectures.
4.4. Activation functions 81

4.4 Activation functions

As we saw in Section 4.3, feedforward networks typically use an acti-


vation function ak at each layer. In this section, we describe various
nonlinearities from scalar to scalar or from vector to scalar. We also
present probability mappings that can be used as such activations.

4.4.1 Scalar-to-scalar nonlinearities


The ReLU (rectified linear unit) is a popular nonlinearity defined as
the non-negative part of its input

u, u≥0
relu(u) := max(u, 0) = .
0, u<0
It is a piecewise linear function and includes a kink at u = 0. A multilayer
perceptron with ReLU activations is called a rectifier neural network.
The layers take the form
sk = relu(Ak sk−1 + bk ),
where the ReLU is applied element-wise. The ReLU can be replaced
with a smooth approximation (i.e., without kinks), called the softplus
softplus(u) := log(1 + eu ).
Unlike the ReLU, it is always strictly positive. Other smoothed variants
of the ReLU are possible, see Section 12.4.1.

4.4.2 Vector-to-scalar nonlinearities


It is often useful to reduce vectors to a scalar value. This scalar value
can be seen as a “statistic”, summarizing the vector. A simple way to
do so is to use the maximum value, also known as max pooling. Given
a vector u ∈ RM , it is defined as
u 7→ max uj .
j∈[M ]

As a smooth approximation of it, we can use the log-sum-exp function,


which is known to behave like a “soft” maximum
M
logsumexp(u) := softmax(u) := log
X
euj .
j=1
82 Parameterized programs

The log-sum-exp can be seen as a generalization of the softplus, as we


have for all u ∈ R

logsumexp((u, 0)) = softplus(u).

A numerically stable implementation of the log-sum-exp is given by

logsumexp(u) = logsumexp(u − c 1) + c,

where c := maxj∈[M ] uj .

4.4.3 Scalar-to-scalar probability mappings


Oftentimes we want to map some real value to a number between 0 and
1, that can represent the probability of an event. For that purpose, we
generally use sigmoids. A sigmoid is a function with a characteristic
“S”-shaped curve. These functions are used to squash real values to
some interval. An example is the binary step function, also known as
Heaviside step function,

1, u≥0
step(u) := .
0, u<0

It is a mapping from R to {0, 1}. Unfortunately, it has a discontinuity: a


jump in its graph at u = 0. Moreover, because the function is constant
at all other points, it has zero derivative at these points, which makes it
difficult to use as part of a neural network trained with backpropagation.
A better sigmoid is the logistic function, which is a mapping from
R to (0, 1) and is defined as
1
logistic(u) :=
1 + e−u
eu
=
1 + eu
1 1 u
 
= + tanh .
2 2 2
It maps (−∞, 0) to (0, 0.5),[0, +∞) to [0.5, 1) and it satisfies logistic(0) =
0.5. It can therefore be seen as mapping from real values to probability
values. The logistic can be seen as a differentiable approximation to the
4.4. Activation functions 83

discontinuous binary step function step(u). The logistic function can


be shown to be the derivative of softplus, i.e., for all u ∈ R

softplus′ (u) = logistic(u).

Two important properties of the logistic function are that for all u ∈ R

logistic(−u) = 1 − logistic(u)

and

logistic′ (u) = logistic(u) · logistic(−u)


= logistic(u) · (1 − logistic(u)).

Other sigmoids are possible; see Section 12.4.3.

4.4.4 Vector-to-vector probability mappings


It is often useful to transform a real vector into a vector of probabilities.
This is a mapping from RM to the probability simplex, defined by
 
 M 
△M := π ∈ RM : ∀j ∈ [M ], πj ≥ 0, πj = 1 .
X
 
j=1

One such mapping is the argmax operator, defined by


!
argmax(u) := ϕ arg max uj ,
j∈[M ]

where ϕ(j) denotes the one-hot encoding of an integer j ∈ [M ], that is,

ϕ(j) := (0, . . . , 0, |{z}


1 , 0, . . . , 0) = ej ∈ {0, 1}M .
j

This mapping puts all the probability mass onto a single coordinate
(in case of ties, we pick a single coordinate arbitrarily). Unfortunately,
this mapping is a discontinuous function. As a differentiable everywhere
relaxation, we can use the softargmax defined by
exp(u)
softargmax(u) := PM .
j=1 exp(uj )
84 Parameterized programs

This operator is commonly known in the literature as softmax but this


is a misnomer: this operator really defines a differentiable relaxation
of the argmax. The output of the softargmax belongs to the relative
interior of the probability simplex, meaning that it can never reach the
borders of the simplex. If we denote π = softargmax(u), this means
that πj ∈ (0, 1), that is, πj can never be exactly 0 or 1. The softargmax
is the gradient of log-sum-exp,

∇logsumexp(u) = softargmax(u).

The softargmax can be seen as a generalization of the logistic function,


as we have for all u ∈ R

[softargmax((u, 0))]1 = logistic(u).

Remark 4.2 (Degrees of freedom and invertibility of softargmax). The


softargmax operator satisfies the property for all u ∈ RM and c ∈ R

π := softargmax(u) = softargmax(u + c 1).

This means that the softargmax operator has M − 1 degrees of


freedom and is a non-invertible function. However, due to the
above property, without loss of generality, we can impose u⊤ 1 =
i=1 ui = 0 (if this is not the case, we simply do ui ← ui − ū,
PM

where ū := M1 PM
j=1 uj ). Using this constraint together with

M
log πi = ui − log exp(uj ),
X

j=1

we then obtain
M M
log πi = −M log exp(uj )
X X

i=1 j=1

so that
M
1 X
ui = [softargmax−1 (π)]i = log πi − log πj .
M j=1
4.5. Residual neural networks 85

4.5 Residual neural networks

We now discuss another feedforward network parametrization: residual


neural networks. Consider a feedforward network with K + 1 layers
f1 , . . . , fK , fK+1 . Surely, as long as fK+1 can exactly represent the
identity function, the set of functions that this feedforward network
can express should be a superset of the functions that f1 , . . . , fK can
express. In other words, depth should in theory not hurt the expressive
power of feedforward networks. Unfortunately, the assumption that each
fk can exactly represent the identity function may not hold in practice.
This means that deeper networks can sometimes be more difficult to
train than shallower ones, making the accuracy saturate or degrade as
a function of depth.
The key idea of residual neural networks (He et al., 2016) is to design
layers fk , called residual blocks, that make it easier to represent the
identity function. Formally, a residual block takes the form

sk = fk (sk−1 , wk ) := sk−1 + hk (sk−1 , wk ).

The function hk is called residual, since it models the difference sk −


sk−1 . The addition with sk−1 is often called a skip connection. As
long as it is easy to adjust wk so that hk (sk−1 , wk ) = 0, fk can freely
become the identity function. For instance, if we use

hk (sk−1 , wk ) := Ct ak (Wk sk−1 + bk ) + dk ,

where wk = (Wk , bk , Ck , dk ), it suffices to set Ck and dk to a zero


matrix and vector. Residual blocks are known to remedy the so-called
vanishing gradient problem.
Many papers and software packages include an additional activation
and instead define the residual block as

sk = fk (sk−1 , wk ) := ak (sk−1 + hk (sk−1 , wk )),

where ak is typically chosen to be the ReLU activation. Whether to


include this additional activation or not is essentially a modelling choice.
In practice, residual blocks may also include additional operations such
as batch norm and convolutional layers.
86 Parameterized programs

4.6 Recurrent neural networks

Recurrent neural networks (RNNs) are a class of neural networks that


operate on sequences of vectors, either as input, output or both. Their
actual parametrization depends on the setup but the core idea is to
maintain a state vector that is updated from step to step by a recursive
function that uses shared parameters across steps. Unrolling this
recursion defines a valid computational graph, as we will see in Chapter 7.
We distinguish between the following setups illustrated in Fig. 4.5:

• Vector to sequence: f d : RD × RP → RL×M

• Sequence to vector: f e : RL×D × RP → RM

• Sequence to sequence (aligned): f a : RL×D × RP → RL×M



• Sequence to sequence (unaligned): f u : RL×D × RP → RL ×M

where L stands for length. Note that we use the same number of
parameters P for each setup for notational convenience, but this of
course does not need to be the case. Throughout this section, we use
the notation p1:L := (p1 , . . . , pL ) for a sequence of L vectors.

4.6.1 Vector to sequence


In this setting, we define a decoder function p1:L = f d (x, w) from an
input vector x ∈ RD and parameters w ∈ RP to an output sequence
p1:L ∈ RL×M . This is for instance useful for image caption generation,
where a sentence (a sequence of word embeddings) is generated from
an image (a vector of pixels). Formally, we may define p1:L := f d (x, w)
through the recursion

zl := g(x, zl−1 , wg ) l ∈ [L]


pl := h(sl , wh ) l ∈ [L].

where w := (wg , wh , s0 ). The goal of g is to update the current decoder


state zl given the input x, and the previous decoder state zl−1 . The
goal of h is to generate the output pl given the current decoder state
zl . Importantly, the parameters of g and h are shared across steps.
4.6. Recurrent neural networks 87

(a) One to many (decoder) (b) Many to one (encoder)

(c) Sequence to sequence aligned (d) Sequence to sequence unaligned

Figure 4.5: Recurrent neural network architectures


88 Parameterized programs

Typically, g and h are parametrized using one-hidden-layer MLPs. Note


that g has multiple inputs; we discuss how to deal with such cases in
Section 4.3.

4.6.2 Sequence to vector


In this setting, we define an encoder function p = f e (x1:L , w) from an
input sequence x1:L ∈ RL×D and parameters w ∈ RP to an output
vector p ∈ RM . This is for instance useful for sequence classification,
such as sentiment analysis. Formally, we may define p := f e (x1:L , w)
using the recursion

sl := γ(xl , sl−1 , wg ) l ∈ [L]


p = pooling(s1:L )

where w := (wg , s0 ). The goal of γ is similar as g, except that it updates


encoder states and does not take previous predictions as input. The
pooling function is typically parameter-less. Its goal is to reduce a
sequence to a vector. Examples include using the last state, the average
of states and the coordinate-wise maximum of states.

4.6.3 Sequence to sequence (aligned)


In this setting, we define a function p1:L = f a (x1:L , w) from an in-
put sequence x1:L ∈ RL×D and parameters w ∈ RP to an output
sequence p1:L ∈ RL×M , which we know to be of the same length.
An example of application is part-of-speech tagging, where the goal is
to assign each word xl to a part-of-speech (noun, verb, adjective, etc).
Formally, we may define p1:L = f a (x1:L , w) as

sl := γ(xl , sl−1 , wγ ) l ∈ [L]


pl = h(sl , wh ) l ∈ [L]

where w := (wγ , wh , s0 ). The function γ and h are similar as before.

4.6.4 Sequence to sequence (unaligned)


In this setting, we define a function p1:L′ = f u (x1:L , w) from an input
sequence x1:L ∈ RL×D and parameters w ∈ RP to an output sequence
4.7. Summary 89


p1:L′ ∈ RL ×M , which potentially has a different length. An example
of application is machine translation, where the sentences in the source
and target languages do not necessarily have the same length. Typically,
p1:L′ = f u (x1:L , w) is defined as the following two steps
c := f e (x1:L , we )
p1:L′ := f d (c, wd )
where w := (we , wd ), and where we reused the previously-defined
encoder fe and decoder fd . Putting the two steps together, we obtain
sl := γ(xl , sl−1 , wγ ) l ∈ [L]
c = pooling(s1:L )
zl := g(c, pl−1 , zl−1 , wg ) l ∈ [L′ ]
pl := h(zl , wh ) l ∈ [L′ ].
This architecture is aptly named the encoder-decoder architecture.
Note that we denoted the length of the target sequence as L′ . However,
in practice, the target length can be input dependent and is often not
known ahead of time. To deal with this issue, the vocabulary (of size D
is our notation) is typically augmented with an “end of sequence” (EOS)
token so that, at inference time, we know when to stop generating the
output sequence. One disadvantage of this encoder-decoder architecture,
however, is that all the information about the input sequence is contained
in the context vector c, which can therefore become a bottleneck. For
this reason, this architecture has been largely replaced with attention
mechanisms and transformers, which we study in the next sections.

4.7 Summary

In this chapter, we reviewed how to mathematically represent a program


as a directed acyclic graph. We then presented various popular neural
network architectures, that we view as parameterized programs.
For instance, the feedfoward network can be seen as a parameterized
computation chain. Multilayer perceptrons (MLPs), residual neural
networks (ResNets) and convolutional neural network (CNNs) are par-
ticular parametrizations of this class of network. Finally, we reviewed
recurrent neural networks, for working with sequential data.
5
Control flows

Control flows, such as conditionals or loops, are an essential part of


computer programming, as they allow us to express complex programs.
It is therefore natural to ask whether these constructs can be included
in a differentiable program. This is what we study in this chapter.

5.1 Comparison operators

Control flows rely on comparison operators, a.k.a. relational oper-


ators. Formally, we can define a comparison operator π = op(u1 , u2 )
as a function from u1 ∈ R and u2 ∈ R to π ∈ {0, 1}. The binary
(Boolean) output π can then be used within a conditional statement
(see Section 5.6, Section 5.7) to decide whether to execute one branch
or another. We define the following operators:

• greater than:

1 if u1 ≥ u2
gt(u1 , u2 ) :=
0 otherwise
= step(u1 − u2 )

90
5.1. Comparison operators 91

• less than:

1 if u1 ≤ u2
lt(u1 , u2 ) :=
0 otherwise
= 1 − gt(u1 , u2 )
= step(u2 − u1 )

• equal:

1 if |u1 − u2 | = 0
eq(u1 , u2 ) :=
0 otherwise
= gt(u2 , u1 ) · gt(u1 , u2 )
= step(u2 − u1 ) · step(u1 − u2 )

• not equal:

1 if |u1 − u2 | > 0
neq(u1 , u2 ) :=
0 otherwise
= 1 − eq(u1 , u2 )
= 1 − step(u2 − u2 ) · step(u1 − u2 ),

where step : R → {0, 1} is the Heaviside step function



1 if u ≥ 0
step(u) := .
0 otherwise

The Heaviside step function is piecewise constant. At u = 0, the function


is discontinuous. At u ̸= 0, it is continuous and has null derivative.
Since the comparison operators we presented are all expressed in terms
of the step function, they are all continuous and differentiable almost
everywhere, with null derivative. Therefore, while their derivatives are
well-defined almost everywhere, they are uninformative and prevent
gradient backpropagation.
92 Control flows

5.2 Soft inequality operators

5.2.1 Heuristic definition


To obtain a continuous relaxation of inequality operators, we can
heuristically replace the step function in the expression of greater than
and less than by a sigmoid, such as the logistic function or the Gaussian’s
CDF. We then obtain
gt(µ1 , µ2 ) = step(µ1 − µ2 ) ≈ sigmoid(µ1 − µ2 )
lt(µ1 , µ2 ) = step(µ2 − µ1 ) ≈ sigmoid(µ2 − µ1 ) = 1 − sigmoid(µ1 − µ2 ).
In the limit, we have that sigmoid(µ1 − µ2 ) → 1 when µ1 − µ2 → ∞. In
the limit, a sigmoid therefore outputs a probability of 1 if µ1 and µ2 are
infinitely apart. Besides the logistic function and the Gaussian’s CDF,
other sigmoid functions are possible, as discussed in Section 12.4.3. In
particular, with sparse sigmoids, there exists a finite value τ such that
µ1 − µ2 ≥ τ =⇒ sigmoid(µ1 − µ2 ) = 1.

5.2.2 Stochastic process perspective


For the logistic function and the Gaussian’s CDF, another equivalent
approach is to perturb each variable in the inequality operator with a
random variable and take the expectation. For example, for the greater
than operator, this gives the approximation
gt(µ1 , µ2 ) ≈ E[gt(µ1 + Z1 , µ2 + Z2 )],
where Z1 and Z2 are two independent random variables. At a high
level, we replaced the real numbers µ1 and µ2 by a stochastic process
(U1 , U2 ), where U1 := µ1 + Z1 and U2 := µ2 + Z2 .
Choosing Z1 and Z2 distributed as Gumbel noise, we get (see Propo-
sition 13.3)
E[gt(µ1 + Z1 , µ2 + Z2 )] = logistic(µ1 − µ2 ),
where
1
logistic(u) := ∈ (0, 1).
1 + e−u
Choosing Z1 and Z2 to be distributed as Gaussian noise, we get

E[gt(µ1 + Z1 , µ2 + Z2 )] = Φ((µ1 − µ2 )/ 2),
5.3. Soft equality operators 93

were Φ the cumulative distribution function (CDF) of the standard


normal distribution, defined in Eq. (3.1).

5.3 Soft equality operators

5.3.1 Heuristic definition


The equality operator can be written as

eq(µ1 , µ2 ) = δ(µ1 − µ2 )

where 
1 if u = 0
δ(u) := .
0 if u ̸= 0
It equals 1 when µ1 = µ2 and is zero everywhere else. A natural idea to
obtain a continuous and smooth relaxation is therefore to replace δ by
a bell-shaped function. For this purpose, we can use a kernel, such as
the PDF of a distribution centered at 0. Formally, we may define
κ(µ1 , µ2 )
eq(µ1 , µ2 ) ≈ ,
κ(0, 0)
where
κ(µ1 , µ2 ) := p0,1 (µ1 − µ2 ),
for p0,1 the PDF of a zero-mean unit-scale distribution in the location-
scale family. The normalization κ(0, 0) ensures that the soft equality
operator is 1 at µ1 = µ2 . For instance, we can use the Gaussian kernel
2 /2
κ(µ1 , µ2 ) := e−(µ1 −µ2 )

or the logistic kernel


exp(µ2 − µ1 )
κ(µ1 , µ2 ) :=
(1 + exp(µ2 − µ1 ))2
1
=
2+e (µ2 −µ 1 ) + e(µ1 −µ2 )

= sech2 ((µ2 − µ1 )/2),

where we defined the hyperbolic secant sech(u) := 2 exp(u)/(exp(2u)+1).


The soft equality operators obtained with these kernels are illustrated
94 Control flows

in Fig. 5.1. The soft equality operator obtained with the logistic kernel
coincides with the expression Petersen et al. (2021) arrive at, in a
different manner.

5.3.2 Gaussian process perspective

To handle the equality operator from a perturbation perspective, we


cannot simply perturb with two independent random noise variables,
as we did for the inequality operator, because we would get E[eq(µ1 +
Z1 , µ2 + Z2 ] = P(µ1 + Z1 = µ2 + Z2 ) = 0 for any µ1 and µ2 .
We can nevertheless generalize the previous approach by using a
Gaussian process (Hida and Hitsuda, 1976). A Gaussian process
on R associates a multivariate Gaussian random variable (U1 , . . . , UK )
to a collection of real numbers (µ1 , . . . , µK ). The Gaussian process
is characterized by the mean function E[Ui ] = µi and its covariance
function, a.k.a. , kernel, that defines the covariance between two points
as Cov(Ui , Uj ) := κ(µi , µj ). Equipped with such mapping from real
numbers to random variables, we need a measure of equality between
random variables. A natural measure for Gaussian random variables is
their correlation
Cov(Ui , Uj ) κ(µi , µj )
corr(Ui , Uj ) := q = ∈ [0, 1].
Var(Ui ) Var(Uj ) κ(0, 0)

From a high-level perspective, we can see the covariance as an inner


product between random variables and so the correlation is a natural
alignment measure, just as (x, y) 7→ ⟨x, y⟩/∥x∥2 ∥y∥2 between two
vectors.
In the case K = 2, we recover the previous heuristically-defined soft
equality operator,

κ(µ1 , µ2 )
eq(µ1 , µ2 ) ≈ corr(U1 , U2 ) = .
κ(0, 0)

Both the Gaussian and logistic kernels arise naturally as the density
of µ1 + Z1 − (µ2 + Z2 ) at 0, with Z1 and Z2 Gaussian or Gumbel random
variables.
5.4. Logical operators 95

Soft equal zero Soft greater than zero


1.00 1.00
0.75 0.75
0.50 0.50
Hard
0.25 0.25 Logistic
Gaussian
0.00 0.00
2 0 2 2 0 2
Figure 5.1: Soft equality and soft greater than operators can be defined as normalized
kernels (PDF) and as CDF functions, respectively.

5.4 Logical operators

Logical operators can be used to perform Boolean algebra. Formally,


we can define them as functions from {0, 1} × {0, 1} to {0, 1}. The and
(logical conjunction a.k.a. logical product), or (logical disjunction a.k.a.
logical addition) and not (logical negation a.k.a. logical complement)
operators, for example, are defined by

1 if π = π ′ = 1
and(π, π ′ ) :=
0 otherwise

1 if 1 ∈ {π, π ′ }
or(π, π ′ ) :=
0 otherwise

0 if π = 1
not(π) := .
1 if π = 0

Classical properties of these operators include


• Commutativity:
and(π, π ′ ) = and(π ′ , π)
or(π, π ′ ) = and(π ′ , π)

• Associativity:
and(π, and(π ′ , π ′′ )) = and(and(π, π ′ ), π ′′ )
or(π, or(π ′ , π ′′ )) = or(or(π, π ′ ), π ′′ )
96 Control flows

• Distributivity of and over or:

and(π, or(π ′ , π ′′ )) = or(and(π, π ′ ), and(π, π ′′ ))

• Neutral element:

and(π, 1) = π
or(π, 0) = π

• De Morgan’s laws:

not(or(π, π ′ )) = and(not(π), not(π ′ ))


not(and(π, π ′ )) = or(not(π), not(π ′ )).

More generally, for a binary vector π = (π1 , . . . , πK ) ∈ {0, 1}K ,


we can define all (universal quantification, ∀) and any (existential
quantification, ∃) operators, which are functions from {0, 1}K to {0, 1},
as 
1 if π = · · · = π = 1
1 K
all(π) :=
0 otherwise

and 
1 if 1 ∈ {π1 , . . . , πK }
any(π) := .
0 otherwise

5.5 Continuous extensions of logical operators

5.5.1 Probabilistic continuous extension


We can equivalently write the and, or and not operators as

and(π, π ′ ) = π · π ′
or(π, π ′ ) = π + π ′ − π · π ′
not(π) = 1 − π.

These are extensions of the previous definitions: we can use them as


functions from [0, 1]×[0, 1] → [0, 1], as illustrated in Fig. 5.2. This means
that we can use the soft comparison operators defined in Section 5.2 to
5.5. Continuous extensions of logical operators 97

And operator Or operator

1.0 1.0
0.8 0.8
0.6 0.6
0.4 0.4
0.2 0.2
0.0 0.0
1.0 1.0
0.8 0.8
0.0 0.6 0.0 0.6
0.2 0.4 0.2 0.4
0.4 0.2 0.4 0.2
0.6 0.6
0.8 0.0 0.8 0.0
1.0 1.0

Figure 5.2: The Boolean and and or operators are functions from {0, 1} × {0, 1}
to {0, 1} but their continuous extensions and(π, π ′ ) := π · π ′ as well as or(π, π ′ ) :=
π + π ′ − π · π ′ define a function from [0, 1] × [0, 1] to [0, 1].

obtain π, π ′ ∈ [0, 1]. Likewise, we can define continuous extensions of


all and any, which are functions from [0, 1]K to [0, 1], as

K
all(π) =
Y
πi
i=1
K
any(π) = 1 − (1 − πi ).
Y

i=1

From a probabilistic perspective, if we let Y and Y ′ to be two


independent random variables distributed according to Bernoulli dis-
tributions with parameter π and π ′ , then

and(π, π ′ ) = P(Y = 1 ∩ Y ′ = 1) = P(Y = 1) · P(Y ′ = 1)


or(π, π ′ ) = P(Y = 1 ∪ Y ′ = 1)
= P(Y = 1) + P(Y ′ = 1) − P(Y = 1 ∩ Y ′ = 1)
= P(Y = 1) + P(Y ′ = 1) − P(Y = 1)P(Y ′ = 1)
not(π) = P(Y ̸= 1) = 1 − P(Y = 1).

In probability theory, these correspond to the product rule of two


independent variables, the addition rule, and the complement rule.
Likewise, if we let Y = (Y1 , . . . , YK ) ∈ {0, 1}K be a random variable
distributed according to a multivariate Bernoulli distribution with
98 Control flows

parameters π = (π1 , . . . , πK ), then


K
all(π) = P(Y1 = 1 ∩ · · · ∩ YK = 1) = P(Yi = 1)
Y

i=1
K
any(π) = P(Y1 = 1 ∪ · · · ∪ YK = 1) = 1 − (1 − P(Yi = 1)).
Y

i=1
This is the chain rule of probability for K independent variables, and
the addition rule for K variables.

5.5.2 Triangular norms and co-norms


More generally, in the fuzzy logic literature (Jayaram and Baczynski,
2008), the concepts of triangular norms and co-norms have been intro-
duced to provide continuous relaxations of the and and or operators,
respectively.
Definition 5.1 (Triangular norms and conorms). A triangular norm,
a.k.a. t-norm, is a function from [0, 1] × [0, 1] to [0, 1] which is is
commutative, associative, neutral w.r.t. 1 and is monotone, meaning
that T (π, π ′ ) ≤ T (τ, τ ′ ) for all π ≤ τ and π ′ ≤ τ ′ . A triangular
conorm, a.k.a. t-conorm, is defined similarly but is neutral w.r.t. 0.

The previously-defined probabilistic extensions of and and or are


examples of triangle norm and conorm. More examples are given in
Table 5.1. Thanks to the associative property of these operators, we can
generalize them to vectors π ∈ [0, 1]K to define continuous extensions
of the all and any operators, as shown in Table 5.2. For more examples
and analysis, see for instance van Krieken (2024, Chapters 2 and 3).

5.6 If-else statements

An if-else statement executes different code depending on a condition.


Formally, we can define the ifelse : {0, 1} × V × V → V function by

v
1 if π = 1
ifelse(π, v1 , v0 ) := (5.1)
v0 if π = 0
= π · v1 + (1 − π) · v0 .
5.6. If-else statements 99

Table 5.1: Examples of triangular norms and conorms, which are continuous
relaxations of the and and or operators, respectively. More instances can be obtained
by smoothing out the min and max operators.

t-norm t-conorm
Probabilistic π · π′ π + π′ − π · π′
Extremum min(π, π ′ ) max(π, π ′ )
Łukasiewicz max(π + π ′ − 1, 0) min(π + π ′ , 1)

Table 5.2: Continuous extensions of the all and any operators.

All (∀) Any (∃)


QK QK
Probabilistic i=1 πi 1 − i=1 (1 − πi )
Extremum min(π1 , . . . , πK ) max(π1 , . . . , πK )
PK PK
Łukasiewicz max( i=1 πi − (K − 1), 0) min( i=1 πi , 1)

The π variable is called the predicate. It is a binary (Boolean) variable,


making the function ifelse undefined if π ̸∈ {0, 1}. The function is
therefore discontinuous and nondifferentiable w.r.t. π ∈ {0, 1}. On the
other hand, v0 ∈ V and v1 ∈ V, which correspond to the false and
true branches, can be continuous variables. If π = 1, the function
is linear w.r.t. v1 and constant w.r.t. v0 . Conversely, if π = 0, the
function is linear w.r.t. v0 and constant w.r.t. v1 . We now discuss how
to differentiate through ifelse.

5.6.1 Differentiating through branch variables

For π ∈ {0, 1} fixed, ifelse(π, v1 , v0 ) is a valid function w.r.t. v1 ∈ V


and v0 ∈ V, and can therefore be used as a node in a computational
graph (Section 7.3). Due to the linearity w.r.t. v1 and v0 , we obtain
that the Jacobians w.r.t. v1 and v0 are

0 if π = 1
∂v0 ifelse(π, v1 , v0 ) :=
I if π = 0
= (1 − π) · I
100 Control flows

and

I if π = 1
∂v1 ifelse(π, v1 , v0 ) :=
0 if π = 0
= π · I,
where I is the identity matrix of appropriate size. Most of the time,
if-else statements are composed with other functions. Let g1 : U1 → V
and g0 : U0 → V be differentiable functions. We then define v1 := g1 (u1 )
and v0 := g0 (u0 ), where u1 ∈ U1 and u0 ∈ U0 . The composition of
ifelse, g1 and g0 is then the function f : {0, 1} × U1 × U0 → V defined by
f (π, u1 , u0 ) := ifelse(π, g1 (u1 ), g0 (u0 ))
= π · g1 (u1 ) + (1 − π) · g0 (u0 ).
We obtain that the Jacobians are
∂u1 f (π, u1 , u0 ) = π · ∂g1 (u1 )
and
∂u0 f (π, u1 , u0 ) = (1 − π)∂g0 (u0 ).
As long as g1 and g0 are differentiable functions, we can therefore
differentiate through the branch variables u1 and u0 without any issue.
More problematic is the predicate variable π, as we now discuss.

5.6.2 Differentiating through predicate variables


The predicate variable π is binary and therefore cannot be differentiated
directly. However, π can be the output of a comparison operator. For
example, suppose we want to express the function f : R × U1 × U0 → V
defined by 
g (u ) if p ≥ 0
1 1
fh (p, u1 , u0 ) := .
g0 (u0 ) otherwise

Using our notation, this can be rewritten as


fh (p, u1 , u0 ) := ifelse(gt(p, 0), g1 (u1 ), g0 (u0 ))
= ifelse(step(p), g1 (u1 ), g0 (u0 ))
= step(p)g1 (u1 ) + (1 − step(p))g0 (u0 ).
5.6. If-else statements 101

The Heaviside step function has a discontinuity at p = 0, but it is


continuous and differentiable with derivative step′ (p) = 0 for all p ̸= 0.
The function fh therefore has null derivative w.r.t. p ̸= 0,

∂p fh (p, u1 , u0 ) = ∂1 fh (p, u1 , u0 )
= step′ (p)(g1 (u1 ) − g0 (u0 ))
= 0.

In other words, while fh has well-defined derivatives w.r.t. p for p ̸= 0,


the derivatives are uninformative. As another example, let us now
consider the function

gh (u1 , u0 ) := fh (t(u1 ), u1 , u0 ),

for some differentiable function t. This time, u1 influences both the


predicate and the true branch. Then, using Proposition 2.8, we obtain

∂u1 gh (u1 , u0 ) = ∂t(u1 )∂1 fh (t(u1 ), u1 , u0 ) + ∂2 fh (t(u1 ), u1 , u0 )


= ∂2 fh (t(u1 ), u1 , u0 ).

In other words, the derivatives of the predicate t(u1 ) do not influence


the derivatives of gh .

5.6.3 Continuous relaxations


Fortunately, we recall that

ifelse(π, v1 , v0 ) = π · v1 + (1 − π) · v0 .

This function is perfectly well-defined, even if π ∈ [0, 1], instead of


π ∈ {0, 1}. That is, this definition is an extension of Eq. (5.1) from the
discrete set {0, 1} to the continuous unit segment [0, 1]. We saw that

gt(a, b) ≈ sigmoid(a − b) ∈ [0, 1],

where sigmoid is for instance the logistic function or the Gaussian CDF.
If we now define

fs (p, u1 , u0 ) := ifelse(sigmoid(p), g1 (u1 ), g0 (u0 ))


= sigmoid(p)g1 (u1 ) + (1 − sigmoid(p))g0 (u0 ), (5.2)
102 Control flows

the Jacobian becomes

∂p fs (p, u1 , u0 ) = sigmoid′ (p)(g1 (u1 ) − g0 (u0 )).

If sigmoid = logistic, the Jacobian is non-null everywhere, allowing


gradients to backpropagate through the computational graph. This is
an example of smoothing by regularization, as studied in Chapter 12.

Probabilistic perspective
From a probabilistic perspective, we can view Eq. (5.2) as the expecta-
tion of gi (ui ), where i ∈ {0, 1} is a binary random variable distributed
according to a Bernoulli distribution with parameter π = sigmoid(p):

fs (p, u1 , u0 ) = Ei∼Bernoulli(sigmoid(p)) [gi (ui )] .

Taking the expectation over the two possibles branches makes the
function differentiable with respect to p. Of course, this comes at the cost
of evaluating both branches, instead of a single one. The probabilistic
perspective suggests that we can also compute the variance if needed as

Vi∼Bernoulli(sigmoid(p)) [gi (ui )]


h i
=Ei∼Bernoulli(sigmoid(p)) (fs (p, u1 , u0 ) − gi (ui ))2 .

Another perspective (Petersen et al., 2021) is based on the logistic


distribution. Indeed, if P is a random variable following a logistic
distribution with mean p and scale 1, we saw in Remark 3.1 that the
CDF is P(P ≤ 0) = logistic(−p) = 1 − logistic(p) and therefore

fs (p, u1 , u0 ) = ifelse(logistic(p), g1 (u1 ), g0 (u0 ))


= logistic(p)g1 (u1 ) + (1 − logistic(p))g0 (u0 )
= P(P > 0) · g1 (u1 ) + P(P ≤ 0) · g0 (u0 ).

5.7 Else-if statements

In the previous section, we focused on if-else statements: conditionals


with only two branches. We now generalize our study to conditionals
including else-if statements, that have K branches.
5.7. Else-if statements 103

5.7.1 Encoding K branches

For conditionals with only 2 branches, we encoded the branch that


the conditional needs to take using the binary variable π ∈ {0, 1}. For
conditionals with K branches, we need a way to encode which of the K
branches the conditional needs to take. To do so, we can use a vector
π ∈ {e1 , . . . , eK }, where ei denotes the standard basis vector (a.k.a.
one-hot vector)

ei := (0, . . . , |{z}
1 , . . . , 0),
i

a vector with a single one in the coordinate i and K − 1 zeros. The


vector ei is the encoding of a categorical variable i ∈ [K].

Combining booleans

To form, such a vector π ∈ {e1 , . . . , eK }, we can combine the previously-


defined comparison and logical operators to define π = (π1 , . . . , πK ).
However, we need to ensure that only one πi is non-zero. We give an
example in Example 5.1.

Argmax and argmin operators

Another way to form π is to use the argmax and argmin operators

argmax(p) := arg max ⟨π, p⟩


π∈{e1 ,...,eK }

argmin(p) := arg min ⟨π, p⟩ = argmax(−p).


π∈{e1 ,...,eK }

They can be seen as a natural generalization of the greater than and


less than operators. In case of ties, we break them arbitrarily.
104 Control flows

5.7.2 Conditionals
We can now express a conditional statement as the function
cond : {e1 , . . . , eK } × V K → V defined by

v if π = e1
 1


.

cond(π, v1 , . . . , vK ) := .. (5.3)


if π = eK

v
K
K
=
X
πi vi .
i=1
Similarly as for the ifelse function, the cond function is discontinuous
and nondifferentiable w.r.t. π ∈ {e1 , . . . , eK }. However, given π = ei
fixed for some i, the function is linear in vi and constant in vj for j ̸= i.
We illustrate how to express a simple example, using this formalism.

Example 5.1 (Soft-thresholding operator). The soft-thresholding op-


erator (see also Section 15.4) is a commonly-used operator to pro-
mote sparsity. It is defined by

0 if |u| ≤ λ


SoftThreshold(u, λ) := u−λ if u ≥ λ .

u+λ if u ≤ −λ

To express it in our formalism, we can define π ∈ {e1 , e2 , e3 } using


comparison operators as

π := (lt(|u|, λ), gt(u, λ), lt(u, −λ))


= (step(λ − |u|), step(u − λ), step(−u − λ)).

Equivalently, we can also define π using an argmax operator as

π := argmax((λ − |u|, u − λ, −u − λ)).

In case of ties, which happens at |u| = λ, we keep only one non-zero


coordinate in π. We can then rewrite the operator as

SoftThreshold(u, λ) = cond(π, 0, u − λ, u + λ).


5.7. Else-if statements 105

6 Mean
4 Standard deviation
Hard
2
0
2
4
6
4 2 0 2 4

Figure 5.3: Example of conditional with three branches: the soft-thresholding


operator. It is a piecewise linear function (dotted black line). Using a softargmax,
we can induce a categorical probability distribution over the three branches. The
expected value (blue line) can be seen as a smoothed out version of the operator.
The induced distribution allows us to also compute the standard deviation.

As we will see, replacing argmax with softargmax induces a cat-


egorical distribution over the three possible branches. The mean
value can be seen as a smoothed out version of the operator, and we
can also compute the standard deviation, as illustrated in Fig. 5.3.

5.7.3 Differentiating through branch variables


For π fixed, cond(π, v1 , . . . , vK ) is a valid function w.r.t. vi , and can
therefore again be used as a node in a computational graph. Due to the
linearity w.r.t. vi , we obtain that the Jacobian w.r.t. vi is

I if π = ei
∂vi cond(π, v1 , . . . , vK ) := .
0 if π ̸= ei
Let gi : Ui → V be a differentiable function and ui ∈ Ui . If we define
the composition
f (π, u1 , . . . , uK ) := cond(π, g1 (u1 ), . . . , gK (uK )),
we then obtain that the Jacobian w.r.t. ui is

∂g (u ) if π = ei
i i
∂ui f (π, u1 , . . . , uK ) := .
0 if π ̸= ei
106 Control flows

As long as the gi functions are differentiable, we can therefore differen-


tiate through the branch variables ui without any issue.

5.7.4 Differentiating through predicate variables


As we saw, π can be obtained by combining comparison and logical
operators, or it can be obtained by argmax and argmin operators.
We illustrate here why these operators are problematic. For example,
suppose we want to express the function

v if p1 = maxj pj
 1


.

fa (p, u1 , . . . , uK ) := .. .

if pK = maxj pj

v
K

In our notation, this can be expressed as

fa (p, u1 , . . . , uK ) := cond(argmax(p), g1 (u1 ), . . . , gK (uK )),

As for the ifelse case, the Jacobian w.r.t. p is null almost everywhere,

∂p fa (p, u1 , . . . , uK ) = 0.

5.7.5 Continuous relaxations


Similarly to the Heaviside step function, the argmax and argmin func-
tions are piecewise constant, with discontinuities in case of ties. Their
Jacobian are zero almost everywhere, and undefined in case of ties.
Therefore, while their Jacobian is well-defined almost everywhere, they
are uninformative and prevent gradient backpropagation. We can replace
the argmax with a softargmax
exp(p)
softargmax(p) := PK ∈ △K
i=1 exp(pi )

and similarly

softargmin(p) := softargmax(−p) ∈ △K .

Other relaxations of the argmax are possible, as discussed in Sec-


tion 12.4.4. See also Proposition 13.4 for the perturbation perspective.
5.7. Else-if statements 107

Fortunately, the definition


K
cond(π, v1 , . . . , vK ) =
X
πi vi
i=1

is perfectly valid if we use π ∈ △K instead of π ∈ {e1 , . . . , eK }, and


can therefore be seen as an extension of Eq. (5.3). If we now define

fs (p, u1 , . . . , uK ) := cond(softargmax(p), g1 (u1 ), . . . , gK (uK ))


K
= [softargmax(p)]i · gi (ui ), (5.4)
X

i=1

the Jacobian becomes

∂p fs (p, u1 , . . . , uK ) = ∂softargmax(p)(g1 (u1 ), . . . , gK (uK )),

which is non-null everywhere, allowing gradients to backpropagate


through the computational graph.

Probabilistic perspective

From a probabilistic perspective, we can view Eq. (5.4) as the ex-


pectation of gi (ui ), where i ∈ [K] is a categorical random variable
distributed according to a categorical distribution with parameter
π = softargmax(p):

fs (p, u1 , . . . , uK ) = Ei∼Categorical(softargmax(p)) [gi (ui )] .

Taking the expectation over the K possible branches makes the function
differentiable with respect to p, at the cost of evaluating all branches,
instead of a single one. Similarly as for the if-else case, we can compute
the variance if needed as

Vi∼Categorical(softargmax(p)) [gi (ui )]


h i
=Ei∼Categorical(softargmax(p)) (fs (p, u1 , . . . , uK ) − gi (ui ))2 .

This is illustrated in Fig. 5.3.


108 Control flows

5.8 For loops

For loops are a control flow for sequentially calling a fixed number K
of functions, reusing the output from the previous iteration. In full
generality, a for loop can be written as follows.

Algorithm 5.1 r = forloop(s0 )


for k := 1, . . . , K do
sk := fk (sk−1 )
r := sK

This defines a computation chain. Assuming the functions fk are


all differentiable, this defines a valid computational graph, we can
therefore use automatic differentiation to differentiate forloop w.r.t. its
input s0 . Feedforward networks, reviewed in Section 4.2, can be seen as
parameterized for loops, i.e.,
fk (sk−1 ) := gk (sk−1 , wk ),
for some differentiable function gk .

Example 5.2 (Unrolled gradient descent). Suppose we want to min-


imize w.r.t. w the function
N
1 X λ
L(w, λ) := ℓ(h(xi , w), yi ) + ∥w∥22 .
N i=1 2

Given an initialization w0 , gradient descent (Section 15.1) performs


iterations of the form

wk = f (wk−1 , γk , λ) := wk−1 − γk ∇1 L(wk−1 , λ).

Gradient descent can therefore be expressed as a for loop with

fk (wk−1 ) := f (wk−1 , γk , λ).

This means that we can differentiate through the iterations of


gradient descent, as long as f is differentiable, meaning that L is
twice differentiable. This is useful for instance to perform gradient-
5.9. Scan functions 109

based optimization of the hyperparameters γk or λ. This a special


case of bilevel optimization; see also Chapter 10.

Example 5.3 (Bubble sort). Bubble sort is a simple sorting algo-


rithm that works by repeatedly swapping elements if necessary.
Mathematically, swapping two elements i and j can be written as
a function from RN × [N ] × [N ] to RN defined by

swap(v, i, j) := v + (vj − vi )ei + (vi − vj )ej .

We can then write bubble sort as

for i := 1, . . . , N do
for j := 1, . . . , N − i − 1 do
v ′ := swap(v, j, j + 1)
π := step(vj − vj+1 )
v ← ifelse(π, v ′ , v)

Replacing the Heaviside step function with the logistic function


gives a smoothed version of the algorithm.

5.9 Scan functions

Scan is a higher-order function (meaning a function of a function)


originating from functional programming. It is useful to perform an
operation f on individual elements uk while carrying the result sk of
that operation to the next iteration.

Algorithm 5.2 r = scan(s0 , u1 , . . . , uK )


for k := 1, . . . , K do
sk , vk := f (sk−1 , uk )
r := (sK , v1 , . . . , vK )

Assuming the function f is differentiable, this again defines a valid


computational graph and can be differentiated through using autodiff.
Recurrent neural networks (RNNs), reviewed in Section 4.6, can be
110 Control flows

seen as a parameterized scan. An advantage of this abstraction is


that parallel scan algorithms have been studied extensively in computer
science (Blelloch, 1989; Sengupta et al., 2010).
Example 5.4 (Prefix sum). Scan can be seen as a generalization of
the prefix sum (a.k.a. cumulated sum) from the addition to any
binary operation. Indeed, a prefix sum amounts to perform

v1 := u1
v2 := u1 + u2
v3 := u1 + u2 + u3
..
.

which can be expressed as a scan by defining

vk := sk−1 + uk
f (sk−1 , uk ) := (vk , vk )

starting from s0 = 0 (sK and vK are redundant in this case).

5.10 While loops

5.10.1 While loops as cyclic graphs


A while loop is a control flow used to repeatedly perform an operation,
reusing the output of the previous iteration, until a certain condition is
met. Suppose f : S → {0, 1} is a function to determine whether to stop
(π = 1) or continue (π = 0) and g : S → S is a function for performing
an operation. Then, without loss of generality, a while loop can be
written as follows.

Algorithm 5.3 r = whileloop(s)


π ← f (s)
while π = 0 do
s ← g(s)
π ← f (s)
r := s
5.10. While loops 111

This definition is somewhat cyclic, as we used the while keyword.


However, we can equivalently rewrite the algorithm recursively.

Algorithm 5.4 r = whileloop(s)


π := f (s)
if π = 1 then
r := s
else
r := whileloop(g(s))

Unlike for loops and scan, the number of iterations of while loops is
not known ahead of time, and may even be infinite. In this respect, a
while loop can be seen as a cyclic graph.

Importance of lazy evaluation

We can also implement Algorithm 5.4 in terms of the ifelse function


defined in Section 5.6 as

r := ifelse(f (s), s, whileloop(g(s)))


= f (s) · s + (1 − f (s)) · whileloop(g(s)).

However, to avoid an infinite recursion, it is crucial that ifelse supports


lazy evaluation. That is, whileloop(g(s)) in the definition above should
be evaluated if and only if π = f (s) = 0. In other words, the fact that
f (s) ∈ {0, 1} is crucial to ensure that the recursion is well-defined.

5.10.2 Unrolled while loops

To avoid the issues with unbounded while loops, we can enforce that a
while loop stops after T iterations, i.e., we can truncate the while loop.
112 Control flows

Unrolling Algorithm 5.4 gives (here with T = 3)

π0 := f (s0 )
if π0 = 1 then
r := s0
else
s1 := g(s0 ), π1 := f (s1 )
if π1 = 1 then
r := s1
else
s2 := g(s1 ), π2 := f (s2 )
if π2 = 1 then
r := s2
else
r := s3 := g(s2 )

Using the ifelse function, we can rewrite it as

r = ifelse(π0 ,
s0 ,
ifelse(π1 ,
s1 ,
ifelse(π2 ,
s2 ,
s3 ))).

which is itself equivalent to

r = π0 s0 + (1 − π0 ) [π1 s1 + (1 − π1 ) [π2 s2 + (1 − π2 )s3 ]]


= π0 s0 + (1 − π0 )π1 s1 + (1 − π0 )(1 − π1 )π2 s2 + (1 − π0 )(1 − π1 )(1 − π2 )s3 .
5.10. While loops 113

More generally, for T ∈ N, the formula is


T
r= ((1 − π0 ) . . . (1 − πi−1 )) πi si
X

i=0
 
T i−1
= (1 − πj ) πi si ,
X Y

i=0 j=0

where we defined

si := g(si−1 ) := g i (s0 ) := g ◦ · · · ◦ g (s0 ) ∈ S


| {z }
i times
πi := f (si ) ∈ {0, 1}.

Example 5.5 (Computing the square root using Newton’s method).



Computing the square root x of a real number x > 0 can be cast
as a root finding problem, which we can solve using Newton’s
method. Starting from an initialization s0 , the iterations read
1 x
 
si+1 := g(si ) := si + .
2 si
To measure the error on iteration i, we can define
1
ε(si ) := (s2i − x)2 .
2
As a stopping criterion, we can then use

1 if ε(si ) ≤ τ
πi :=
0 otherwise
= step(τ − ε(si )),

where 0 < τ ≪ 1 is an error tolerance and step is the Heaviside


step function.

5.10.3 Markov chain perspective


Given the function g : S → S and the initialization s0 ∈ S, a while
loop can only go through a discrete set of values s0 , s1 , s2 , . . . defined
114 Control flows

by si = g(si−1 ). This set is potentially countably infinite if the while


loop is unbounded, and finite if the while loop is guaranteed to stop.
Whether the loop moves from the state si to the state si+1 , or stays
at si , is determined by the stopping criterion πi ∈ {0, 1}. To model
the state of the while loop, we can then consider a Markov chain
with a discrete space {s0 , s1 , s2 , . . . }, which we can always identify with
{0, 1, 2, . . . }, with transition probabilities


πi

 if i = j
P(St+1 = si |St = sj ) = pi,j := (1 − πi ) if i = j + 1 ,

0 otherwise

and initial state S0 = s0 . Here, St is the value at iteration t of the loop.


Note that since πi ∈ {0, 1}, the values St are “degenerate” probabilities.
However, this framework lets us generalize to a smooth version of the
while loop naturally. To illustrate the framework, if the while loop stops
at T = 3, the transition probabilities can be cast as a matrix

s0 s1 s2 s3
s0  0 1 0 0
P := (pi,j )Ti,j=0 := s1  0 0 1 0 .

s 0
2 0 0 1
s3 0 0 0 1

The output r of the while-loop is determined by the time at which the


state stays at the same value

I = min{i ∈ {1, 2, . . .} s.t. Si = Si−1 }.


5.10. While loops 115

Note that I itself is a random variable, as it is defined by the Si variables.


It is called a stopping time. The output of the chain is then
r = E[SI ]
+∞
= P(I = i)E[Si |I = i]
X

i=1
+∞
= P(I = i)si−1
X

i=1
X i−2
+∞
= (1 − πj )πi−1 si−1
Y

i=1 j=0
X i−1
+∞
= (1 − πj )πi si .
Y

i=0 j=0

Because the stopping time is not known ahead of time, the sum over i
goes from 0 to ∞. However, if we enforce in the stopping criterion that
the while loop run no longer than T iterations, by setting
πi := or(f (si ), eq(i, T )) ∈ {0, 1},
we then naturally recover the expression found by unrolling the while
loop before,
T i−1
r = E[SI ] = (1 − πj )πi si .
X Y

i=0 j=0

For example, with T = 3, the transition probability matrix is


s0 s1 s2 s3
s0 π
0 1 − π0 0 0 
0 1 − π1 0 
P = s1  π1 .
s2 0 0 π2 1 − π2 
s3 0 0 0 1

Smoothed while loops


With the help of this framework, we can backpropagate even through the
while loop’s stopping criterion, provided that we smooth out the predi-
cate. For example, we saw that the stopping criterion in Example 5.5 is
116 Control flows

f (si ) = step(τ − ε(si )) and therefore

πi := or(f (si ), eq(i, T )) ∈ {0, 1}.

With step, the derivative of the while loop with respect to τ will always
be 0, just like it was the case for if-else statements. If we change the
stopping criterion to f (si ) = sigmoid(τ − ε(si )), we then have (recall
that or is well defined on [0, 1] × [0, 1])

πi := or(f (si ), eq(i, T )) ∈ [0, 1].

With sigmoid, we obtain more informative derivatives. In particular,


with sigmoid = logistic, the derivatives w.r.t. τ are always non-zero.
The smoothed output is expressed as before as the expectation
T i−1
r = E[SI ] = (1 − πj )πi si
X Y

i=0 j=0
T i−1
= (1 − sigmoid(ε(si ) − τ ))sigmoid(ε(si ) − τ )si .
X Y

i=0 j=0

5.11 Summary

In this chapter, we studied control flows for differentiable programming.


For conditionals, we saw that differentiating through the branch
variables is not problematic. However, for the predicate variable, we
saw that a differentiable relaxation is required to avoid null derivatives.
To do so, we introduced soft comparison operators in a principled
manner, using a stochastic process perspective, as well as the continuous
extension of logical operators.
For loops and scan define valid computational graphs, as their
number of iterations is fixed ahead of time. Feedforward networks and
RNNs can be seen as parameterized for loops and scan, respectively.
Unlike for loops and scan, the number of iterations of while loops is
not known ahead of time and may even be infinite. However, unrolled
while loops define valid directed acyclic graphs. We defined a principled
way to differentiate through the stopping criterion of a while loop,
thanks to a Markov chain perspective.
Part III

Differentiating through
programs
6
Finite differences

One of the simplest way to numerically compute derivatives is to use


finite differences, which approximate the infinitesimal definition of
a derivative. Finite differences only require function evaluations,
and can therefore work with blackbox functions (i.e., they ignore the
compositional structure of functions). Without loss of generality, our
exposition focuses on computing directional derivatives ∂f (w)[v], for a
function f : E → F, evaluated at w ∈ E, in the direction v ∈ E.

6.1 Forward differences

From Definition 2.3 and Definition 2.12, the directional derivative and
more generally the JVP are defined as a limit,

f (w + δv) − f (w)
∂f (w)[v] := lim .
δ→0 δ
This suggests that we can approximate the directional derivative and
the JVP using
f (w + δv) − f (w)
∂f (w)[v] ≈ ,
δ

118
6.2. Backward differences 119

for some 0 < δ ≪ 1. This formula is called a forward difference. From


the Taylor expansion in Section 2.5.4, we indeed have

δ2 2 δ3
f (w+δv)−f (w) = δ∂f (w)[v]+ ∂ f (w)[v, v]+ ∂ 3 f (w)[v, v, v]+. . .
2 3!
so that

f (w + δv) − f (w) δ δ2
= ∂f (w)[v] + ∂ 2 f (w)[v, v] + ∂ 3 f (w)[v, v, v] + . . .
δ 2 3!
= ∂f (w)[v] + o(δ).

The error incurred by choosing a finite rather than infinitesimal δ in the


forward difference formula is called the truncation error. The Taylor
approximation above shows that this error is of the order of o(δ).
However, we cannot choose a too small value of δ, because the
evaluation of the function f on a computer rounds the value of f to
machine precision. Mathematically, a scalar-valued function f evaluated
on a computer becomes a function f˜ such that f˜(w) ≈ [f (w)/ε]ε,
where [f (w)/ε] denotes the closest integer of f (w)/ε ∈ R and ε is
the machine precision, i.e., the smallest non-zero real number encoded
by the machine. This means that the difference f (w + δv) − f (w)
evaluated on a computer is prone to round-off error of the order of
o(ε). We illustrate the trade-off between truncation and round-off errors
in Fig. 6.1.

6.2 Backward differences

As an alternative, we can approximate the directional derivative and


the JVP by
f (w) − f (w − δv)
∂f (w)[v] ≈ ,
δ
for some 0 < δ ≪ 1. This formula is called a backward difference.
From the Taylor expansion in Section 2.5.4, we easily verify that (f (w)−
f (w − δv))/δ = ∂f (w)[v] + o(δ), so that the truncation error is the
same as for the forward difference.
120 Finite differences

10−2 Round-off error Truncation error


dominant dominant
−4
10

10−6
Approx. Error
10−8

10−10

10−12

10−14 Forward Diff.


Central Diff.
10−16
Complex Step

10−13 10−11 10−9 10−7 10−5 10−3 10−1


δ

Figure 6.1: The forward or central difference schemes applied to f (x) := ln(1 +
exp(−x)) to approximate f ′ (x) at x = 1 induce both truncation error (for large δ)
and round-off error (for small δ).

6.3 Central differences

Rather than using an asymmetric formula to approximate the derivative,


we can use
f (w + δv) − f (w − δv)
∂f (w)[v] ≈ ,

for some 0 < δ ≪ 1. This formula is called a central difference. From
the Taylor expansion in Section 2.5.4, we have
2δ 3 3
f (w + δv) − f (w − δv) =2δ∂f (w)[v] + ∂ f (w)[v, v, v]
3!
2δ 5 5
+ ∂ f (w)[v, . . . , v] + . . .
5!
so that
f (w + δv) − f (w − δv) δ2
=∂f (w)[v] + ∂ 3 f (w)[v, v, v] + . . .
2δ 3!
=∂f (w)[v] + o(δ 2 ).

We see that the terms corresponding to derivatives of even order


canceled out, allowing the formula to achieve o(δ 2 ) truncation error.
For any δ < 1, the truncation error of the central difference is much
6.4. Higher-accuracy finite differences 121

smaller than the one of the forward or backward differences as confirmed


empirically in Fig. 6.1.

6.4 Higher-accuracy finite differences

The truncation error can be further reduced by making use of additional


function evaluations. One can generalize the forward difference scheme
by a formula of the form
p
ai
∂f (w)[v] ≈ f (w + iδv)
X

i=0
δ

requiring p + 1 evaluations. To select the ai and reach a truncation error


of order o(δ p ), we can use a Taylor expansion on each term of the sum
to get
p p p
ai ij δ j−1 j
f (w + iδv) = ∂ f (w)[v, . . . , v] + o(δ p ).
X X X
ai
k=0
δ k=0 j=0
j!

By grouping the terms in the sum for each order of derivative, we obtain
a set of p + 1 equations to be satisfied by the p + 1 coefficients a0 , . . . , ap ,
that is,

a0 + a1 + . . . + ap = 0
a1 + 2a2 + . . . + pap = 1
a1 + 2j a2 + . . . + pj ap = 0 ∀j ∈ {2, . . . , p}.

This system of equations can be solved analytically to derive the co-


efficients. Backward differences can be generalized similarly by using
∂f (w)[v] ≈ pi=0 aδi f (w − iδv). Similarly, the central difference scheme
P

can be generalized by using


p
ai
∂f (w)[v] ≈ f (w + iδv),
X

i=−p
δ

to reach a truncation error of order o(δ 2p ). Solving for the coefficients


a−p , . . . , ap as above reveals that a0 = 0. Therefore, only 2p evaluations
are necessary.
122 Finite differences

6.5 Higher-order finite differences

To approximate higher order derivatives, we can follow a similar rea-


soning. Namely, we can generalize the forward difference scheme to
approximate the derivative of order k by
p
ai
∂ k f (w)[v, . . . , v] ≈ f (w + iδv).
X

i=0
δk

As before, we can expand the terms in the sum. For the approximation
to capture only the k th derivative, we now require the coefficients ai to
satisfy

0j a0 + 1j a1 + 2j a2 + . . . + pj ap = 0 ∀j ∈ {0, . . . , k − 1}.
0k a0 + 1k a1 + 2k a2 + . . . + pk ap = k!
0j a0 + 1j a1 + 2j a2 + . . . + pj ap = 0 ∀j ∈ {k + 1, . . . , p}.

With the resulting coefficients, we obtain a truncation error of order


o(δ p−k+1 ), while making p + 1 evaluations. For example, for p = k = 2,
we can approximate the second-order derivative as

−(3/2)f (x) + 2f (x + δv) − (1/2)f (x + 2δv)


∂ 2 f (w)[v, v] ≈ ,
δ2
with a truncation error of order o(δ).
The central difference scheme can be generalized similarly by
p
ai
∂ k f (w)[v, . . . , v] ≈ f (w + iδv),
X

i=−p
δk

to reach truncation errors of order o(δ 2p+2−2⌈(k+1)/2⌉ ). For example, for


k = 2, p = 1, we obtain the second-order central difference

f (w + δv) + f (w − δv) − 2f (w)


∂ 2 f (w)[v, v] ≈ .
δ2
By using a Taylor expansion we see that, this time, the terms corre-
sponding to derivatives of odd order cancel out and the truncation
error is o(δ 2 ) while requiring 3 evaluations.
6.6. Complex-step derivatives 123

6.6 Complex-step derivatives

Suppose f is well defined on CP , the space of P -dimensional complex



numbers. Let us denote the imaginary unit by i = −1. Then, the
Taylor expansion of f reads
(iδ)2 2
f (w + (iδ)v) =f (w) + (iδ)∂f (w)[v] + ∂ f (w)[v, v]
2
(iδ)3 3
+ ∂ f (w)[v, v, v] + . . .
3!
δ2
=f (w) + (iδ)∂f (w)[v] − ∂ 2 f (w)[v, v]
2
i(δ)3 3
− ∂ f (w)[v, v, v] + . . . .
3!
We see that the real part corresponds to even-degree terms and the
imaginary part corresponds to odd-degree terms. We therefore obtain

Re(f (w + (iδ)v)) = f (w) + o(δ 2 )

and
f (w + (iδ)v)
 
Im = ∂f (w)[v] + o(δ 2 ).
δ
This suggests that we can compute directional derivatives using the
approximation
f (w + (iδ)v)
 
∂f (w)[v] ≈ Im ,
δ
for 0 < δ ≪ 1. This is called the complex-step derivative approxi-
mation (Squire and Trapp, 1998; Martins et al., 2003).
Contrary to forward, backward and central differences, we see that
only a single function call is necessary. A function call on complex
numbers may take roughly twice the cost of a function call on real
numbers. However, thanks to the fact that a difference of functions is
no longer needed, the complex-step derivative approximation usually
enjoys smaller round-off error as illustrated in Fig. 6.1. That said, one
drawback of the method is that all elementary operations within the
program implementing the function f must be well-defined on complex
numbers, e.g., using overloading.
124 Finite differences

Table 6.1: Computational complexity in number of function evaluations for com-


puting the directional derivative and the gradient of a function f : RP → R by finite
differences and complex step derivatives.

Directional derivative Gradient


Forward difference 2 P +1
Backward difference 2 P +1
Central difference 2 2P
Complex step 1 P

6.7 Complexity

We now discuss the computational complexity in terms of function


evaluations of finite differences and complex-step derivatives. For con-
creteness, as this is the most common use case in machine learning, we
discuss the case of a single M = 1 output, i.e., we want to differentiate
a function f : RP → R. Whether we use forward, backward or central
differences, the computational complexity of computing the directional
derivative ∂f (w)[v] in any direction v amounts to two calls to f . For
computing the gradient ∇f (w), we can use (see Definition 2.6) that

[∇f (w)]j = ⟨∇f (w), ej ⟩ = ∂f (w)[ej ],

for j ∈ [P ]. For forward and backward differences, we therefore need


P + 1 function calls to compute the gradient, while we need 2P function
calls for central differences. For the complex step approximation, we need
P complex function calls. We summarize the complexities in Table 6.1.

6.8 Summary

We reviewed finite differences, a simple way to numerically compute


derivatives using only function evaluations. Central differences achieve
smaller truncation error than forward and backward differences. We
saw that it is possible to achieve smaller truncation error, at the cost
of more function evaluations. Complex-step derivatives achieve smaller
6.8. Summary 125

round-off error than central differences but require the function and
the program implementing it to be well-defined on complex numbers.
However, whatever the method used, finite differences require a number
of function calls that is proportional to the number of dimensions. They
are therefore seldom used in machine learning, where there can be
millions or billions of dimensions. The main use cases of finite differ-
ences are therefore i) for blackbox functions of low dimension and ii)
for test purposes (e.g., checking that a gradient function is correctly
implemented). For modern machine learning, the main workhorse is
automatic differentiation, as it leverages the compositional structure of
functions. This is what we study in the next chapter.
7
Automatic differentiation

In Chapter 2, we reviewed the fundamentals of differentiation and


stressed the importance of two linear maps: the Jacobian-vector product
(JVP) and its adjoint, the vector-Jacobian product (VJP). In this chap-
ter, we review forward-mode and reverse-mode autodiff using these
two linear maps. We start with computation chains and then gener-
alize to feedforward networks and general computation graphs. We
also review checkpointing, reversible layers and randomized estimators.

7.1 Computation chains

To begin with, consider a computation chain (Section 4.1.1) repre-


senting a function f : S0 → SK expressed as a sequence of compositions
f := fK ◦ · · · ◦ f1 , where fk : Sk−1 → Sk . The computation of f can be

126
7.1. Computation chains 127

unrolled into a sequence of operations

s0 ∈ S0
s1 := f1 (s0 ) ∈ S1
..
.
sK := fK (sK−1 ) ∈ SK
f (x) := sK . (7.1)

Our goal is to compute the variations of f around a given input s0 . In


a feedforward network, this amounts to estimating the influence of a
given input s0 for fixed parameters (we will see how to estimate the
variations w.r.t. parameters w in the sequel).

Jacobian matrix. We first consider the computation of the full Jaco-


bian ∂f (s0 ), seen as a matrix, as the notation ∂ indicates. Following
Proposition 2.2, we have

∂f (s0 ) = ∂fK (sK−1 )∂fK−1 (sK−2 ) . . . ∂f2 (s1 ), ∂f1 (s0 ), (7.2)

where ∂fk (sk−1 ) are the Jacobians of the intermediate functions com-
puted at s0 , . . . , sK , as defined in Eq. (7.1). The main drawback of
this approach is computational: computing the full ∂f (s0 ) requires
to materialize the intermediate Jacobians in memory and to perform
matrix-matrix multiplications. However, in practice, computing the full
Jacobian is rarely needed. Indeed, oftentimes, we only need to right-
multiply or left-multiply with ∂f (s0 ). This gives rise to forward-mode
and reverse-mode autodiff, respectively.

7.1.1 Forward-mode

We now interpret the Jacobian ∂f (s0 ) as a linear map, as the non-bold


∂ indicates. Following Proposition 2.6, ∂f (s0 ) is the composition of the
intermediate linear maps,

∂f (s0 ) = ∂fK (sK−1 ) ◦ ∂fK−1 (sK−2 ) ◦ . . . ◦ ∂f2 (s1 ) ◦ ∂f1 (s0 ).


128 Automatic differentiation

... ...

... ...

Figure 7.1: Forward-mode autodiff for a chain of computations. For readability,


we denoted the intermediate JVP as a function of two variables ∂fk : sk−1 , tk−1 7→
∂fk (sk−1 )[tk−1 ] with ∂fk (sk−1 )[tk−1 ] = tk .

Evaluating ∂f (s0 ) on an input direction v ∈ S0 can be decomposed,


like the function Eq. (7.1) itself, into intermediate computations

t0 := v
t1 := ∂f1 (s0 )[t1 ]
..
.
tK := ∂fK (sK−1 )[tK−1 ]
∂f (s0 )[v] := tK .

Each intermediate linear map ∂fk (sk−1 ) amounts to a Jacobian-vector


product (JVP) and can be performed in a forward manner, along the
computation of the intermediate states sk . This can also be seen as
multiplying the matrix defined in Eq. (7.2) with a vector, from right
to left. This is illustrated in Fig. 7.1 and the procedure is summarized
in Algorithm 7.1.

Computational complexity. The JVP follows exactly the computations


of f , with an additional variable tk being propagated. If we consider that
computing ∂fk is roughly as costly as computing fk , then computing a
JVP has roughly twice the computational cost of f . See Section 7.3.3
for a more general and more formal statement.
7.1. Computation chains 129

Algorithm 7.1 Forward-mode autodiff for chains of computations


Functions: f := fK ◦ . . . ◦ f1
Inputs: input s0 ∈ S0 , input direction v ∈ S0
1: Initialize t0 := v
2: for k := 1, . . . , K do
3: Compute sk := fk (sk−1 ) ∈ Sk
4: Compute tk := ∂fk (sk−1 )[tk−1 ] ∈ Sk
Outputs: f (s0 ) := sK , ∂f (s0 )[v] = tK

Memory usage. The memory usage of a program at a given evaluation


step is the number of variables that need to be stored in memory to
ensure the execution of all remaining steps. The memory cost of a
program is then the maximal memory usage over all evaluation steps.
For our purposes, we analyze the memory usage and memory cost
by examining the given program. Formal definitions of operations on
memory such as read, write, delete and associated memory costs are
presented by Griewank and Walther (2008, Chapter 4).
For example, to execute the chain f = fK ◦ · · · ◦ f1 , at each step k,
we only need to have access to sk−1 to execute the rest of the program.
As we compute sk , we can delete sk−1 from memory and replace it by
sk . Therefore, the memory cost associated to the evaluation of f is just
the maximal dimension of the sk variables.
For forward mode autodiff, as we follow the computations of f , at
each step k, we only need to have access to sk−1 and tk−1 to execute the
rest of the program. The memory used by sk−1 and tk−1 can directly be
used for sk , tk once they are computed. The memory usage associated
to the JVP is summarized in Fig. 7.2. Overall the memory cost of the
JVP is then exactly twice the memory cost of the function itself.

7.1.2 Reverse-mode
In machine learning, most functions whose gradient we need to compute
take the form ℓ ◦ f , where ℓ is a scalar-valued loss function and f is a
network. As seen in Proposition 2.3, the gradient takes the form

∇(ℓ ◦ f )(s0 ) = ∂f (s0 )∗ [∇ℓ(f (s0 ))].


130 Automatic differentiation

Memory
usage

Algorithm steps

Figure 7.2: Memory usage of forward-mode autodiff for a computation chain. Here
t0 = v, sK = f (s0 ), tK = ∂f (s0 )[v].

This motivates the need for applying the adjoint ∂f (s0 )∗ to ∇ℓ(f (s0 )) ∈
SK and more generally to any output direction u ∈ SK . From Propo-
sition 2.7, we have

∂f (s0 )∗ = ∂f1 (s0 )∗ ◦ . . . ◦ ∂fK (sK−1 )∗ .

Evaluating ∂f (s0 )∗ on an output direction u ∈ SK is decomposed as

rK = u
rK−1 = ∂fK (sK−1 )∗ [rK ]
..
.
r0 = ∂f1 (s0 )∗ [r1 ]
∂f (s0 )∗ [u] = r0 .

Each intermediate adjoint ∂fk (sk−1 )∗ amounts to a vector-Jacobian


product (VJP). The key difference with the forward mode is that the
procedure runs backward through the chain, hence the name reverse
mode autodiff. This can also be seen as multiplying Eq. (7.2) from left
to right. The procedure is illustrated in Fig. 7.3 and summarized in
Algorithm 7.2.

Computational complexity. In terms of number of operations, the


VJP simply passes two times through the chain, once forward, then
backward. If we consider the intermediate VJPs to be roughly as costly
as the intermediate functions themselves, the VJP amounts just to twice
the cost of the original function, just as the JVP. See Section 7.3.3 for
a more generic and formal statement.
7.1. Computation chains 131

Forward pass

... ...

... ...

Backward pass

Figure 7.3: Reverse mode of automatic differentiation for a chain of computations.


For readability, we denoted the intermediate VJP as a function of two variables
∂fk∗ : (sk−1 , rk ) 7→ ∂fk (sk−1 )∗ [rk ] with ∂fk (sk−1 )∗ [rk ] = rk−1 .

Algorithm 7.2 Reverse-mode autodiff for chains of computations


Functions: f := fK ◦ . . . ◦ f1 ,
Inputs: input s0 ∈ S0 , output direction u ∈ SK
1: for k := 1, . . . , K do ▷ Forward pass
2: Compute sk := fk (sk−1 ) ∈ Sk
3: Initialize rK := u.
4: for k := K, . . . , 1 do ▷ Backward pass
5: Compute rk−1 := ∂fk (sk−1 )∗ [rk ] ∈ Sk−1
Outputs: f (s0 ) := sK , ∂f (s0 )∗ [u] = r0
132 Automatic differentiation

Memory
usage

Algorithm steps
Forward pass Backward pass

Figure 7.4: Memory usage of reverse mode autodiff for a computation chain.

Memory usage. Recall that the memory usage of a program at a given


evaluation step is the number of variables that need to be stored in
memory to ensure the execution of the remaining steps. If we inspect
Algorithm 7.3, to execute all backward steps, that is the loop in line 4,
we need to have access to all the intermediate inputs s0 , . . . , sK−1 .
Therefore, the memory cost of reverse-mode autodiff is proportional to
the length of the chain K. Fig. 7.4 illustrates the memory usage during
reverse mode autodiff. It grows linearly until the end of the forward
pass and then progressively decreases until it outputs the value of the
function and the VJP. The memory cost can be mitigated by means of
checkpointing techniques presented in Section 7.5.

Decoupled function and VJP evaluations. The additional memory


cost of reverse mode autodiff comes with some advantages. If we need
to compute ∂f (s0 )∗ [ui ] for n different output directions ui , we only
need to compute and store once the intermediate computations sk and
then make n calls to the backward pass. In other words, by storing in
memory the intermediate computations sk , we may instantiate a VJP
operator, which we may apply to any u through the backward pass.
7.1. Computation chains 133

Formally, the forward and backward passes can be decoupled as

forward(f, s0 ) := (f (s0 ), ∂f (s0 )∗ )

where
∂f (s0 )∗ [u] := backward(u; s0 , . . . , sK−1 ).
In functional programming terminology, the VJP ∂f (s0 )∗ is a closure,
as it contains the intermediate computations s0 , . . . sK . The same can
be done for the JVP ∂f (s0 ) if we want to apply to multiple directions
vi .

Example 7.1 (Multilayer perceptron with fixed parameters ). As a run-


ning example, consider a multilayer perceptron (MLP) with one
hidden layer and (for now) given fixed weights. As presented in
Chapter 4, an MLP can be decomposed as

s0 = x
s1 = f1 (s0 ) = σ(A1 s0 + b1 )
s2 = f2 (s1 ) = A2 s1 + b2
f (x) = s2 ,

for A1 , A2 , b1 , b2 some fixed parameters and σ an activation func-


tion such as the softplus activation function σ(x) = log(1 + ex ) with
derivative σ ′ (x) = ex /(1 + ex ).
Evaluating the JVP of f on an input x along a direction v can
then be decomposed as

t0 = v
t1 = σ ′ (A1 s0 + b1 ) ⊙ (A1 t0 )
t2 = A2 t1
∂f (x)[v] = t2 ,

where we used in the second line the JVP of element-wise function


as in Example 7.3.
Evaluating the VJP of f at x requires to evaluate the interme-
134 Automatic differentiation

diate VJPs at the stored activations

r2 = u
r1 = ∂f2 (s1 )∗ [r2 ] = A⊤
2 r2
r0 = ∂f1 (s0 )∗ [r1 ] = A⊤ ′
1 (σ (A1 s0 + b1 ) ⊙ r1 )
∂f (x)∗ [u] = r0 .

7.1.3 Complexity of entire Jacobians


In this section, we analyze the time and space complexities of forward-
mode and reverse-mode autodiff for computing the entire Jacobian
matrix ∂f (s0 ) of a computain chain f = fK ◦ · · · ◦ f1 . We assume
Sk ⊆ RDk , DK = M and D0 = D.

Complexity of forward-mode autodiff


Using Definition 2.8, we find that we can extract each column ∂j f (s0 ) ∈
RM of the Jacobian matrix, for j ∈ [D], by multiplying with the standard
basis vector ej ∈ RD :

∂1 f (s0 ) = ∂f (s0 )e1


..
.
∂D f (s0 ) = ∂f (s0 )eD .

Computing the full Jacobian matrix therefore requires D JVPs with


vectors in RD . Assuming each fk in the chain composition has the form
fk : RDk−1 → RDk , seen as a matrix, ∂fk (sk−1 ) has size Dk × Dk−1.
Therefore, the computational cost of D JVPs is O D k=1 Dk Dk−1 .
PK

The memory cost is O(maxk∈[K] Dk ), since we can release intermediate


computations after each layer is processed. Setting D1 = · · · = DK−1 =
D for simplicity and using DK = M , we obtain that the computational
cost of computing D JVPs and therefore of computing the full Jacobian
matrix by forward-mode autodiff is O(M D2 + KD3 ). The memory cost
is O(max{D, M }). If a function is single-input D = 1, then the forward
mode computes at once all the Jacobian, which reduces to a single
directional derivative.
7.1. Computation chains 135

Forward-mode Reverse-mode
Time O(M D2 + KD3 ) O(M 2 D + KM D2 )
Space O(max{M, D}) O(KD + M )

Table 7.1: Time and space complexities of forward-mode and reverse-mode autodiff
for computing the full Jacobian of a chain of functions f = fK ◦ · · · ◦ f1 , where
fk : RD → RD if k = 1, . . . , K − 1 and fK : RD → RM . We assume ∂fk is a dense
linear operator. Forward mode requires D JVPs. Reverse mode requires M VJPs.

Complexity of reverse-mode autodiff

Using Definition 2.8, we find that we can extract each row of the Jaco-
bian matrix, which corresponds to the transposed gradients ∇fi (s0 ) ∈
RD , for i ∈ [M ], by multiplying with the standard basis vector ei ∈ RM :

∇f1 (s0 ) = ∂f (s0 )∗ e1


..
.
∇fM (s0 ) = ∂f (s0 )∗ eM .

Computing the full Jacobian matrix therefore requires M VJPs with


vectors in RM . Assuming as before that each fk in the chain composition
has the
 form f : RDk−1
 → R , the computational cost ofPM VJPs
Dk
PK k
is O M k=1 Dk Dk−1 . However, the memory cost is O( K k=1 Dk ),
as we need to store the intermediate computations for each of the
K layers. Setting D0 = · · · = DK−1 = D for simplicity and using
DK = M , we obtain that the computational cost of computing M VJPs
and therefore of computing the full Jacobian matrix by reverse-mode
autodiff is O(M 2 D + KM D2 ). The memory cost is O(KD + M ). If the
function is single-output M = 1, reverse-mode autodiff computes at
once the Jacobian, which reduces to the gradient.

When to use forward-mode vs. reverse-mode autodiff?

We summarize the time and space complexities in Table 7.1. Generally,


if M < D, reverse-mode is more advantageous at the price of some
memory cost. If M ≥ D, forward mode is more advantageous.
136 Automatic differentiation

7.2 Feedforward networks

In the previous section, we derived forward-mode autodiff and reverse-


mode autodiff for computation chains with an input s0 ∈ S0 . In this
section, we now derive reverse-mode autodiff for feedforward networks,
in which each layer fk is now allowed to depend explicitly on some
additional parameters wk ∈ Wk . The recursion is
s0 := x ∈ S0
s1 := f1 (s0 , w1 ) ∈ S1
..
.
sK := fK (sK−1 , wK ) ∈ SK
f (x, w) := sK ,
where S0 = X and w = (w1 , . . . , wK ) ∈ W1 × · · · × WK . Each fk is
now a function of two arguments. The first argument depends on the
previous layer, but the second argument does not. This is illustrated in
Fig. 7.5. We now explain how to differentiate a feedforward network.

7.2.1 Computing the adjoint


The function has the form f : E → F, where E := X × (W1 × · · · ×
WK ) and F := SK . From Section 2.3, we know that the VJP has the
form ∂f (x, w)∗ : F → E. Therefore, we want to be able to compute
∂f (x, w)∗ [u] ∈ E for any u ∈ F.
Fortunately, the backward recursion is only a slight modification
of the computation chain case. Indeed, since fk : Ek → Fk , where
Ek := Sk−1 × Wk and Fk := Sk , the intermediate VJPs have the form
∂fk (sk−1 , wk )∗ : Fk → Ek . We therefore arrive at the recursion
rK = u ∈ SK
(rK−1 , gK ) = ∂fK (sK−1 , wK )∗ [rK ] ∈ SK−1 × WK
..
.
(r0 , g1 ) = ∂f1 (s0 , w1 )∗ [r1 ] ∈ S0 × W1 .
The final output is
∂f (x, w)∗ [u] = (r0 , (g1 , . . . , gK )).
7.2. Feedforward networks 137

... ...

Figure 7.5: Computation graph of an MLP as a function of its parameters.

7.2.2 Computing the gradient


We often compose a network with a loss function

L(w; x, y) := ℓ(f (x, w); y) ∈ R.

From Proposition 2.7, the gradient is given by

∇L(w; x, y) = (g1 , . . . , gK ) ∈ W1 × · · · × WK

where
∂f (x, w)∗ [u] = (r0 , (g1 , . . . , gK )),
with u = ∇ℓ(f (x, w); y) ∈ SK . The output r0 ∈ S0 , where S0 =
X , corresponds to the gradient w.r.t. x ∈ X and is typically not
needed, except in generative modelling settings. The full procedure is
summarized in Algorithm 7.3.
138 Automatic differentiation

Forward pass

... ... ...

... ...

Backward pass

Figure 7.6: Reverse mode of automatic differentiation, a.k.a., gradient back-


propagation to compute the gradient of the loss of an MLP on an input label
pair. For readability, we denoted the intermediate VJPs as ∂fk∗ : (sk−1 , wk−1 , rk ) 7→
∂fk (sk−1 , wk )[rk ] with ∂fk (sk−1 , wk )[rk ] = (rk−1 , gk ).

Algorithm 7.3 Gradient back-propagation for feedforward networks


Functions: f1 , . . . , fK in sequential order
Inputs: data point (x, y) ∈ X × Y
parameters w = (w1 , . . . wK ) ∈ W1 × · · · × WK
1: Initialize s0 := x ▷ Forward pass
2: for k := 1, . . . , K do
3: Compute and store sk := fk (sk−1 , wk ) ∈ Sk
4: Compute ℓ(sK ; y) and u := ∇ℓ(sK ; y) ∈ SK
5: Initialize rK := u ∈ SK ▷ Backward pass
6: for k := K, . . . , 1 do
7: Compute (rk−1 , gk ) := ∂fk (sk−1 , wk )∗ [rk ] ∈ Sk−1 × Wk
8: Outputs: L(w; x, y) := ℓ(sK ; y), ∇L(w; x, y) = (g1 , . . . , gK )
7.3. Computation graphs 139

Algorithm 7.4 Forward-mode autodiff for computation graphs


Functions: f1 , . . . , fK in topological order
Inputs: input s0 ∈ S0 , input direction v ∈ S0
1: Initialize t0 := v
2: for k := 1, . . . , K do ▷ Forward pass
3: Retrieve parent nodes i1 , . . . , ipk := parents(k)
4: Compute sk := fk (si1 , . . . , sipk )
5: Compute

tk := ∂fk (si1 , . . . , sipk )[ti1 , . . . , tipk ]


pk
= ∂j fk (si1 , . . . , sipk )[tij ].
X

j=1

6: Outputs: f (s0 ) := sK , ∂f (s0 )[v] = tK

7.3 Computation graphs

In the previous sections, we reviewed autodiff for computation chains


and its extension to feedforward networks. In this section, we generalize
autodiff to computation graphs introduced in Section 4.1.3.

7.3.1 Forward-mode
The forward mode corresponds to computing a JVP in an input direction
v ∈ S0 . The algorithm consists in computing intermediate JVPs along
the forward pass. We initialize t0 := v ∈ S0 . Using Proposition 2.8, the
derivatives on iteration k ∈ [K] are propagated as

tk := ∂fk (si1 , . . . , sipk )[ti1 , . . . , tipk ]


pk
= ∂j fk (si1 , . . . , sipk )[tij ],
X

j=1

where i1 , . . . , ipk = parents(k). The final output is ∂f (s0 )[v] = tK . The


resulting generic forward-mode autodiff is summarized in Algorithm 7.4.
Although not explicitly mentioned, we can release sk and tk from
memory when no child node depends on node k.
140 Automatic differentiation

7.3.2 Reverse-mode

The reverse mode corresponds to computing a VJP in an output di-


rection u ∈ SK . We first perform a forward pass to compute the
intermediate values and store the corresponding VJP as a pointer, since
the VJP shares some computations with the function itself. From Propo-
sition 2.8, the VJP returns a tuple with the same length as the number
of inputs to the function:

δi1 ,k , . . . , δipk ,k = ∂fk (si1 , . . . , sipk )∗ [rk ].

After the forward pass, we traverse the graph in reverse topological


order as illustrated in Fig. 7.7. If an intermediate value sk is used by
later functions fj1 , . . . , fjck for j1 , . . . , jck = children(k), the derivatives
with respect to sk need to sum all the variations through the fj functions
in a variable rk ,
rk :=
X
δk,j .
j∈{j1 ,...,jck }

In practice, as we go backward through the computation graph, we can


accumulate the VJPs corresponding to node k by doing in-place updates
of the rk values. The topological ordering ensures that rk has been fully
computed when we reach node k. The resulting generic reverse-mode
autodiff is presented in Algorithm 7.5.

7.3.3 Complexity, the Baur-Strassen theorem

For computing the gradient of a function f : E → R represented by


a computation graph, we saw that the reverse mode is more efficient
than the forward mode. As we previously stated, assuming that the
elementary functions fk in the DAG and their VJP have roughly the
same computational complexity, then f and ∇f have roughly the same
computational complexity. This fact is crucial and is the pillar on which
modern machine learning relies: it allows us to optimize high-dimensional
functions by gradient descent.
For arithmetic circuits, reviewed in Section 4.1.4, this crucial fact
is made more precise in the celebrated Baur-Strassen theorem (Baur
and Strassen, 1983). If f is a polynomial, then so is its gradient ∇f .
7.4. Implementation 141

Algorithm 7.5 Reverse-mode autodiff for computation graphs


Functions: f1 , . . . , fK in topological order
Inputs: input s0 ∈ S0 , output direction u ∈ SK
1: for k := 1, . . . , K do ▷ Forward pass
2: Retrieve parent nodes i1 , . . . , ipk := parents(k)
3: Compute sk := fk (si1 , . . . , sipk )
4: Instantiate VJP lk := ∂fk (si1 , . . . , sipk )∗
5: Initialize rK := u, rk := 0 ∀k ∈ {0, . . . , K − 1} ▷ Backward pass
6: for k := K, . . . , 1 do
7: Retrieve parent nodes i1 , . . . , ipk = parents(k)
8: Compute δi1 ,k , . . . , δipk ,k := lk [rk ]
9: Compute rij ← rij + δij ,k ∀j ∈ {1, . . . , pk }
10: Outputs: f (s0 ) := sK , ∂f (s0 )∗ [u] = r0

The theorem gives a lower bound on the size of the best circuit for
computing ∇f from the size of the best circuit for computing f .

Proposition 7.1 (Baur-Strassen’s theorem). For any polynomial


f : E → R, we have
S(∇f ) ≤ 5 · S(f ),
where S(f ) is defined in Definition 4.1.

A simpler proof by backward induction was given by Morgenstern


(1985). See also the proof of Theorem 9.10 in Chen et al. (2011). For
general computation graphs, that have more primitive functions than
just + and ×, a similar result can be obtained; see, e.g., (Bolte et al.,
2022, Theorem 2).

7.4 Implementation

7.4.1 Primitive functions

An autodiff system implements a set A of primitive or elementary


functions, which serve as building blocks for creating other functions, by
function composition. For instance, we saw that in arithmetic circuits
142 Automatic differentiation

Computations Forward mode Reverse mode

Figure 7.7: Left: Assuming a topological order, the computation of fk on iteration


k involves pk inputs computed by fi1 , . . . , fipk , where {i1 , . . . , ipk } = parents(k),
and is used in ck functions fj1 , . . . , fjck , where {j1 , . . . , jck } = children(k). Middle:
Forward-mode autodiff on iteration k, denoting the shorthand spk := (si1 , . . . , sipk ).
This computes ∂i fk (spk )[tij ] for incoming tij , j = 1, . . . , pk , then sum these results
to pass tk to the next iterations. Right: Reverse-mode autodiff on iteration k. As
we traverse the graph backward, we first sum the contributions coming from each
child computation fj1 , . . . , fjck for {j1 , . . . , jck } = children(k) to get rk . We then
feed this rk to each ∂i fk (spk ) and continue the procedure in reverse topological order.

(Section 4.1.4), A = {+, ×}. More generally, A may contain all the
necessary functions for expressing programs. We emphasize, however,
that A is not necessarily restricted to low-level functions such as log
and exp, but may also contain higher-level functions. For instance,
even though the log-sum-exp can be expressed as the composition
of elementary operations (log, sum, exp), it is usually included as a
primitive on its own, both because it is a very commonly-used building
block, but also for numerical stability reasons.

7.4.2 Closure under function composition

Each function fk in a computation graph belongs to a set F, the class of


functions supported by the system. A desirable property of an autodiff
implementation is that the set F is closed under function composition,
meaning that if f ∈ F and g ∈ F, then f ◦ g ∈ F. This means that
composed functions can themselves be used for composing new functions.
This property is also crucial for supporting higher-order differentiation
(Chapter 8) and automatic linear transposition (Section 7.4.4). When
7.4. Implementation 143

fk is a composition of elementary functions in A, then fk itself is a


nested DAG. However, we can always inline each composite function,
such that all functions in the DAG belong to A.

7.4.3 Examples of JVPs and VJPs

An autodiff system must implement for each f ∈ A its JVP for support-
ing the forward mode, and its VJP for supporting the reverse mode.
We give a couple of examples. We start with the JVP and VJP of linear
functions.

Example 7.2 (JVP and VJP of linear functions). Consider the matrix-
vector product f (W ) = W x ∈ RM , where x ∈ RD is fixed and
W ∈ RM ×D . As already mentioned, the JVP of f at W ∈ RM ×D
along an input direction V ∈ RM ×D is simply

∂f (W )[V ] = f (V ) = V x ∈ RM .

To find the associated VJP, we note that for any u ∈ RM and


V ∈ RM ×D , we must have ⟨∂f (W )[V ], u⟩ = ⟨V , ∂f (W )∗ [u]⟩.
Using the properties of the trace, we have

⟨∂f (W )[V ], u⟩ = ⟨V x, u⟩ = tr(x⊤ V ⊤ u) = ⟨V , ux⊤ ⟩.

Therefore, we find that the VJP is given by

∂f (W )∗ [u] = ux⊤ ∈ RM ×D .

Similarly, consider now a matrix-matrix product f (W ) = W X,


where W ∈ RM ×D and where X ∈ RD×N is fixed. The JVP at
W ∈ RM ×D along an input direction V ∈ RM ×D is simply

∂f (W )[V ] = f (V ) = V X ∈ RM ×N .

The VJP along the output direction U ∈ RM ×N is

∂f (W )∗ [U ] = U X ⊤ ∈ RM ×D .

Another simple example are element-wise separable functions.


144 Automatic differentiation

Example 7.3 (JVP and VJP of separable function). Consider the func-
tion f (w) := (g1 (w1 ), . . . , gP (wP )), where each gi : R → R has a
derivative gi′ . The Jacobian matrix is then a diagonal matrix

∂f (w) = diag(g1′ (w1 ), . . . , gP′ (wP )) ∈ RP ×P .

In this case, the JVP and VJP are actually the same

∂f (w)[v] = ∂f (w)∗ [v] = (g1′ (w1 ), . . . , gP′ (wP )) ⊙ v,

where ⊙ indicates element-wise multiplication.

7.4.4 Automatic linear transposition


On first sight, if we want to support both forward ans reverse modes, it
appears like we need to implement both the JVP and the VJP for each
primitive operation f ∈ A. Fortunately, there exists a way to recover
VJPs from JVPs, and vice-versa.
We saw in Section 2.3 that if l(w) is a linear map, then its JVP is
∂l(w)[v] = l(v) (independent of w). Conversely, the VJP is ∂l(w)∗ [u] =
l∗ (u), where l∗ is the adjoint operator of l (again, independent of w).
Let us define l(u; w) := ∂f (w)∗ [u], i.e., the VJP of f in the output
direction u. Since l(u; w) is linear in u, we can apply the reasoning
above to compute its VJP
∂l(u; w)∗ [v] = l∗ (v; w) = ∂f (w)∗∗ [v] = ∂f (w)[v],
which is independent of u. In words, the VJP of a VJP is the cor-
responding JVP! This means that we can implement forward-mode
autodiff even if we only have access to VJPs. As an illustration and
sanity check, we give the following example.

Example 7.4 (Automatic transpose of “dot”). If we define


f (x, W ) := W x, from Example 7.2, we know that

∂f (x, W )∗ [u] = (W ⊤ u, ux⊤ )


= (f (u, W ⊤ ), f (x⊤ , u))
=: l(u; x, W ).
7.5. Checkpointing 145

Using Proposition 2.9, we obtain

∂l(u; x, W )∗ [v, V ] = f (v, W ) + f (x, V )


= Wv + V x
= ∂f (x, W )[v, V ].

The other direction, automatically creating a VJP from a JVP, is


also possible but is more technical and relies on the notion of partial
evaluation (Frostig et al., 2021; Radul et al., 2022).

7.5 Checkpointing

We saw that forward-mode autodiff can release intermediate computa-


tions from memory along the way, while reverse-mode autodiff needs to
cache all of them. This means that the memory complexity of reverse-
mode autodiff, in its standard form, grows linearly with the number of
nodes in the computation graph. A commonly-used technique to circum-
vent this issue is checkpointing, which trades-off computation time for
better memory usage. Checkpointing works by selectively storing only a
subset of the intermediate nodes, called checkpoints, and by comput-
ing others on-the-fly. The specific choice of the checkpoint locations in
the computation graph determines the memory-computation trade-off.
While it is possible to heuristically set checkpoints at user-specified
locations, it is also possible to perform a checkpointing strategy algo-
rithmically, as studied in-depth by Griewank (1992) and Griewank and
Walther (2008). In this section, we review two such algorithms: recursive
halving and dynamic programming (divide-and-conquer). Our exposi-
tion focuses on computation chains f = fK ◦ . . . ◦ f1 with fi : RD → RD
for simplicity.

Computational and memory complexities at two extremes. Let C(K)


be the number of calls to the individual functions fi (we ignore the
cost of computing the intermediate VJPs) and M(K) be the number
of function inputs cached, when performing reverse-mode autodiff on
a chain f = fK ◦ . . . ◦ f1 . On one extreme, if we store all intermediate
computations, as done in Algorithm 7.1, to compute only the VJP
146 Automatic differentiation

Algorithm 7.6 Reverse-mode autodiff with constant memory


vjp_full_recompute(fK ◦ . . . ◦ f1 , s0 , u) := ∂(fK ◦ . . . ◦ f1 )(s0 )∗ [u]
Inputs: Chain fK ◦ . . . ◦ f1 , input s0 ∈ S0 , output direction u ∈ SK
1: if K = 1 then
2: return ∂f1 (s0 )∗ [u]
3: else
4: Set rK = u
5: for k := K, . . . , 1 do
6: Compute sk−1 = (fk−1 ◦ . . . ◦ f1 )(s0 )
7: Compute rk−1 = ∂fk (sk−1 )∗ [rk ]
8: return: r0

∂f (s0 )∗ [u], we have


C(K) = K − 1 and M(K) = K.
This is optimal w.r.t. computational complexity, but suboptimal w.r.t.
memory. On the other extreme, if we only store the initial input, as
done in Algorithm 7.6, then we have
C(K) = K(K − 1)/2 and M(K) = 1.
This is optimal w.r.t. memory but leads to a computational complexity
that is quadratic in K.

7.5.1 Recursive halving


As a first step towards obtaining a better computation-memory trade-off,
we may split the chain sK = fK ◦ · · · ◦ f1 (s0 ) as
sK/2 = fK/2 ◦ . . . ◦ f1 (s0 )
sK = fK ◦ . . . ◦ fK/2+1 (sK/2 ),
for K even. Then, rather than recomputing all intermediate computa-
tions sk from the input s0 as in Algorithm 7.6, we can store sK/2 and
recompute sk for k > K/2 starting from sK/2 . Formally, this strategy
amounts to the follwoing steps.
1. Compute sK/2 = fK/2 ◦ . . . ◦ f1 (s0 )
7.5. Checkpointing 147

Algorithm 7.7 Reverse-mode autodiff with recursive halving


vjp_halving(fK ◦ . . . ◦ f1 , s0 , u) := ∂(fK ◦ . . . ◦ f1 )(s0 )∗ [u]
Functions: Chain fK ◦ . . . ◦ f1
Inputs: input s0 ∈ S0 , output direction u ∈ SK
1: if K = 1 then
2: return ∂f1 (s0 )∗ [u]
3: else
4: Compute sK/2 = fK/2 ◦ . . . ◦ f1 (s0 )
5: Compute rK/2 = vjp_halving(fK ◦ . . . ◦ fK/2+1 , sK/2 , u)
6: Compute r0 = vjp_halving(fK/2 ◦ . . . ◦ f1 , s0 , rK/2 )
7: return: r0

2. Compute rK/2 = vjp_full_recompute(fK ◦ . . . ◦ fK/2+1 , sK/2 , u)

3. Compute r0 = vjp_full_recompute(fK/2 ◦ . . . ◦ f1 , s0 , rK/2 )

At the expense of having to store the additional checkpoint sK/2 , this


already roughly halves the computational complexity compared to Al-
gorithm 7.6.
We can then apply this reasoning recursively, as formalized in Al-
gorithm 7.7. The algorithm is known as recursive binary schedule
(Griewank, 2003) and illustrated in Fig. 7.8. In terms of number of
function evaluations C(K), for K even, we make K/2 function calls,
and we call the procedure recursively twice, that is,

C(K) = 2C(K/2) + K/2.

If the chain is of length 1, we directly use the VJP, so C(1) = 0. Hence,


the numbers of function calls, if K is a power of 2, is
K
C(K) = log2 K.
2
In terms of memory usage, Algorithm 7.7 uses s0 not only at line 4
but also at line 6. So when the algorithm is called recursively on the
second half of the chain at line 5, one memory slot is taken by s0 . This
line is called recursively until the chain is reduced to a single function.
At that point, the total number of memory slots used is equal to the
148 Automatic differentiation

Function step

2
1
0
Time step

Forward computation Storage in memory Backward computation

Figure 7.8: Illustration of checkpointing with recursive halving, for a chain of


8 functions. The chain is first fully evaluated while storing some computations
as checkpoints in memory. Then, during the backward pass, we recompute some
intermediate values from the latest checkpoint available. In contrast, vanilla reverse-
mode autodiff (with full caching of the intermediate computations) would lead to a
simple triangle shape.

number of times we split the function in half, that is log2 K for K a


power of 2. On the other hand, the input s0 is no longer used after line 6
of Algorithm 7.7. At that line, the memory slot taken by s0 can be
consumed by the recursive call on the first-half. In other words, calling
the algorithm recursively on the first half does not incur extra memory
cost. So if K is a power of 2, the memory cost of Algorithm 7.7 is

M(K) = log2 K.

7.5.2 Dynamic programming


Recursive halving requires log2 K memory slots for a chain of length
K. However, as illustrated in Fig. 7.8, at a given time step, all memory
slots may not be exploited.
To optimize the approach, we observe that recursive halving is
just one instance of a program that splits the chain and calls itself
recursively on each part. In other words, it is a form of divide-and-
conquer algorithm. Rather than splitting the chain in half, we may
consider splitting the chain at some index l. One split is used to reverse
7.5. Checkpointing 149

the computations from l + 1 to K by a recursive call that consumes one


memory slot. The other split is used on a recursive call that reverses
the computations from 0 to l. That second call does not require an
additional memory slot, as it can use directly the memory slot used
by the original input s0 . To split the chain in such two parts, we need
l intermediate computations to go from s0 to sl . The computational
complexity C(k, s), counted as the number of function evaluations, for
a chain of length k with s memory slots then satisfies the recurrence

C(k, s) = C(k − l, s − 1) + C(l, s) + l,

for all l ∈ {1, . . . , k − 1}. By simply taking l = k/2, we recover exactly


the computational complexity of recursive halving. To refine the latter,
we may split the chain by selecting l to minimize the complexity. An
optimal scheme must satisfy the recursive equation,

C ∗ (k, s) := min {C ∗ (k − l, s − 1) + C ∗ (l, s) + l}. (7.3)


1≤l≤K−1

Note that C ∗ (K, S) can be computed from C ∗ (k, s) for k = 1, . . . , K − 1,


s = 1, . . . , S − 1. This suggests a dynamic programming approach to
find an optimal scheme algorithmically. For a chain of length k = 1, the
cost is null as we directly reverse the computation, so C ∗ (1, s) := 0. On
the other hand for a memory s = 1, there is only one possible scheme
that saves only the initial input as in Algorithm 7.6, so C ∗ (k, 1) :=
(k(k − 1))/2. The values C ∗ (k, s) can then be computed incrementally
from k = 1 to K and s = 1 to S using Eq. (7.3). The optimal splits can
be recorded along the way as

l∗ (k, s) := arg min{C ∗ (k − l, s − 1) + C ∗ (l, s) + l}.


1≤l≤k−1

The optimal split for K, S can then be found by backtracking the


optimal splits along both branches corresponding to C ∗ (k − l, s − 1)
and C ∗ (l, s). As the final output consists in traversing a binary tree,
it was called treeverse (Griewank, 1992). Note that the dynamic
programming procedure is generic and could a priori incorporate varying
computational costs for the intermediate functions fk .
150 Automatic differentiation

Analytical formula

It turns out that we can also find an optimal scheme analytically. This
scheme was found by Griewank (1992), following the analysis of optimal
inversions of sequential programs by divide-and-conquer algorithms
done by Grimm et al. (1996); see also Griewank (2003, Section 6) for
a simple proof. The main idea consists in considering the number of
times an evaluation step fk is repeated. As we split the chain at l, all
steps from 1 to l will be repeated at least once. In other words, treating
the second half of the chain incurs one memory cost, while treating the
first half of the chain incurs one repetition cost. Griewank (1992) shows
that for fixed K, S, we can find the minimal number of repetitions
analytically and build the corresponding scheme with simple formulas
for the optimal splits.
Compared to the dynamic programming approach, it means that we
do not need to compute the pointers l∗ (k, s), and we can use a simple
formula to set l∗ (k, s). We still need to traverse the corresponding binary
tree given K, S and the l∗ (k, s) to obtain the schedules. Note that such
optimal scheme does not take into account varying computational costs
for the functions fk .

7.5.3 Online checkpointing

The optimal scheme presented above requires knowing the total number
of nodes in the computation graph ahead of time. However, when
differentiating through for example a while loop (Section 5.10), this is
not the case. To circumvent this issue, online checkpointing schemes
have been developed and proven to be nearly optimal (Stumm and
Walther, 2010; Wang et al., 2009). These schemes start by defining a set
of S checkpoints with the first S computations, then these checkpoints
are rewritten dynamically as the computations keep going. Once the
computations terminate, the optimal approach presented above for a
fixed length is applied on the set of checkpoints recorded.
7.6. Reversible layers 151

Algorithm 7.8 Reverse-mode autodiff for reversible chains.


Functions: f := fK ◦ . . . ◦ f1 , with each fk invertible
Inputs: input s0 ∈ S0 , output direction u ∈ SK
1: Compute sK = fK ◦ . . . ◦ f1 (s0 )
2: for k := K, . . . , 1 do
3: Compute sk−1 = fk−1 (sk )
4: Compute rk−1 = ∂fk (sk−1 )∗ [rk ]
Outputs: f (s0 ) := sK , ∂f (s0 )∗ [u] = r0

7.6 Reversible layers

7.6.1 General case

The memory requirements of reverse-mode autodiff can be completely


alleviated when the functions fk are invertible (meaning that fk−1 exists)
and when fk−1 is easily accessible. In that case, rather than storing the
intermediate computations sk−1 , necessary to compute the VJP rk 7→
∂fk (sk−1 )∗ [rk ], one can compute them on the fly during the backward
pass from sk using sk−1 = fk−1 (sk ). We summarize the procedure for
the case of computation chains in Algorithm 7.8. Compared to vanilla
reverse-mode autodiff in Algorithm 7.2, the algorithm has optimal
memory complexity, as we can release sk and rk as we go.
In practice, fk−1 often does not exist or may not be easily accessible.
However, network architectures can be constructed to be easily invertible
by design. Examples include reversible residual networks (Gomez et
al., 2017), orthonormal RNNs (Helfrich et al., 2018), neural ODEs
(Section 11.6), and momentum residual neural networks (Sander et al.,
2021a); see also references therein.

7.6.2 Case of orthonormal JVPs

When the JVP of each fk is an orthonormal linear mapping, i.e.,

∂fk (sk−1 )−1 = ∂fk (sk−1 )∗ ,


152 Automatic differentiation

it is easy to check that the VJP of f = fK ◦ . . . ◦ f1 is equal to the JVP


of f −1 = f1−1 ◦ . . . ◦ fK
−1
, that is
∂f (s0 )∗ [u] = ∂f −1 (sK )[u].
In other words, in the case of orthormal JVPs, reverse-mode autodiff of
f coincides with forward-mode autodiff of f −1 .

7.7 Randomized forward-mode estimator

Forward-mode autodiff does not require to store intermediate activations.


However, for a function f : RP → R, computing the gradient ∇f using
forward-mode autodiff requires P JVPs, which is intractable if P is large.
Can we approximate ∇f with fewer JVPs? The following proposition
gives an unbiased estimator of ∇f that only involves JVPs.

Proposition 7.2 (Unbiased forward-mode estimator of the gradient).


Let f : RP → R be a differentiable function. Then,

∇f (µ) = EZ∼p [∂f (µ)[Z]Z]


= EZ∼p [⟨∇f (µ), Z⟩Z] .

where p := Normal(0, 1)P is the isotropic Gaussian distribution.

This estimator is for instance used by Baydin et al. (2022). It


can be seen as the zero-temperature limit of the gradient of a
perturbed function, estimated by the score-function estimator (SFE);
see Section 13.4.6.
In practice, the expectation above can be approximated by drawing
M noise vectors z1 , . . . , zM , and averaging ⟨∇f (µ), zi ⟩ over i ∈ [M ].
A word of caution: while this estimator can be useful for example
when we do not want to store the intermediate activations for memory
reasons, this of course comes at the cost of increasing the variance,
which influences the convergence rate of SGD, as seen in Section 15.2.

7.8 Summary

Computer programs can be seen as directed acyclic graphs, where


nodes correspond to the output of intermediate operations in the pro-
7.8. Summary 153

gram, and edges represent the dependences of current operations on


past operations.
Automatic differentiation (autodiff) for a function f : RP → RM has
two main modes: forward mode and reverse mode. The forward
mode: i) uses JVPs, ii) builds the Jacobian one column at a time, iii)
is efficient for tall Jacobians (M ≥ P ), iv) need not store intermediate
computations. The reverse mode: i) uses VJPs, builds the Jacobian one
row at a time, iii) is efficient for wide Jacobians (P ≥ M ), iv) needs
to store intermediate computations, in order to be computationally
optimal. To trade computational efficiency for better memory efficiency,
we can use checkpointing techniques.
We saw that the complexity of computing the gradient of a function
f : RP → R using the reverse mode is at most a constant time bigger
than that of evaluating the function itself. This is the Baur-Strassen
theorem, in arithmetic circuits. This astonishing result is one of the
pillars of modern machine learning.
8
Second-order automatic differentiation

We review in this chapter how to perform automatic differentiation for


second-order derivatives.

8.1 Hessian-vector products

We consider in this section a function f : E → R. Similarly to the


Jacobian, for most purposes, we do not need access to the full Hessian
but rather to the Hessian-vector product (HVP) ∇2 f (w)[v] at w ∈ E,
in a direction v ∈ E, as defined in Definition 2.18. The latter can be
computed in four different ways, depending on how we combine the two
main modes of autodiff.

8.1.1 Four possible methods


An HVP can be computed in four different ways.

1. Reverse on reverse: The Hessian can be seen as the transposed


Jacobian of the gradient, hence the HVP can be computed as the
VJP of the gradient,

∇2 f (w)[v] = ∂(∇f )(w)∗ [v].

154
8.1. Hessian-vector products 155

2. Forward on reverse: Owing to its symmetry (see Proposi-


tion 2.10), the Hessian can also be seen as the Jacobian of the
gradient, hence the HVP can be computed as the JVP of the
gradient,
∇2 f (w)[v] = ∂(∇f )(w)[v].

3. Reverse on forward: Recall that for any function g : E → E, the


VJP can equivalently be defined as the gradient along an output
direction v ∈ E, that is,
∂g(w)∗ [v] = ∇⟨g, v⟩(w),
where we recall the shorthand ⟨g, v⟩(w) := ⟨v, g(w)⟩, so that
⟨g, v⟩ is a function of w. In our case, we can therefore rewrite the
reverse-on-reverse approach as
∂(∇f )(w)∗ [v] = ∇⟨∇f, v⟩(w).
We know that ⟨∇f, v⟩(w) = ⟨∇f (w), v⟩ = ∂f (w)[v] is the JVP
of f at w along v. Therefore, we can also compute the HVP as
the gradient of the JVP of f at w along v,
∇2 f (w)[v] = ∇(∂f (·)[v])(w),
where we use the notation (∂f (·)[v])(w) := ∂f (w)[v] to insist on
the fact that it is a function of w.

4. Forward on forward: Finally, we can use the definition of the


HVP in Definition 2.18 as a vector of second partial derivatives
along v and each canonical direction. That is, assuming E = RP ,
we can compute the JVP of the JVP P times,
∇2 f (w)[v] = (∂ 2 f (w)[v, ei ])Pi=1 .

The four different ways of computing the HVP are summarized in


Table 8.1.

8.1.2 Complexity
To get a sense of the computational and memory complexity of the four
approaches, we consider a chain of functions f := fK ◦ · · · ◦ f1 as done
156 Second-order automatic differentiation

Method Computation
Reverse on reverse (VJP of gradient) ∂(∇f )(w)∗ [v]
Forward on reverse (JVP of gradient) ∂(∇f )(w)[v]
Reverse on forward (gradient of JVP) ∇(∂f (·)[v])(w)
Forward on forward (JVPs of JVPs) (∂ 2 f (w)[v, ei ])Pi=1

Table 8.1: Four different ways of computing the HVP ∇2 f (w)[v].

Figure 8.1: Computation graph corresponding to reverse mode autodiff for eval-
uating the gradient of f = fK ◦ . . . f1 . While f is a simple chain, ∇f is a DAG.

in Section 7.1. To simplify our analysis, we assume fk : RP → RP for


k ∈ {1, . . . , K − 1} and fK : RP → R.
The computation graph of the reverse mode is illustrated in Fig. 8.1.
While f = fK ◦ · · · ◦ f1 would be represented by a simple chain, the
computational graph of ∇f is no longer a chain: it is a DAG. This
is due to the computations of ∂fk (sk−1 )[rk ], where both sk−1 and rk
depend on s0 .
We illustrate the computation graphs of reverse-on-reverse and
forward-on-reverse in Fig. 8.2 and Fig. 8.3 respectively. By applying
reverse mode on reverse mode, at each fan-in operation sk−1 , rk 7→
∂fk (sk−1 )[rk ], the reverse mode on ∇f branches out in two paths that
are later merged by a sum. By applying forward mode on top of reverse
mode, the flow of computations simply follows the one of ∇f .
With this in mind, following a similar calculation as for Table 7.1,
we obtain the following results. We assume that each ∂fk (sk−1 ) is a
dense linear operator, so that its application has the same cost as a
matrix-vector multiplication. For the memory complexity, we consider
that the inputs of each operation is saved to compute the required
8.1. Hessian-vector products 157

Gradient computation HVP computations


by reverse mode auto-diff by reverse mode on top of reverse mode

Figure 8.2: Computation graph for computing the HVP ∇2 f (x)[v] by using reverse
mode on top of reverse mode. As the computation graph of ∇f induces fan-in
operations sk−1 , rk 7→ ∂fk (sk−1 )[rk ], the reverse mode applied on ∇f induces
branching of the computations at each such node.

Gradient computation HVP computations


by reverse mode auto-diff by forward mode on top of reverse mode

Figure 8.3: Computation graph for computing the HVP ∇2 f (x)[v] by using forward
mode on top of reverse mode. The forward mode naturally follows the computations
done for the gradient, except that it passes through the derivatives of the intermediate
operations.
158 Second-order automatic differentiation

derivatives in the backward passes.

1. Reverse on reverse: O(KP 2 ) time and O(KP ) space.

2. Forward on reverse: O(KP 2 ) time and O(KP ) space.

3. Reverse on forward: O(KP 2 ) time and O(KP ) space.

4. Forward on forward: O(KP 3 ) time and O(3P ) space for the


P JVPs with e1 , . . . , eP .

We see that, for chains of functions, “reverse on reverse”, “forward


on reverse” and “reverse on forward” all have similar time complexities
up to some constant factors. Using reverse mode on top of reverse
mode requires storing the information backpropagated, i.e., the rk
(resp. the information forwarded, i.e., the tk in Fig. 7.1), to perform
the final reverse pass. By using forward mode on top of reverse mode,
this additional cost is not incurred, making it slightly less memory
expensive. In addition, reverse mode on top of reverse mode induces a
few additional summations due to the branching and merge operations
depicted in Fig. 8.2. The same holds when using reverse on top of
forward as we cannot avoid fan-in operations (this time of the form
sk−1 , tk−1 7→ ∂fk (sk−1 )[tk−1 ]). Unfortunately, “forward on forward” is
prohibitively expensive.
To summarize, among the four approaches presented to compute
HVPs, the forward-over-reverse mode is a priori the most preferable
in terms of computational and memory complexities. Note, however,
that computations of higher derivatives can benefit from dedicated
autodiff implementations such as Taylor mode autodiff, that do not
merely compose forward and reverse modes. For general functions f , it
is reasonable to benchmark the first three methods to determine which
method is the best for the function at hand.
8.2. Gauss-Newton matrix 159

8.2 Gauss-Newton matrix

8.2.1 An approximation of the Hessian


The Hessian matrix ∇2 L(w) of a function L : W → R is often used to
construct a quadratic approximation of L(w),
1
L(w + v) ≈ ⟨∇L(w), v⟩ + ⟨v, ∇2 L(w)v⟩.
2
Unfortunately, when L is nonconvex, ∇2 L(w) is typically an indefinite
matrix, which means that the above approximation is a nonconvex
quadratic w.r.t. v. For instance, if L = ℓ ◦ f with ℓ convex, then L is
convex if f is linear, but it is typically nonconvex if f is nonlinear. The
(generalized) Gauss-Newton matrix is a principled alternative to the
Hessian, which is defined for L := ℓ ◦ f .

Definition 8.1 (Gauss-Newton matrix). Given a differentiable func-


tion f : W → M and a twice differentiable function ℓ : M → R, the
(generalized) Gauss-Newton matrix of the composition L = ℓ ◦ f
evaluated at a point w ∈ W is defined as

∇2GN (ℓ ◦ f )(w) := ∂f (w)∗ ∇2 ℓ(f (w))∂f (w).

As studied in Section 16.2, the Gauss-Newton matrix is a key ingre-


dient of the Gauss-Newton method. An advantage of the Gauss-Newton
matrix is its positive semi-definiteness provided that ℓ is convex.

Proposition 8.1 (Positive semi-definiteness of the GN matrix). If ℓ is


convex, then ∇2GN (ℓ ◦ f )(w) is positive semi-definite for all f .

This means that the approximation


1
L(w + v) ≈ ⟨∇L(w), v⟩ + ⟨v, ∇2GN L(w)v⟩
2
is a convex quadratic w.r.t. v.
Using the chain rule, we find that the Hessian of L = ℓ◦f decomposes
into the sum of two terms (see also Proposition 8.7).
160 Second-order automatic differentiation

Proposition 8.2 (Approximation of the Hessian). For f differentiable


and ℓ twice differentiable, we have

∇2 (ℓ ◦ f )(w) = ∂f (w)∗ ∇2 ℓ(f (w))∂f (w) + ∂ 2 f (w)∗ [∇ℓ(f (w))]


Z
= ∇2GN (ℓ ◦ f )(w) + ∇j ℓ(f (w))∇2 fj (w).
X

j=1

If f is linear, then the Hessian and Gauss-Newton matrices coincide,

∇2 (ℓ ◦ f )(w) = ∇2GN (ℓ ◦ f )(w).

The Gauss-Newton operator ∇2GN (ℓ ◦ f ) can therefore be seen as an


approximation of the Hessian ∇2 (ℓ ◦ f ), with equality if f is linear.

8.2.2 Gauss-Newton chain rule

A chain rule for computing the Hessian of a composition of two functions


is presented in Proposition 8.7, but the formula is relatively complicated,
due to the cross-terms. In contrast, a Gauss-Newton chain rule is
straightforward.

Proposition 8.3 (Gauss-Newton chain rule).

∇2GN (ℓ ◦ f ◦ g)(w) = ∂g(w)∗ ∇2GN (ℓ ◦ f )(g(w))∂g(w).

8.2.3 Gauss-Newton vector product

As for the Hessian, we rarely need to materialize the full Gauss-Newton


matrix in memory. Indeed, we can define the Gauss-Newton vector
product (GNVP), a linear map for a direction v ∈ W, as

∇2GN (ℓ ◦ f )(w)[v] := ∂f (w)∗ ∇2 ℓ(f (w))∂f (w)v, (8.1)

where ∇2 ℓ(θ)u is the HVP of ℓ, a linear map from M to M. The


GNVP can be computed using the JVP of f , the HVP of ℓ and the
VJP of f . Instantiating the VJP requires 1 forward pass through f ,
from which we get both the value f (w) and the adjoint linear map
u 7→ (∂f (w)∗ u). Evaluating the VJP requires 1 backward pass through
8.2. Gauss-Newton matrix 161

f . Evaluating the JVP requires 1 forward pass through f . In total,


evaluating v 7→ ∇2GN (ℓ ◦ f )(w)v therefore requires 2 forward passes and
1 backward pass through f .

8.2.4 Gauss-Newton matrix factorization

In this section, we assume W ⊆ RP and M ⊆ RM . When ℓ is convex,


we know that the Gauss-Newton matrix is positive semi-definite and
therefore it can be factorized into ∇2GN (ℓ ◦ f )(w) = V V ⊤ for some
V ∈ RP ×R , where R ≤ min{P, M } is the rank of the matrix. Such a
decomposition can actually be computed easily from a factorization of
the Hessian of ℓ. For instance, suppose we know the eigendecomposition
of the Hessian of ℓ, ∇2 ℓ(f (w)) = M j=1 λi ui ui , where the ui are the

P

eigenvectors and the λi ≥ 0 are the eigenvalues (which we know are


non-negative due to positive semidefinitess). Then, the Gauss-Newton
matrix can be decomposed as

M
∇2GN (ℓ ◦ f ) = λi ∂f (w)∗ ui u⊤
i ∂f (w)

X

j=1
M p  p ⊤
= λi ∂f (w)∗ ui λi ∂f (w)∗ ui
X

j=1
M
= vi vi⊤ where vi := λi ∂f (w)∗ ui .
X p

j=1

Stacking the vectors vi into a matrix V = (v1 , . . . , vM ), we recover


the factorization ∇2GN (ℓ ◦ f )(w) = V V ⊤ . To form this decomposition,
we need to perform the eigendecomposition of ∇2 ℓ(f (w)) ∈ RM ×M ,
which takes O(M 3 ) time. We also need M calls to the VJP of f at w.
Compared to the direct implementation in Eq. (8.1), the factorization,
once computed, allows us to compute the Gauss-Newton vector product
(GNVP) as ∇2GN (ℓ ◦ f )(w)[v] = V V ⊤ v. The factorization only requires
P × M memory, while the direct implementation in Eq. (8.1) requires
us to maintain the intermediate computations of f . The computation-
memory trade-offs therefore depend on the function considered.
162 Second-order automatic differentiation

8.2.5 Stochastic setting

Suppose the objective function is of the form

L(w; x, y) := ℓ(f (w; x); y).

With some slight abuse of notation, we then have that the Gauss-Newton
matrix associated with a pair (x, y) is

∇2GN L(w; x, y) := ∂f (w; x)∗ ∇2 ℓ(θ; y)∂f (w; x).

Given a distribution ρ over (x, y) pairs, the Gauss-Newton matrix


associated with the averaged loss

L(w) := EX,Y ∼ρ [L(w; X, Y )]

is then h i
∇2GN L(w) = EX,Y ∼ρ ∇2GN L(w; X, Y ) .

8.3 Fisher information matrix

8.3.1 Definition using the score function

The Fisher information is a way to measure the amount of information


in a random variable S.

Definition 8.2 (Fisher information matrix). The Fisher informa-


tion matrix, or Fisher for short, associated with the negative
log-likelihood L(w; S) = − log qw (S) of a probability distribution
qw with parameters w is the covariance of the gradients of L at w
for S distributed according to qw ,

∇2F L(w) := ES∼qw [∇L(w; S) ⊗ ∇L(w; S)]


= ES∼qw [∇w log qw (S) ⊗ ∇w log qw (S)].

The gradient ∇w log qw (S) is known as the score function.

As studied in Section 16.3, the Fisher information matrix is a key


ingredient of the natural gradient descent method.
8.3. Fisher information matrix 163

8.3.2 Link with the Hessian


Provided that the probability distribution is twice diferentiable w.r.t. w
with integrable second derivatives, the Fisher information matrix can
also be expressed as the Hessian of the negative log-likelihood (Amari,
1998; Martens, 2020).

Proposition 8.4 (Connection with the Hessian). The Fisher infor-


mation matrix of the negative log-likelihood L(w; S) = − log qw (S)
satisfies

∇2F L(w) = ES∼qw [∇2 L(w; S)] = ES∼qw [−∇2w log qw (S)].

Remark 8.1 (Empirical Fisher). We emphasize that in the above


definitions, S is sampled from the model distribution qw , not from
the data distribution ρ. That is, we have

∇2F L(w) = ES∼qw [∇w log qw (S)∇w log qw (S)⊤ ]


̸= ES∼ρ [∇w log qw (S)∇w log qw (S)⊤ ]

The latter is sometimes called ambiguously the “empirical” Fisher,


though this name has generated confusion (Kunstner et al., 2019).

8.3.3 Equivalence with the Gauss-Newton matrix


So far, we discussed the Fisher information for a generic random variable
S ∼ qw . We now discuss the supervised probabilistic learning setting
where S = (X, Y ) and where, using the product rule of probability,
we define the PDF qw (X, Y ) := ρX (X)pθ (Y ), with the shorthand θ :=
f (w; X).

Proposition 8.5 (Fisher matrix in supervised setting). Suppose


(X, Y ) ∼ qw where the PDF of qw is qw (X, Y ) := ρX (X)pθ (Y ).
In that case, the Fisher information matrix of the negative log-
164 Second-order automatic differentiation

likelihood L(w; x, y) = − log qw (x, y) decomposes as,

∇2F L(w) = E(X,Y )∼qw [∇w log qw (X, Y ) ⊗ ∇w log qw (X, Y )]


= EX∼ρX [EY ∼pθ [∇w log pθ (Y ) ⊗ ∇w log pθ (Y )]]
h i
= EX∼ρX ∂f (w; X)∗ ∇2F ℓ(θ)∂f (w; X) ,

where we defined the shorthand θ := f (w; X) and where we defined


the negative log-likelihood loss ℓ(θ; Y ) := − log pθ (Y ).

When pθ is an exponential family distribution, we can show that the


Fisher information matrix and the Gauss-Newton matrix are equivalent.

Proposition 8.6 (Equivalence between Fisher and Gauss-Newton). If


pθ is an exponential family distribution, then

∇2F L(w) = EX∼ρX EY ∼pθ [∇L(w; X, Y ) ⊗ ∇L(w; X, Y )]


= EX∼ρX EY ∼pθ [∂f (w; X)∗ ∇ℓ(θ; Y ) ⊗ ∇ℓ(θ; Y )∂f (w; X)]
= EX∼ρX EY ∼pθ [∂f (w; X)∗ ∇2 ℓ(θ; Y )∂f (w; X)]
h i
= EX,Y ∼ρ ∇2GN L(w; X, Y ) ,

where ρX (x) :=
R
ρ(x, y)dy.

Proof. From Proposition 3.3, if pθ is an exponential family distribution,


∇2 ℓ(θ, y) is actually independent of y. Using Bartlett’s second identity
Eq. (11.5), we then obtain

∇2 ℓ(θ; ·) = EY ∼pθ [∇2 ℓ(θ; Y )]


= EY ∼pθ [∇2 ℓ(θ; Y )]
= EY ∼pθ [−∇2θ log pθ (Y )]
= EY ∼pθ [∇θ log pθ (Y ) ⊗ ∇θ log pθ (Y )]
= EY ∼pθ [∇ℓ(θ; Y ) ⊗ ∇ℓ(θ; Y )],

where we used · to indicate that the results holds for all y. Plugging the
result back in the Fisher information matrix concludes the proof.
8.4. Inverse-Hessian vector product 165

8.4 Inverse-Hessian vector product

8.4.1 Definition as a linear map

We saw in Section 16.1 that Newton’s method uses iterations as

wt+1 = wt − ∇2 L(wt )−1 ∇L(wt ).

The inverse is well-defined if for example L is strictly convex. Otherwise,


we saw that some additional regularization can be added. Newton’s
method therefore requires to access inverse-Hessian vector products
(IHVPs), as defined below.

Definition 8.3 (Inverse-Hessian vector product). For a twice differ-


entiable function L : RP → R, we define the inverse-Hessian
Vector Product (IHVP) of L at w ∈ RP as the linear map

u 7→ ∇2 L(w)−1 u,

provided that it exists. In other words, it is the linear map which


to u associates v such that ∇2 L(w)v = u.

8.4.2 Implementation with matrix-free linear solvers

Numerous direct methods exist to compute the inverse of a matrix,


such as the Cholesky decomposition, QR decomposition and Gaussian
elimination. However, these algorithms require accessing elementary
entries of the matrix, while an autodiff framework gives access to the
Hessian through HVPs. Fortunately, there exists so-called matrix-free
algorithms, that can solve a linear system of equations

H[v] = u

by only accessing the linear map v 7→ H[v] for any v. Among such
algorithms, we have the conjugate gradient (CG) method, that applies
for H positive-definite, i.e., such that ⟨v, H[v]⟩ > 0 for all v ̸= 0, or the
generalized minimal residual (GMRES) method, that applies for
any invertible H. A longer list of solvers can be found in public software
such as SciPy (Virtanen et al., 2020). The IHVP of a striclty convex
166 Second-order automatic differentiation

function (ensuring that the Hessian is positive definite) can therefore


be computed by instantiating CG on the HVP,

∇2 L(w)−1 u ≈ CG(∇2 L(w)[·], u).

Positive-definiteness of the Hessian is indeed guaranteed for strictly


convex functions for example, while for generic non-convex functions,
such property may be verified around a minimizer but not in general.
The conjugate gradient method is recalled in Algorithm 8.1 in its
simplest form. In theory, the exact solution of the linear system is found
after at most T = P iterations of CG, though in practice numerical
errors may prevent from getting an exact solution.

Algorithm 8.1 Conjugate gradient method


Inputs: linear map H[·] : RP → RP , target u ∈ RP , initialization
v0 (default 0), number of iterations T (default P ), target accuracy
ε (default machine precision)
1: r0 = u − H[v0 ]
2: p0 = r0
3: for t = 0, . . . T do
4: αt = ⟨p⟨r t ,rt ⟩
t ,H[pt ]⟩
5: vt+1 = vt + αt pt
6: rt+1 = rt − αt H[pt ]
7: if ⟨rt+1 , rt+1 ⟩ ≤ ε then break
8: βt = ⟨rt+1 ,rt+1 ⟩
⟨rt ,rt ⟩
9: pt+1 = rt+1 + βt pt
Output: vT , such that H[vT ] ≈ u

8.4.3 Complexity

For a given matrix H ∈ RP ×P , solving Hv = u can be done with


decomposition methods (LU, QR, Cholesky) in O(P 3 ) time. For matrix-
free methods such as CG or GMRES, the cost per iteration is O(P 2 ).
Since they theoretically solve the linear system in O(P ) iterations, the
cost to obtain an exact solution is theoretically the same, O(P 3 ).
8.5. Second-order backpropagation 167

However, CG or GMRES differ from decomposition methods in that


they are iterative methods, meaning that, at each iteration, they get
closer to a solution. Unlike decomposition methods, this means that we
can stop them before an exact solution is found. In practice, the number
of iterations required to find a good approximate solution depends on
the matrix. Well conditioned matrices require only few iterations. Badly
conditioned matrices lead to some numerical instabilities for CG, so
that more than P iterations may be needed to get a good solution. In
contrast, decomposition methods proceed in two steps: first they build
a decomposition of H at a cost of O(P 3 ), and second they solve a linear
system at a cost of O(P 2 ), by leveraging the structure. LU and QR
decompositions are known to be generally more stable and are therefore
often preferred in practice, when we can access entries of H at no cost.
If we do not have access to the Hessian H, but only to its HVP,
accessing entries of H comes at a prohibitive cost. Indeed, entries of H
can still be recovered from HVPs, since e⊤ i Hej = Hi,j , but accessing
each row or column of H costs one HVP (matrix-vector product). To
access the information necessary to use a decomposition method, we
therefore need P calls to HVPs before being able to actually compute
the solution. For the same number of calls, CG or GMRES will already
have found an approximate solution. In addition, a CG method does
not require to store any memory.

8.5 Second-order backpropagation

8.5.1 Second-order Jacobian chain rule

The essential ingredient to develop forward-mode and reverse-mode


autodiff hinged upon the chain rule for composed functions, h = g ◦ f .
For second derivatives, a similar rule can be obtained. To do so, we
slightly abuse notations and denote

∂ 2 h(w)∗ [u] := ∇2 ⟨h, u⟩(w) ∈ RP ×P ,

where h : RP → RQ , w ∈ RP , u ∈ RQ , and where we recall the


shorthand notation ⟨u, h⟩(w) := ⟨u, h(w)⟩. Moreover, we view the
above quantity as a linear map. Strictly speaking, the superscript ∗ is
168 Second-order automatic differentiation

not a linear adjoint anymore, since v1 , v2 7→ ∂ 2 h(w)[v1 , v2 ] is no longer


linear but bilinear. However, this superscript plays the same role as the
VJP, since it takes an output vector and returns the input derivatives
that correspond to infinitesimal variations along that output vector.

Proposition 8.7 (Hessian chain-rule). For two twice differentiable


functions f : RP → RM and g : RM → RQ , the second directional
derivative of the composition g ◦ f is a bilinear map from RP × RP
to RQ along input directions v1 , v2 ∈ RP of the form

∂ 2 (g ◦ f )(w)[v1 , v2 ] = ∂g(f (w))[∂ 2 f (w)[v1 , v2 ]]


+ ∂ 2 g(f (w))[∂f (w)[v1 ], ∂f (w)[v2 ]].

The Hessian of the composition g ◦ f along an output direction


u ∈ RQ is, seen as a linear map,

∂ 2 (g ◦ f )(w)∗ [u] = ∂ 2 f (w)∗ [∂g(f (w))∗ u] (8.2)


∗ 2 ∗
+ ∂f (w) ∂ g(f (w)) [u]∂f (w).

For the composition of f : RP → RM with a scalar-valued function


function ℓ : RM → R, we have in matrix form
M
∇2 (ℓ ◦ f )(w) = (∇ℓ(f (w)))j ∇2 fj (w)
X

j=1

+ ∂f (w)⊤ ∇2 ℓ(f (w))∂f (w).

Note that, while the Hessian is usually defined for scalar-valued


functions h : RP → R, the above definition is for a generalized notion
of Hessian that works for any function h : RP → RQ .
The Hessian back-propagation rule in Eq. (8.2) reveals two terms.
The first one ∂ 2 f (w)∗ [∂g(f (w))∗ u] simply computes the Hessian of
the intermediate function along the output direction normally back-
propagated by a VJP. The second term ∂f (w)∗ ∂ 2 g(f (w))∗ [u]∂f (w)
shows how intermediate first-order variations influence second order
derivatives of the output.
8.5. Second-order backpropagation 169

Example 8.1 (Composition with an elementwise nonlinear function).


Consider the element-wise application of a twice differentiable
scalar-valued function f (x) = (f (xi ))M
i=1 followed by some twice
differentiable function ℓ. Note that ∇ fi (x) = f ′′ (xi )ei e⊤
2
i . Hence,
the Hessian of the composition reads
M
∇2 (ℓ ◦ f )(x) = (∇ℓ(f (x)))i f ′′ (xi )ei e⊤
X
i
i=1
+ diag(f ′ (x))∇2 ℓ(f (x)) diag(f ′ (x))
= diag(∇ℓ(f (w)) ⊙ f ′′ (x))
+ ∇2 ℓ(f (x)) ⊙ (f ′ (x)f ′ (x)⊤ ),

where f ′ (x) := (f ′ (xi ))M


i=1 and f (x) := (f (xi ))i=1 .
′′ ′′ M

Example 8.2 (Hessian of the composition with a linear function). Consider


a linear function f (W ) = W x, for W ∈ RM ×D , composed with
some twice differentiable function ℓ : RM → R. From Proposi-
tion 8.7, we get, in terms of linear maps,

∇2 (ℓ ◦ f )(W ) = ∂f (W )∗ ∇2 ℓ(f (W ))∂f (W ).

As already noted in Section 2.3, we have that ∂f (W )[V ] = V x


and ∂f (W )∗ [u] = ux⊤ . Hence, the Hessian seen as a linear map
reads

∇2 (ℓ ◦ f )(W )[V ] = ∂f (W )∗ [∇2 ℓ(f (W ))[∂f (W )[V ]]] = HV xx⊤ ,

where H := ∇2 ℓ(f (W )).

8.5.2 Computation chains

For a simple computation chain f = fK ◦ . . . ◦ f1 as in Section 7.1, the


formula derived in Proposition 8.7 suffices to develop an algorithm that
backpropagates the Hessian, as shown in Algorithm 8.2. Compared to
Algorithm 7.2, we simply backpropagate both the vectors rk and the
matrices Rk using intermediate first and second derivatives.
170 Second-order automatic differentiation

Algorithm 8.2 Hessian backprop for computation chains


Functions: f := fK ◦ . . . ◦ f1 ,
Inputs: input x, output direction u
1: Initialize and store s0 := x ▷ Forward pass
2: for k := 1, . . . , K do
3: Compute and store sk := fk (sk−1 )
4: Initialize rK := ∇ℓ(sK ), RK := ∇2 ℓ(sK ) ▷ Backward pass
5: for k := K, . . . , 1 do
6: Compute rk−1 := ∂fk (sk−1 )∗ [rk ]
7: Compute Rk−1 := ∂ 2 fk (sk−1 )∗ [rk ] + ∂fk (sk−1 )∗ Rk ∂fk (sk−1 )
8: Release sk−1 from memory
Outputs: ℓ(f (x)) = ℓ(sK ), ∇(ℓ ◦ f )(x) = r0 , ∇2 (ℓ ◦ f )(x) = R0

8.5.3 Fan-in and fan-out


For generic computation graphs (see Section 7.3), we saw that multi-
input functions (fan-in) were crucial. For Hessian backpropagation in
computation graphs, we therefore need to develop a similar formula.

Proposition 8.8 (Hessian chain-rule for fan-in). Consider n+1 twice


differentiable functions f1 , . . . , fn and g with fi : RP → RMi and
g : RM1 × . . . × RMn → RQ . The Hessian of g ◦ f for f (w) =
(f1 (w), . . . , fn (w)) along an output direction u ∈ RQ is given by
n
∂ 2 (g ◦ f )(w)∗ [u] = ∂ 2 fi (w)∗ [∂i g(f (w))∗ [u]]
X

i=1
n
+ ∂fi (w)∗ ∂i,j g(f (w))∗ [u]∂fj (w).
X
2

i,j=1

The gradient backpropagation expression for fan-in is simple because


the functions fi are not linked by any path. In contrast, the Hessian
backpropagation involves cross-product terms
∂fi (w)∗ ∂i,j
2 g(f (w))∗ [u]∂f (w) for i ̸= j. The nodes associated to the
j
fi computations cannot be treated independently anymore.
On the other hand, developing a backpropagation rule for fan-out
does not pose any issue, since each output function can be treated
8.6. Block diagonal approximations 171

independently.

Proposition 8.9 (Hessian chain-rule for fan-out). Consider n+1 twice


differentiable functions g1 , . . . , gn and f with gi : RM → RQi and
f : RP → RM . The Hessian of g ◦ f for g(w) = (g1 (w), . . . , gn (w))
along a direction u = (u1 , . . . , un ) ∈ RQ1 × . . . × RQn is given by
n

∂ (g ◦ f )(w) [u] = ∂ 2 f (w)∗ [∂gi (f (w))∗ [ui ]]
X
2

i=1
n
+ ∂f (w)∗ ∂ 2 gi (f (w))∗ [u]∂f (w).
X

i=1

8.6 Block diagonal approximations

Rather than computing the whole Hessian or Gauss-Newton matrices,


we can consider computing block-diagonal or diagonal approximations,
which are easier to invert. The approximation rules we present in this
section build upon the Hessian chain rule studied in Section 8.5.

8.6.1 Feedforward networks

Recall the definition of a feedforward network:

s0 := x
sk := fk (sk−1 , wk ) ∀k ∈ {1, . . . , K}
f (x, w) := sK ,

where w := (w1 , . . . , wK ). Rather than computing the entire Hessian of


ℓ ◦ f w.r.t. w, we can compute the Hessians w.r.t. each set of parameters
wk . For the case of computation chains, the Hessian backpropagation
recursion we used in Algorithm 8.2 was

Rk−1 := ∂ 2 fk (sk−1 )∗ [rk ] + ∂fk (sk−1 )∗ Rk ∂fk (sk−1 ).


172 Second-order automatic differentiation

Extending this recursion to the feedforward network case, we obtain,


starting from rK := ∇ℓ(sK ) and RK := ∇2 ℓ(sK ),
rk−1 := ∂fk (sk−1 , wk )∗ [rk ]
!
Rk−1 ∼
:= ∂ 2 fk (sk−1 , wk )∗ [rk ]
∼ Hk
+ ∂fk (sk−1 , wk )∗ Rk ∂fk (sk−1 , wk ),
where we used ∼ to indicate that these blocks are not used. The Hessians
w.r.t each set of parameters are then
R0 = ∇2xx (ℓ ◦ f )(x, w)
H1 = ∇2w1 w1 (ℓ ◦ f )(x, w)
..
.
HK = ∇2wK wK (ℓ ◦ f )(x, w)).
The validity of this result stems from the fact that we can view the
Hessian w.r.t. wk as computing the Hessian w.r.t. wk of
f˜K ◦ . . . ◦ f˜k+1 ◦ fk (sk−1 , wk )
where f˜i := fi (·, wk ), for i ∈ {k + 1, . . . , K}. As the computations of
the block-wise Hessians share most of the computations, they can be
evaluated in a single backward pass just as the gradients.

Example 8.3 (Block-wise computation of the Gauss-Newton matrix).


Our blockwise backpropagation scheme can readily be adapted for
the Gauss-Newton matrix as
!
Rk−1 ∼
:= ∂fk (sk−1 , wk )∗ Rk ∂fk (sk−1 , wk ),
∼ Gk

starting from RK := ∇2 ℓ(sK ). The outputs R0 , G0 , . . . , GK give a


block-wise approximation of the Gauss-Newton matrix.
Now, consider a simple multilayer perceptron such that

fk (sk−1 , wk ) := a(Wk sk−1 ) with wk := vec(Wk )

Using Example 8.2 and Example 8.1 adapted to the Gauss-Newton


8.7. Diagonal approximations 173

matrix, we can compute the block-wise decomposition of the Gauss-


Newton matrix as, for k = K, . . . , 1,

Rk−1 := Wk⊤ Jk Wk
Jk := Rk ⊙ (a′ (Wk sk−1 )a′ (Wk sk−1 )⊤ )
Gk := Jk ⊗ sk−1 s⊤
k−1

starting from RK := ∇2 ℓ(sK ). The outputs G1 , . . . , GK correspond


to the block-wise elements of the Gauss-Newton matrix of f for the
vectorized weights w1 , . . . , wK . Similar computations were done in
KFRA (Botev et al., 2017) and BackPack (Dangel et al., 2019).

8.6.2 Computation graphs


For generic computation graphs, consider a function f (x, w) defined
by, denoting i1 , . . . , ipk := parents(k),
sk := fk (si1 , . . . , sipk ) ∀k ∈ {1, . . . , K}
such that f (x, w) = sK , and k is following a topological ordering of the
graph (see Section 7.3). We can consider the following backpropagation
scheme, for k = K, . . . , 1 and j ∈ parents(k)
rij ← rij + ∂j fk (si1 , . . . , sipk )∗ [rk ] (8.3)
Rij ← Rij + ∂jj
2
fk (si1 , . . . , sipk )∗ [rk ]
+ ∂j fk (si1 , . . . , sipk )∗ Rk ∂j fk (si1 , . . . , sipk ), (8.4)
starting from RK := and rK := ∇ℓ(sK ). Recall that for
∇2 ℓ(sK )
multiple inputs, the chain-rule presented in Proposition 8.8 involves
the cross-derivatives. For this reason the back-propagation scheme
in Eq. (8.3) only computes an approximation. For example, one can
verify that using Eq. (8.3) to compute the Hessian of ℓ(f1 (w), f2 (w))
does not provide an exact expression for the Hessian of f . This scheme
is easy to implement and may provide a relevant proxy for the Hessian.

8.7 Diagonal approximations

Similarly to the idea of designing a backpropagation scheme that ap-


proximates blocks of the Hessian, we can design a backpropagation
174 Second-order automatic differentiation

scheme that approximates the diagonal of the Hessian. The approach


was originally proposed by Becker and Le Cun (1988) for feedforward
networks, but our exposition, new to our knowledge, has the benefit
that it naturally extends to computational graphs, as we shall see.

8.7.1 Computation chains


The idea stems from modifying the Hessian backpropagation rule
in Proposition 8.7 to only keep the diagonal of the Hessian. Formally,
given a matrix M ∈ RD×D , we denote by diag(M ) = (Mii )D i=1 ∈ R
D

the vector of diagonal entries of M , and for a vector m ∈ R , we D

denote Diag(m) = D i=1 mi ei ei the diagonal matrix with entries mi .



P

For the backpropagation of the Hessian of ℓ ◦ fK ◦ . . . ◦ f1 , we see from


Algorithm 8.2 that diag(Hk−1 ) can be expressed in terms of Hk as
diag(Hk−1 ) = diag(∂ 2 fk (sk−1 )∗ rk )
+ diag(∂fk (sk−1 )∗ Hk ∂fk (sk−1 )).
Unfortunately, that recursion needs access to the whole Hessian Hk ,
and would therefore be too expensive. A natural idea is to modify the
recursion to approximate diag(Hk ) by backpropagating vectors:
dk−1 := diag(∂ 2 fk (sk−1 )∗ rk )
+ diag(∂fk (sk−1 )∗ Diag(dk )∂fk (sk−1 )).
The diagonal matrix Diag(dk ) serves as a surrogate for Hk . Each
iteration of this recursion can be computed in linear time in the output
dimension Dk since
Dk Dk
dk−1,i = fk,j (sk−1 ) + dk,j (∂i fk,j (sk−1 ))2 .
X X
2
rk,j · ∂i,i
j=1 j=1

To initialize the recursion, we can set dK := diag(∇2 ℓ(sK )). As an


alternative, as proposed by Elsayed and Mahmood (2022), if HK has
a simple form, we can use ∇2 ℓ(sK ) instead of Diag(dK ) at the first
iteration. This is the case for instance if fK is a cross-entropy loss. The
recursion is repeated until we obtain the approximate diagonal Hessian
d0 ≈ diag(∇2 (ℓ ◦ f )(x)). The gradients rk , needed to compute dk , are
computed along the way and the algorithm can therefore also return
r0 = ∇(ℓ ◦ f )(x).
8.7. Diagonal approximations 175

8.7.2 Computation graphs


Although this diagonal approximation was originally derived for feed-
forward networks Becker and Le Cun (1988), it is straightforward to
generalize it to computation graphs. Namely, for a function f (x, w) de-
composed along a computation graph, we can backpropagate a diagonal
approximation in reverse topological order as

rij ← rij + ∂j fk (si1 , . . . , sipk )∗ [rk ]


dij ← dij + diag(∂jj
2
fk (si1 , . . . , sipk )∗ [rk ])
+ diag(∂j fk (si1 , . . . , sipk )∗ Diag(dk )∂j fk (si1 , . . . , sipk )), (8.5)

for j ∈ parents(k), starting from rK = ∇ℓ(sK ) and dK = diag(∇2 ℓ(sK ))


or Diag(dK ) = ∇2 ℓ(sK ). To implement such an algorithm, each ele-
mentary function in the computational graph needs to be augmented
with an oracle that computes the Hessian diagonal approximation of
the current function, given the previous ones. An example with MLPs
is presented in Example 8.4.

Example 8.4 (Hessian diagonal approximation for MLPs ). Consider


a multilayer perceptron

sk := ak (Wk sk−1 ) ∀k ∈ {1, . . . , K − 1}


f (w, x) := sK

starting from s0 = x. Here ak is the element-wise activation func-


tion (potentially the identity) and w encapsulates the weight ma-
trices W1 , . . . , WK . We consider the derivatives w.r.t. the flattened
matrices, so that gradients and diagonal approximations w.r.t. these
flattened quantities are vectors. The backpropagation scheme (8.5)
176 Second-order automatic differentiation

then reduces to, denoting tk = Wk sk−1 ,

rk−1 := Wk⊤ (a′ (tk ) ⊙ rk )


gk := vec((a′ (tk ) ⊙ rk )s⊤
k−1 )
δk := rk ⊙ a′′ (tk ) + dk ⊙ a′ (tk )2
 Dk
Dk
dk−1 := 
X
2
Wk,ij δk,j 
j=1 i=1
hk := vec(δk (s2k−1 )⊤ )

starting from rK = ∇ℓ(sK ) and, e.g., dK = diag(∇2 ℓ(sK )). The


algorithm returns g1 , . . . , gK as the gradients of f w.r.t. w1 , . . . , wK ,
with wi = vec(Wi ), and h1 , . . . , hK as the diagonal approximations
of the Hessian w.r.t. w1 , . . . , wK .

8.8 Randomized estimators

In this section, we describe randomized estimators of the diagonal of


the Hessian or Gauss-Newton matrices.

8.8.1 Girard-Hutchinson estimator

We begin with a generic estimator, originally proposed for trace es-


timation by Girard (1989) and extended by Hutchinson (1989). Let
A ∈ RP ×P be an arbitrary square matrix, whose matrix-vector product
(matvec) is available. Suppose ω ∈ RP is an isotropic random vector,
i.e., such that Eω∼p [ωω ⊤ ] = I. For example, two common choices are
the Rademacher distribution p = Uniform({−1, 1}) and the standard
normal distribution p = Normal(0, I). Then, we have

Eω∼p [⟨ω, Aω⟩] = tr(A).

Applications include generalized cross-validation, computing the Kullback-


Leibler divergence between two Gaussians, and computing the deriva-
tives of the log-determinant.
The approach can be extended (Bekas et al., 2007; Baston and
Nakatsukasa, 2022; Hallman et al., 2023) to obtain an estimator of the
8.8. Randomized estimators 177

diagonal of A,
Eω∼p [ω ⊙ Aω] = Diag(A),
where ⊙ denotes the Hadamard product (element-wise multiplication).
This suggests that we can use the Monte-Carlo method to estimate the
diagonal of A,
S
1X
Diag(A) ≈ ωi ⊙ Aωi ,
S i=1
with equality as S → ∞, since the estimator is unbiased. Since, as
reviewed in Section 8.1 and Section 8.2, we know how to multiply
efficiently with the Hessian and the Gauss-Newton matrices, we can
apply the technique with these matrices. The variance is determined
by the number S of samples drawn and therefore by the number of
matvecs performed. More elaborated approaches have been proposed to
further reduce the variance (Meyer et al., 2021; Epperly et al., 2023).

8.8.2 Bartlett estimator for the factorization

Suppose the objective function is of the form L(w; x, y) := ℓ(f (w; x); y)
where ℓ is the negative log-likelihood ℓ(θ; y) := − log pθ (y) of an ex-
ponential family distribution, and θ := f (w; x), for some network f .
We saw from the equivalence between the Fisher and Gauss-Newton
matrices in Proposition 8.6 (which follows from the Bartlett identity)
that

∇2GN L(w; x, ·) = EY ∼pθ [∂f (w; x)∗ ∇ℓ(θ; Y ) ⊗ ∇ℓ(θ; Y )∂f (w; x)]
= EY ∼pθ [∇L(w; x, Y ) ⊗ ∇L(w; x, Y )],

where · indicates that the result holds for any value of the second
argument. This suggests a Monte-Carlo scheme
S
1X
∇2GN L(w; x, ·) ≈ [∇L(w; x, yij ) ⊗ ∇L(w; x, yij )]
S j=1

where yi1 , . . . , yiS ∼ pθ and θ = f (w, x). In words, we can approxi-


mate the Gauss-Newton matrix with S gradient computations. This
factorization can also be used to approximate the GNVP in Eq. (8.1).
178 Second-order automatic differentiation

8.8.3 Bartlett estimator for the diagonal


Following a similar approach, we obtain
diag(∇2GN (w; x, ·)) = EY ∼pθ [∇L(w; x, Y ) ⊙ ∇L(w; x, Y )],
where ⊙ indicates the element-wise (Hadamard) product. Using a Monte-
Carlo scheme, sampling yi1 , . . . , yiS from pθ , we therefore obtain
S
1X
diag(∇2GN (w; x, ·)) ≈ ∇L(w; x, yij ) ⊙ ∇L(w; x, yij ),
S j=1
with equality when all labels in the support of pθ have been sampled.
That estimator, used for instance in (Wei et al., 2020, Appendix C.1.),
requires access to individual gradients evaluated at the sampled
labels. Another possible estimator of the diagonal is given by
1
diag(∇2GN L(w; x, ·))
S
S S
1X 1X
" #
=EY1 ,...,YS ∼pθ ∇ L(w; x, Yi ) ⊙ ∇ L(w; x, Yi ) .
S i=1 S i=1
Letting γi := ∇L(w; x, Yi ), this follows from
   

γj  = E  γi ⊙ γi +
X X X X
E  γi ⊙ γi ⊙ γj 
i j i i̸=j
" #
=E
X
γi ⊙ γi
i
where we used that E[γi ⊙ γj ] = E[γi ] ⊙ E[γj ] = 0 since γi and γj are
independent variables for i ̸= j and have zero mean, from Bartlett’s
first identity Eq. (11.4). We can then use the Monte-Carlo method to
obtain
   
S S
1 1X 1X
diag(∇2GN L(w; x, ·)) ≈ ∇ L(w; x, yij ) ⊙ ∇ L(w; x, yij ) ,
S S j=1
S j=1

with equality when all labels in the support of pθ have been sampled.
This estimator can be more convenient to implement, since it only needs
access to the gradient of the averaged losses. However, it may suffer
from higher variance. A special case of this estimator is used by Liu
et al. (2023), where they draw only one y for each x.
8.9. Summary 179

8.9 Summary

By using a Hessian chain rule, we can develop a “Hessian backprop-


agation”. While it is reasonably simple for computation chains, it be-
comes computationally prohibitive for computation graphs, due to the
cross-product terms occurring with fan-in. A better approach is to use
Hessian-vector products (HVPs). We saw that there are four possible
methods to compute HVPs, but the forward-over-reverse method is a
priori the most efficient. Similarly as for computing gradients, comput-
ing HVPs is only a constant times more expensive than evaluating the
function itself.
The Gauss-Newton matrix associated with the composition ℓ ◦
f can be seen as an approximation of the Hessian. It is a positive
semidefinite matrix if ℓ is convex, and can be used to build a principled
quadratic approximation of a function. It is equivalent to the Fisher
information matrix in the case of exponential families. Gauss-Newton-
vector products can be computed efficiently, like HVPs.
We also described other approximations, such as (block) diagonal
approximations, and randomized estimators.
9
Inference in graphical models as differentiation

A graphical model specifies how random variables depend on each


other and therefore determines how their joint probability distribution
factorizes. In this chapter, we review key concepts in graphical models
and how they relate to differentiation, drawing in the process analogies
with computation chains and computation graphs.

9.1 Chain rule of probability

The chain rule of probability is a fundamental law in probability theory


for computing the joint probability of events. In the case of only two
events A1 and A2 , it reduces to the product rule

P(A1 ∩ A2 ) = P(A2 |A1 )P(A1 ).

For two discrete random variables S1 and S2 , using the events A1 :=


{S1 = s1 } and A2 := {S2 = s2 }, the product rule becomes

P(S1 = s1 , S2 = s2 ) = P(S2 = s2 |S1 = s1 )P(S1 = s1 ).

More generally, using the product rule, we have for K events

P (A1 ∩ . . . ∩ AK ) = P (AK | A1 ∩ . . . ∩ AK−1 ) P (A1 ∩ . . . ∩ AK−1 ) .

180
9.2. Conditional independence 181

Applying the product rule one more time, we have


P (A1 ∩ . . . ∩ AK−1 ) = P (AK−1 | A1 ∩ . . . ∩ AK−2 ) P (A1 ∩ . . . ∩ AK−2 ) .
Repeating the process recursively, we arrive at the chain rule of
probability
K
P (A1 ∩ . . . ∩ AK ) = P(Aj | A1 ∩ · · · ∩ Aj−1 )
Y

j=1
 
K j−1
=
Y \
P Aj Ai  .
j=1 i=1

For K discrete random variables Sj , using the events Aj := {Sj = sj },


the chain rule of probability becomes
K
P (S1 = s1 , . . . , SK = sK ) = P(Sj = sj | S1 = s1 , . . . , Sj−1 = sj−1 ).
Y

j=1

Importantly, this factorization holds without any independence assump-


tion on the variables S1 , . . . , SK . We can further simplify this expression
if we make conditional independence assumptions.

9.2 Conditional independence

We know that if two events A and B are independent, then


P(A|B) = P(A).
Similarly, if two random variables S1 and S2 are independent, then
P(S2 = s2 |S1 = s1 ) = P(S2 = s2 ).
More generally, if we work with K variables S1 , . . . , SK , some variables
may depend on each other, while others may not. To simplify the
notation, given a set C, we define the shorthands
SC := (Si : i ∈ C)
sC := (si : i ∈ C).
We say that a variable Sj is independent of SD conditioned on SC ,
with C ∩ D = ∅, if for any sj , sC , sD
P(Sj = sj | SC = sC , SD = sD ) = P(Sj = sj | SC = sC ).
182 Inference in graphical models as differentiation

9.3 Inference problems

9.3.1 Joint probability distributions


We consider a collection of K variables s := (s1 , . . . , sK ), potentially
ordered or unordered. Each s belongs to the Cartesian product
S := S1 × · · · × SK . Throughout this chapter, we assume that the sets
Sk are discrete for concreteness, with Sk := {v1 , . . . , vMk }. Note that
because Sk is discrete, we can always identify it with {1, . . . , Mk }. A
graphical model specifies a joint probability distribution
P(S = s) = P(S1 = s1 , . . . , SK = sK )
= p(s)
= p(s1 , . . . , sK ),
where p is the probability mass function of the joint probability distribu-
tion. Summing over the Cartesian product of all possible configurations,
we obtain
p(s) = p(s1 , . . . , sK ) = 1.
X X

s∈S s1 ,...,sK ∈S
As we shall see, the graph of a graphical model encodes the dependen-
cies between the variables (S1 , . . . , SK ) and therefore how their joint
distribution factorizes. Given access to a joint probability distribution,
there are several inference problems one typically needs to solve.

9.3.2 Likelihood
A trivial task is to compute the likelihood of some observations s =
(s1 , . . . , sK ),
P(S1 = s1 , . . . , Sk = sk ) = p(s1 , . . . , sk ).
It is also common to compute the log-likelihood,
log P(S1 = s1 , . . . , Sk = sk ) = log p(s1 , . . . , sk ).

9.3.3 Maximum a-posteriori inference


Another common task is to compute the most likely configuration,
arg max p(s1 , . . . , sK ).
s1 ∈S1 ,...,sK ∈SK
9.3. Inference problems 183

This is the mode of the joint probability distribution.

9.3.4 Marginal inference

The operation of marginalization consists in summing (or integrat-


ing) over all possible values of a given variable in a joint probability
distribution. This allows us to compute the marginal probability of
the remaining variables. For instance, we may want to marginalize all
variables but Sk = sk . To do so, we define the Cartesian product

Ck (sk ) := S1 × · · · × Sk−1 ×{sk } × Sk+1 × · · · × SK . (9.1)


| {z } | {z }
Ak−1 Bk+1

Summing over all variables but Sk , we obtain the marginal probability


of Sk = sk as

P(Sk = sk ) = p(s1 , . . . , sK )
X

s1 ,...,sK ∈Ck (sk )

= p(s1 , . . . , sK )
X X

s1 ,...,sk−1 ∈Ak−1 sk+1 ,...,sK ∈Bk+1

Defining similarly

Ck,l (sk , sl ) := S1 × · · · × {sk } × · · · × {sl } × · · · × SK ,

we obtain

P(Sk = sk , Sl = sl ) = p(s1 , . . . , sK ).
X

s1 ,...,sK ∈Ck,l (sk ,sl )

In particular, we may want to compute the marginal probability of two


consecutive variables, P(Sk−1 = sk−1 , Sk = sk ).

9.3.5 Expectation, convex hull, marginal polytope

Another common operation is to compute the expectation of ϕ(S) under


a distribution p. It is defined by

µ := ES∼p [ϕ(S)] =
X
p(s)ϕ(s) ∈ M
s∈S
184 Inference in graphical models as differentiation

For the expectation under pθ , we write

µ(θ) := ES∼pθ [ϕ(S)] = pθ (s)ϕ(s) ∈ M.


X

s∈S

In exponential family distributions (Section 3.4), the function ϕ is called


a sufficient statistic. It decomposes as

ϕ(s) := (ϕC (sC ))C∈C ,

where C ⊆ [K]. Intuitively, ϕ(s) can be thought as an encoding or


embedding of s (a potentially discrete object such as a sequence of
integers) in a vector space. Under this decomposition, we can also
compute
µC := ES [ϕC (SC )] = p(s)ϕC (sC ).
X

s∈S

Convex hull

The mean µ belongs to the convex hull of ϕ(S) := {ϕ(s) : s ∈ S},


( )
M := conv(ϕ(S)) := p(s)ϕ(s) : p ∈ P(S) ,
X

s∈S

where P(S) is the set of all possible probability distributions over S. In


other words, M is the set of all possible convex combinations of ϕ(s)
for s ∈ S. The vertices of M are all the s ∈ S.

Case of binary encodings: the marginal polytope

In the special case of a discrete set Sk = {v1 , . . . , vM } and of a binary


encoding (indicator function) ϕ(s), the set M is called the marginal
polytope (Wainwright and Jordan, 2008), because each point µ ∈
M contains marginal probabilities. To see why, consider the unary
potential
[ϕ(s)]k,i = [ϕk (sk )]i = I(sk = vi ) (9.2)
9.4. Markov chains 185

where I(p) := 1 if p is true, 0 otherwise. We then obtain the marginal


probability of Sk = vi ,

[µ]k,i = ES [ϕ(S)k,i ]
= ESk [ϕk (Sk )i ]
= ESk [I(Sk = vi )]
= P(Sk = sk )I(sk = vi )
X

sk ∈Sk

= P(Sk = vi ).

Likewise, consider the pairwise potential

[ϕ(s)]k,l,i,j = [ϕk,l (sk , sl )]i,j = I(sk = vi , sl = vj ). (9.3)

We then obtain the marginal probability of Sk = vi and Sl = vj ,

[µ]k,l,i,j = ES [ϕ(S)k,l,i,j ]
= ESk ,Sl [ϕk,l (Sk , Sl )i,j ]
= ESk ,Sl [I(Sk = vi , Sl = vj )]
= P(Sk = sk , Sl = sl )I(sk = vi , sl = vj )
X X

sk ∈Sk sl ∈Sl

= P(Sk = sk , Sl = vj ).

We can do the same with higher-order potential functions.

9.3.6 Complexity of brute force

Apart from computing the likelihood, which is trivial, computing the


marginal, mode and expectation by brute force takes O( K k=1 |Sk |) time.
Q

In particular, if |Sk | = M ∀k ∈ [K], brute force takes O(M K ) time.

9.4 Markov chains

In this section, we briefly review Markov chains. Our notation is chosen


to emphasize the analogies with computation chains.
186 Inference in graphical models as differentiation

9.4.1 The Markov property

When random variables are organized sequentially as S1 , . . . , SK ,


a simple example of conditional independence is when each variable
Sk ∈ Sk only depends on the previous variable Sk−1 ∈ Sk−1 , that is,

P(Sk = sk | Sk−1 = sk−1 , . . . , S1 = s1 ) = P(Sk = sk | Sk−1 = sk−1 )


:= pk (sk | sk−1 ),

A probability distribution satisfying the above is said to satisfy the


Markov property, and is called a Markov chain. A computation
chain is specified by the functions fk , that take sk−1 as input and
output sk . In analogy, a Markov chain is specified by the conditional
probability distributions pk of Sk given Sk−1 . We can then define the
generative process

S0 := s0
S1 ∼ p1 (· | S0 )
S2 ∼ p2 (· | S1 )
..
.
SK ∼ pK (· | SK−1 ).

Strictly speaking, we should write Sk | Sk−1 ∼ pk (· | Sk−1 ). We choose


our notation both for conciseness and for analogy with computation
chains. Furthermore, to simplify the notation, we assume without loss
of generality that S0 is deterministic (if this is not the case, we can
always move S0 to S1 and add a dummy variable as S0 ). That is,
P(S0 = s0 ) = p0 (s0 ) := 1 and S0 := {s0 }. This amounts to setting the
initial distribution of S1 as

P(S1 = s1 ) := P(S0 = s0 )P(S1 = s1 |S0 = s0 ) = P(S1 = s1 |S0 = s0 ).


...

9.4. Markov chains 187

...

... start ... end

...
...

Figure 9.1: Left: Markov chain. Right: Computation graph of the forward-
backward and the Viterbi algorithms: a lattice.
start ... end

We can then compute the joint probability of the Markov chain by


...

P(S1 = s1 , . . . , SK = sK ) = p(s1 , . . . , sK )
K
= P(Sk = sk | Sk−1 = sk−1 )
Y

k=1
K
= pk (sk | sk−1 ),
Y

k=1

where we left the dependence on s0 implicit, since p0 (s0 ) = 1. A Markov


chain with Sk = {1, 2, 3} is illustrated in Fig. 9.1. A chain defines a
totally ordered set {1, . . . , K}, since two nodes in the graph are
necessarily linked to each other by a path.

Example 9.1 (Chain of categorical distributions). Suppose our goal


is predict, from x ∈ X , a sequence of length K, where each Sk
belongs to Sk = {1, . . . , M }. In natural language processing, this
task is called sequence tagging. We can define

Sk ∼ Categorical(πk−1,k,Sk−1 )

where

πk−1,k,i := softargmax(θk−1,k,i ) ∈ △M
= (πk−1,k,i,j )M
j=1
θk−1,k,i := (θk−1,k,i,j )M
j=1 ∈ R
M

θk−1,k,i,j := fk−1,k (x, i, j, wk ) ∈ R.


188 Inference in graphical models as differentiation

We therefore have

P(Sk = j | Sk−1 = i) = pk (j | i)
= πk−1,k,i,j
= [softargmax(θk−1,k,i )]j
exp(θk−1,k,i,j )
=P
j ′ exp(θk−1,k,i,j ′ )

and

log P(Sk = j | Sk−1 = i) = log pk (j | i)


= θk−1,k,i,j − logsumexp(θk−1,k,i )
= θk−1,k,i,j − log exp(θk−1,k,i,j ′ ).
X

j′

We emphasize that because k − 1 and k are always consecutive,


the representation θk−1,k,i,j is inefficient; we could use θk,i,j instead.
Our notation is designed for consistency with Markov random fields.

9.4.2 Time-homogeneous Markov chains

A time-homogeneous discrete-time Markov chain corresponds to the


case when the distribution of Sk given Sk−1 is the same regardless of k:

p1 = · · · = pK = p.

The finite-space case corresponds to when each Sk ∈ S can take a


finite set of values S = {v1 , . . . , vM } and

P(Sk = vj | Sk−1 = vi ) = p(vj |vi ) = πi,j ,

where πi,j ∈ [0, 1] is the transition probability from vi to vj . Because


the set S = {v1 , . . . , vM } is discrete, we can always identify it with
{1, . . . , M }. That is, we can instead write

P(Sk = j | Sk−1 = i) = p(j|i) = πi,j .


9.5. Bayesian networks 189

9.4.3 Higher-order Markov chains

More generally, a nth -order Markov chain may depend, not only on the
last variable, but on the last n variables,

P(Sk = sk | Sk−1 = sk−1 , . . . , S1 = s1 )


=P(Sk = sk | Sk−1 = sk−1 , . . . , Sk−n = sk−n )
=pk (sk |sk−1 , . . . , sk−n ).

Autoregressive models such as Transformers can be seen as specifying


a higher-order Markov chain, with a context window of size n. The
larger context makes exact inference using dynamic programming com-
putationally intractable. This is why practitioners use beam search or
ancestral sampling (Section 9.5.3) instead.

9.5 Bayesian networks

In this section, we briefly review Bayesian networks. Our notation is


chosen to emphasize the analogies with computation graphs.

9.5.1 Expressing variable dependencies using DAGs

Markov chains and more generally higher-order Markov chains are a spe-
cial case of Bayesian network. Similarly to computation graphs reviewed
in Section 7.3, variable dependencies can be expressed using a directed
acylic graph (DAG) G = (V, E), where the vertices V = {1, . . . , K}
represent variables and edges E represent variable dependencies. The
set {i1 , . . . , ink } = parents(k) ⊆ V, where nk := |parents(k)|, indicates
the variables Si1 , . . . , Sink that Sk depends on. This defines a partially
ordered set (poset). For notational simplicity, we again assume with-
out loss of generality that S0 is deterministic. A computation graph
is specified by functions f1 , . . . , fK in topological order. In analogy, a
Bayesian network is specified by conditional probability distribu-
tions pk of Sk given Sparents(k) . We can then define the generative
190 Inference in graphical models as differentiation

process

S0 := s0
S1 ∼ p1 (· | S0 )
S2 ∼ p2 (· | Sparents(2) )
..
.
SK ∼ pK (· | Sparents(K) ).

Using the chain rule of probability and variable independences expressed


by the DAG, the joint probability distribution is then (assuming a
topological order for S0 , S1 , . . . , SK )

P(S = s) := P(S1 = s1 , . . . , SK = sK )
K
= P(Sk = sk |Sparents(k) = sparents(k) )
Y

k=1
K
:= pk (sk |sparents(k) )
Y

k=1

This representation is well suited to express causal relationships between


random variables.

9.5.2 Parameterizing Bayesian networks

In a Bayesian framework, observed data, latent variables, parameters


and noise variables are all treated as random variables. If the conditional
distribution pk associated to node k depends on some parameters, they
can be provided to pk as conditioning, using parent nodes.
A Bayesian network is specified by the conditional distributions pk .
Therefore, unlike computation graphs, there is no notion of function fk in
a Bayesian network. However, the root nodes of the Bayesian network can
be the output of a neural network. For instance, autoregressive models,
such as RNNs or Transformers, specify the conditional probability
distribution of a token given past tokens, and the chain rule of probability
is used to obtain a probability distribution over entire sequences.
9.6. Markov random fields 191

9.5.3 Ancestral sampling


A major advantage of Bayesian networks is that, provided that each
conditional distribution pk is normalized, the joint distribution of S =
(S1 , . . . , SK ) is automatically normalized. This means that we can very
easily draw i.i.d. samples from the joint distribution, by following the
generative process: we follow the topological order k = 1, . . . , K and
on iteration k we draw a value sk ∼ pk (·|sparents(k) ) conditioned on the
previous values sparents(k) . This is known as ancestral sampling.

9.6 Markov random fields

9.6.1 Expressing factors using undirected graphs


A Markov random field (MRF), a.k.a. undirected graphical model,
specifies a distribution that factorizes as
1 Y
P(S = s) = p(s) := ψC (sC ),
Z C∈C

where C is the set of maximal cliques of G, that is, subsets of V that


are fully connected, Z is a normalization constant defined by

Z := ψC (sC ),
X Y

s∈S C∈C

and ψC : SC → R+ is a potential function (a.k.a. compatibility func-


tion), with SC := (Sj )j∈C . According to the Hammersley-Clifford theo-
rem, an MRF can be equivalently defined in terms of Markov properties;
we refer the interested reader to Wainwright and Jordan (2008). For the
sake of this chapter, the definition above is sufficient for our purposes.

Example 9.2 (Markov chains as Markov random fields). For a chain,


letting S = (S1 , . . . , SK ) and s = (s1 , . . . , sK ), recall that
K
P(S = s) = pk (sk | sk−1 ).
Y

k=1

This is equivalent to an MRF with Z = 1 (since a chain is auto-


192 Inference in graphical models as differentiation

matically normalized),

C := {{0, 1}, {1, 2}, . . . , {K − 1, K}}

and with potential function

ψ{k−1,k} (sk−1 , sk ) := pk (sk |sk−1 ).

More generally, a Bayesian network can be similarly written as an


MRF by creating appropriate potential functions corresponding to
the parents of each node.

9.6.2 MRFs as exponential family distributions

Let us define the potential functions

ψC (sC ; θC ) := exp(⟨θC , ϕC (sC )⟩)

for some sufficient statistic function ϕC : SC → ΘC and parameters


θC ∈ ΘC . Then,

1 Y
pθ (s) := ψC (sC ; θC )
Z(θ) C∈C
1 Y
= exp(⟨θC , ϕC (sC )⟩)
Z(θ) C∈C
1
!
= exp ⟨θC , ϕC (sC )⟩
X
Z(θ) C∈C
1
= exp (⟨θ, ϕ(s)⟩)
Z(θ)
= exp (⟨θ, ϕ(s)⟩ − A(θ))
9.6. Markov random fields 193

where

ϕ(s) := (ϕC (sC ))C∈C


θ := (θC )C∈C
Z(θ) := ψC (sC ; θC )
X Y

s∈S C∈C

= exp (⟨θ, ϕ(s)⟩)


X

s∈S
A(θ) := log Z(θ)

Therefore, for this choice of potential functions, we can view an MRF


as an exponential family distribution (Section 3.4) with natural pa-
rameters θ, sufficient statistic ϕ and log-partition function A(θ).

Example 9.3 (Ising model). The Ising model is a classical example


of MRF. Let Y = (Y1 , . . . , YM ) ∈ {0, 1}M be an unordered collec-
tion of binary variables Yi ∈ {0, 1}. This forms a graph G = (V, E),
where V = [M ] and E ⊆ V 2 , such that (i, j) ∈ E means that Yi in-
teracts with Yj . In statistical physics, Yi may indicate the presence
or absence of particles, or the orientation of magnets. In image
processing, Yi may represent a black and white pixel. In multi-label
classification, Yi may indicate the presence or absence of a label.
The probability of y = (y1 , . . . , yM ) ∈ {0, 1}M is then

P(Y = y) = pθ (y)
 

= exp  θi yi +
X X
θi,j yi yj − A(θ)
i∈V (i,j)∈E
!
= exp ⟨θC , ϕC (y)⟩ − A(θ) ,
X

C∈C

where C := V∪E and θ ∈ R|V|+|E| is the concatenation of (θi )i∈V and


(θi,j )(i,j)∈E . These models are also known as Boltzmann machines
in a neural network context. MAP inference in general Ising models
is know to be NP-hard, but when the interaction weights θi,j are non-
negative, MAP inference can be reduced to graph cut algorithms
194 Inference in graphical models as differentiation

(Greig et al., 1989). There are two ways the above equation can
be extended. First, we can use higher-order interactions, such as
yi yj yk for (i, j, k) ∈ V 3 . Second, we may want to use categorical
variables, which leads to the Potts model.

9.6.3 Conditional random fields


Conditional random fields (Lafferty et al., 2001; Sutton, McCallum,
et al., 2012) are a special case of Markov random field, in which a
conditioning variable is explicitly incorporated. For example, when
the goal is to predict a variable y conditioned on a variable x, CRFs
are defined as
1 Y
P(Y = y | X = x) = p(y | x) = ψC (yC , x).
Z(x) C∈C

Note that the potential functions ΨC are allowed to depend on the


whole x, as x is just a conditioning variable.

9.6.4 Sampling
Contrary to Bayesian networks, MRFs require an explicit normalization
constant Z. As a result, sampling from a distribution represented by a
general MRF is usually more involved than for Bayesian networks. A
commonly-used technique is Gibbs sampling.

9.7 Inference on chains

In this section, we review how to perform marginal inference and


maximum a-posteriori inference on joint distributions of the form
K
1 Y
p(s1 , . . . , sK ) = ψk (sk−1 , sk ),
Z k=1

where
K
Z := ψk (sk−1 , sk−1 )
XY

s∈S k=1
and where we used ψk as a shorthand for ψk−1,k , since k − 1 and k are
consecutive. As explained in Example 9.2, this also includes Markov
9.7. Inference on chains 195

chains by setting

ψk (sk−1 , sk ) := pk (sk | sk−1 ),

in which case Z = 1.

9.7.1 The forward-backward algorithm

The key idea of the forward-backward algorithm is to use the distribu-


tivity of multiplication over addition to write

Z= ψ1 (s0 , s1 ) ψ2 (s1 , s2 ) · · · ψK (sK−1 , sK ).


X X X

s1 ∈S1 s2 ∈S2 sK ∈SK

We can compute these sums recursively, either forward or backward.


Recalling the definitions of Ak−1 and Bk+1 in Eq. (9.1), we define the
summations up to and down to k,
k
αk (sk ) := ψj (sj−1 , sj )
X Y

s1 ,...,sk−1 ∈Ak−1 j=1

= ψk (sk−1 , sk ) · · · ψ2 (s1 , s2 )ψ1 (s0 , s1 )


X X

sk−1 ∈Sk−1 s1 ∈S1


K
βk (sk ) := ψj (sj−1 , sj )
X Y

sk+1 ,...,sK ∈Bk+1 j=k+1

= ψk+1 (sk , sk+1 ) · · · ψK (sK−1 , sK ).


X X

sk+1 ∈Sk+1 sK ∈SK

We can compute the two quantities by recursing forward and backward

αk (sk ) = ψk (sk−1 , sk )αk−1 (sk−1 )


X

sk−1 ∈Sk−1

βk (sk ) = ψk+1 (sk , sk+1 )βk+1 (sk+1 )


X

sk+1 ∈Sk+1

where we defined the initializations

α1 (s1 ) := ψ1 (s0 , s1 ) ∀s1 ∈ S1


βK (sK ) := 1 ∀sK ∈ SK .
196 Inference in graphical models as differentiation

The normalization term can then be computed by


Z= αK (sK )βK (sK ) = α1 (s1 )β1 (s1 )
X X

sK ∈SK s1 ∈S1
and the marginal probabilities by
1
P(Sk = sk ) = αk (sk )βk (sk )
Z
1
P(Sk−1 = sk−1 , Sk = sk ) = αk−1 (sk−1 )ψk (sk−1 , sk )βk (sk ).
Z
We can also compute the conditional probabilities by
P(Sk−1 = sk−1 , Sk = sk )
P(Sk = sk | Sk−1 = sk−1 ) =
P(Sk−1 = sk−1 )
αk−1 (sk−1 )ψk (sk−1 , sk )βk (sk )
=
αk−1 (sk−1 )βk−1 (sk−1 )
ψk (sk−1 , sk )βk (sk )
= .
βk−1 (sk−1 )
In practice, the two recursions are often implemented in the log-domain
for numerical stability,
log αk (sk ) = log exp(log ψk (sk−1 , sk ) + log αk−1 (sk−1 ))
X

sk−1 ∈Sk−1

log βk (sk ) = log exp(log ψk+1 (sk , sk+1 ) + log βk+1 (sk+1 )).
X

sk+1 ∈Sk+1

We recognize the log-sum-exp operator, which can be implemented


in a numerically stable way (Section 4.4.2). The overal dynamic pro-
gramming procedure, a.k.a. forward-backward algorithm (Baum
and Petrie, 1966; Rabiner, 1989), is summarized in Algorithm 9.1. We
notice that the forward and backward passes are actually independent
of each other, and can therefore be performed in parallel.

9.7.2 The Viterbi algorithm


Similarly, using the distributivity of multiplication over maximization,
K
max ψk (sk−1 , sk )
Y
s1 ∈S1 ,...,sK ∈SK
k=1
= max max ψK (sK−1 , sK ) . . . max ψ2 (s1 , s2 )ψ1 (s0 , s1 ).
sK ∈SK sK−1 ∈SK−1 s1 ∈S1
9.7. Inference on chains 197

Algorithm 9.1 Marginal inference on a chain


Potential functions: ψ1 , . . . , ψK
Input: s0
1: Initialize α1 (s1 ) := ψ1 (s0 , s1 ) ∀s1 ∈ S1
2: for k := 2, . . . , K do ▷ Forward pass
3: for sk ∈ Sk do X
4: αk (sk ) := ψk (sk−1 , sk )αk−1 (sk−1 )
sk−1 ∈Sk−1

5: Initialize βK (sK ) := 1 ∀sK ∈ SK


6: for k := K − 1, . . . , 1 do ▷ Backward pass
7: for sk ∈ Sk do X
8: βk (sk ) := ψk+1 (sk , sk+1 )βk+1 (sk+1 )
sk+1 ∈Sk+1

Compute Z = αK (sK )βK (sK ) = αK (sK )


X X
9:
sK ∈SK sK ∈SK
Outputs: ∀k ∈ [K]:
P(Sk = sk ) = Z1 αk (sk )βk (sk )
P(Sk−1 = sk−1 , Sk = sk ) = Z1 αk−1 (sk−1 )ψk (sk−1 , sk )βk (sk )

Let us define for k ∈ [K]

δk (sk ) := max ψk (sk−1 , sk ) . . . max ψ2 (s1 , s2 )ψ1 (s0 , s1 ).


sk−1 ∈Sk−1 s1 ∈S1

We can compute these quantities recursively, since for k ∈ [K]

δk (sk ) = max ψk (sk−1 , sk )δk−1 (sk−1 ),


sk−1 ∈Sk−1

with δ1 (sk ) := ψ(s0 , sk ). We finally have


1
max p(s1 , . . . , sK ) = max δK (sK ).
s1 ∈S1 ,...,sK ∈SK Z sK ∈SK
In practice, for numerical stability, we often implement the forward
recursion in the log-domain. Using the fact that the logarithm is
monotonic, we indeed have for all k ∈ [K]

log δk (sk ) = max log ψk (sk−1 , sk ) + log δk−1 (sk−1 ).


sk−1 ∈Sk−1
198 Inference in graphical models as differentiation

To enable efficient backtracking, during the forward pass, we compute


qk (sk ) := arg max ψk (sk−1 , sk )δk−1 (sk−1 )
sk−1 ∈Sk−1

which can be thought as backpointers from s⋆k to s⋆k−1 .


The resulting dynamic programming procedure, a.k.a. Viterbi al-
gorithm (Viterbi, 1967; Forney, 1973), is summarized in Algorithm 9.2.

Algorithm 9.2 MAP inference on a chain


Potential functions: ψ1 , . . . , ψK
Input: s0
1: Initialize δ1 (s1 ) := ψ1 (s0 , s1 ) ∀s1 ∈ S1
2: for k := 2, . . . , K do ▷ Forward pass
3: for sk ∈ Sk do
4: δk (sk ) := max ψk (sk−1 , sk )δk−1 (sk−1 )
sk−1 ∈Sk−1
5: qk (sk ) := arg max ψk (sk−1 , sk )δk−1 (sk−1 )
sk−1 ∈Sk−1

6: δ ⋆ := max δK (sK )
sK ∈SK
7: s⋆K := arg max δK (sK )
sK ∈SK
8: for k := K − 1, . . . , 1 do ▷ Backtracking
9: s⋆k := qk+1 (s⋆k+1 )
Outputs: max p(s1 , . . . , sK ) ∝ δ ⋆
s1 ∈S1 ,...,sK ∈SK
arg max p(s1 , . . . , sK ) = (s⋆1 , . . . , s⋆K )
s1 ∈S1 ,...,sK ∈SK

9.8 Inference on trees

More generally, efficient inference based on dynamic programming


can be performed when dependencies between variables are expressed
using a tree or polytree. The resulting marginal inference and MAP
inference algorithms are often referred to as the sum-product and
max-sum algorithms. The sum-product algorithm is also known as
belief propagation or message passing, since it can be interpreted
as propagating “local messages” through the graph. See for instance
(Wainwright and Jordan, 2008, Section 2.5.1) for more details.
9.9. Inference as differentiation 199

9.9 Inference as differentiation

In this section, we review the profound connections between differenti-


ating the log-partition function of an exponential family distribution
on one hand, and performing marginal inference (as well as maximum
a-posteriori inference in the zero-temperature limit) on the other hand.

9.9.1 Inference as gradient of the log-partition


We first discuss a well-known fact in the graphical model literature: when
using a binary encoding as the sufficient statistic ϕ in an exponential
family distribution, the gradient ∇A(θ) of the log-partition A(θ) gathers
all the marginals (Wainwright and Jordan, 2008).
To see why, recall from Section 3.4 the definition of an exponential
family distribution
pθ (s) = h(s) exp [⟨θ, ϕ(s)⟩ − A(θ)]
and of its log-partition
A(θ) := log h(s) exp [⟨θ, ϕ(s)⟩] .
X

s∈S

From Proposition 3.2,


µ(θ) := ∇A(θ) = EY ∼pθ [ϕ(Y )] ∈ M.
Therefore, with the binary encodings in Eq. (9.2) and Eq. (9.3),
P(Sk = vi ) = [∇A(θ)]k,i
P(Sk = vi , Sl = vj ) = [∇A(θ)]k,l,i,j .
Put differently, if we have an efficient algorithm for computing A(θ), we
can perform reverse-mode autodiff on A(θ) to obtain ∇A(θ), and
therefore obtain the marginal probabilities. Following Section 7.3.3, the
complexity of computing all marginal probabilities is therefore roughly
the same as that of computing A(θ).
In the special case of chains, we obtain
1
P(Sk = vi ) = [∇A(θ)]k,i = αk (vi )βk (vi )
Z
1
P(Sk−1 = vi , Sk = vj ) = [∇A(θ)]k−1,k,i,j = αk−1 (vi )ψk (vi , vj )βk (vj ),
Z
200 Inference in graphical models as differentiation

where we left the dependence of Z, α and β on θ implicit.


If we define Aε (θ) := εA(θ/ε), in the zero-temperature limit ε → 0,
we obtain that µ(θ) is a binary encoding of the mode, i.e., of the
maximum a-posteriori inference solution.
We now show i) how to unify the forward pass of the forward-
backward and Viterbi algorithms using semirings and softmax operators
ii) how to compute the gradient of the log-partition using backpropaga-
tion.

9.9.2 Semirings and softmax operators

The forward passes in the forward-backward and Viterbi algorithms are


clearly similar. In fact, they can be formally linked to each other using
semirings.

Definition 9.1 (Semiring). A semiring is a set K equipped with two


binary operations (⊕, ⊗) such that

• ⊗ is commutative and associative,

• ⊕ is associative and distributive over ⊕,

• ⊗ and ⊕ have identity element 0̄ and 1̄, respectively.

We use the notations ⊕, ⊗, 0̄ and 1̄ to clearly distinguish them from


the classical addition, multiplication, 0 and 1.
We recall the following laws for binary operations:

• Commutativity of ⊕: a ⊕ b = b ⊕ a,

• Associativity of ⊕: a ⊕ (b ⊕ c) = (a ⊕ b) ⊕ c,

• Distributivity of ⊗ over ⊕: a ⊗ (b ⊕ c) = (a ⊗ b) ⊕ (a ⊗ c).

A set equipped with a binary operation supporting associativity and an


identity element is called a monoid. A monoid such that every element
has an inverse element is called a group. The difference between a ring
and a semiring is that the latter only requires (K, ⊕) and (K, ⊗) to be
monoids, not groups.
9.9. Inference as differentiation 201

Equipped with these definitions, we can interpret the forward passes


in the Viterbi and forward-backward algorithms as follows:

• the forward-backward algorithm in the exponential domain uses


the semiring R+ equipped with (+, ×) and identity elements (0, 1);

• the Viterbi algorithm in the log domain uses the semiring R


equipped with (max, +) and identity elements (−∞, 0);

• the forward-backward algorithm in the log domain uses the semir-


ing R equipped with (maxε , +) and identity elements (−∞, 0),
where we defined the soft max operator (log-add-exp)

maxε (a, b) := ε log((exp(a) + exp(b))/ε),

with ε := 1 by default.

It can be checked that indeed maxε is commutative, associative, and


addition is distributive over maxε . Its identity element is −∞. By
associativity,

maxε (a1 , maxε (a2 , a3 )) = logsumexpε (a1 , a2 , a3 )


= ε log exp(ai /ε).
X

In constrast, note that the sparsemax in Section 12.4.2 is not associative.


Thanks to associativity, we can introduce the shorthand notations

maxε f (v) := ε log exp(f (v)/ε) ∈ R.


X
v∈V v∈V

and
!
argmaxε f (v) := exp(f (v ′ )/ε)/ exp(f (v)/ε)
X
∈ P(V).
v∈V v∈V v ′ ∈V

Many algorithms can be generalized thanks to the use of semirings;


see among others Aji and McEliece (2000) and Mohri et al. (2008). The
distributive and associative properties play a key role in breaking down
large problems into smaller ones (Verdu and Poor, 1987).
202 Inference in graphical models as differentiation

9.9.3 Inference as backpropagation

In this section, we show that, algorithmically, backtracking is recovered


as a special case of backpropagation. See also (Eisner, 2016; Mensch
and Blondel, 2018).
For notation simplicity, we assume S0 = {1} and Sk = {1, . . . , M }
for all k ∈ [K]. We focus on the case

log ψk (i, j) = ⟨θk , ϕk (i, j)⟩ = θk,i,j .

We also introduce the shorthands

a1,j := log α1 (j) = θ1,1,j


ak,j := log αk (j) = maxε θk,i,j + ak−1,i
i∈[M ]

and

qk,j := argmaxε θk,i,j + ak−1,i .


i∈[M ]

Our goal is to compute the gradient w.r.t. θ ∈ RK×M ×M of

log Z = A = maxε aK,j .


j∈[M ]

The soft argmax counterpart of this quantity is

Q := argmaxε aK,j ∈ △M ,
j∈[M ]

where we used P([M ]) = △M .


Computing the gradient of A is similar to computing the gradient
of a feedforward network, in the sense that θk influences not only ak
but also ak+1 , . . . , aK . Let us introduce the adjoint variable
∂A
rk,i := ,
∂ak,i
which we initialize as
∂A
rK,i = = Qi .
∂aK,i
9.9. Inference as differentiation 203

Since θk,i,j directly influences ak,j , we have for k ∈ [K], i ∈ [M ] and


j ∈ [M ]
∂A
µk,i,j :=
∂θk,i,j
∂A ∂ak,j
= ·
∂ak,j ∂θk,i,j
= rk,j · qk,j,i .

Since ak,i directly influences ak+1,j for j ∈ [M ], we have for k ∈


{1, . . . , K − 1} and i ∈ [M ]
∂A ∂A ∂ak+1,j
rk,i = =
X
∂ak,i j∈[M ] ∂ak+1,j ∂ak,i

=
X
rk+1,j qk+1,j,i
j∈[M ]

=
X
µk+1,i,j .
j∈[M ]

We summarize the procedure in Algorithm 9.3. The forward pass


uses the softmax operator maxε and the softargmax operator argmaxε .
In the hard max case, in Algorithm 9.2, we used q to store backpointers
from integer to integer. In the soft max case, in Algorithm 9.3, we used q
to store soft backpointers, that is, discrete probability distributions. In
the zero-temperature limit, backpropagation outputs a binary encoding
of the solution of backtracking.
204 Inference in graphical models as differentiation

Algorithm 9.3 Inference on a chain as backprop with max operators


Input: θ ∈ RK×M ×M
Max operator: maxε
1: Initialize a1,j := θ1,1,j ∀j ∈ [M ]
2: for k := 2, . . . , K do ▷ Forward pass
3: for j ∈ [M ] do
4: ak,j := maxε θk,i,j + ak−1,j ∈ R
i∈[M ]
5: qk,j := argmaxε θk,i,j + ak−1,j ∈ △M
i∈[M ]

6: A := maxε aK,i ∈ R
i∈[M ]
7: Q := argmaxε aK,i ∈ △M
i∈[M ]
8: Initialize rK,j = Qj ∀j ∈ [K]
9: for k := K − 1, . . . , 1 do ▷ Backward pass
10: for i ∈ [M ] do
11: for j ∈ [M ] do
12: µk+1,i,j = rk+1,j · qk+1,j,i
13: rk,i ← µk+1,i,j
maxε θ1,1,i1 + = A, ∇A(θ) = µ
PK
Outputs: k=2 θk,ik−1 ,ik
i1 ,...,iK ∈[M ]K

9.10 Summary

Graphical models represent the conditional dependencies between vari-


ables and therefore specify how their joint distribution factorizes. We
saw clear analogies between the worlds of functions and of distributions:
the counterparts of computation chains and computation graphs are
Markov chains and Bayesian networks. Inference on chains and more
generally on trees, for exponential family distributions, is equivalent,
both statistically and algorithmically, to differentiating the log-partition
function. The forward-backward algorithm can be seen as using a sum-
product algebra, while the Viterbi algorithm can be seen as using a
max-plus algebra. Equivalently, in the log domain, we can see the former
as using a soft max, and the latter as using a hard max.
10
Differentiating through optimization

In this chapter, we study how to differentiate through optimization


problems, and more generally through nonlinear systems of equations.

10.1 Implicit functions

Implicit functions are functions that do not enjoy an explicit decompo-


sition into elementary functions, for which automatic differentiation, as
studied in Chapter 7, can therefore not be directly applied. We describe
in this chapter techniques to differentiate through such functions and
how to integrate them into an autodiff framework.
Formally, we will denote an implicit function by w⋆ (λ), where
w⋆ : Λ → W. One question is then how to compute the Jacobian
∂w⋆ (λ). As a first application one can consider sensitivity analysis
of a system. For example, w⋆ (λ) could correspond to the equilibrium
state of a physical system and in this case, ∂w⋆ (λ) would tell us about
the sensitivity of the system to some parameters λ ∈ Λ.

205
206 Differentiating through optimization

10.1.1 Optimization problems


Another example is a function implicitly defined as the solution (assumed
unique) of an optimization problem

w⋆ (λ) = arg max f (w, λ),


w∈W

where f : W × Λ → R and W denotes a constraint set. Note that we


use an arg max for convenience, but the same applies when using an
arg min.

10.1.2 Nonlinear equations


More generally, w⋆ (λ) can be defined as the root of some function
F : W × Λ → W, i.e., w⋆ (λ) is implicitly defined as the function
satisfying the (potentially nonlinear) system of equations

F (w, λ) = 0

for all λ ∈ Λ.

10.1.3 Application to bilevel optimization


Besides sensitivity analysis, another example of application is bilevel
optimization. Many times, we want to minimize a function defined as
the composition of a fixed function and the solution of an optimization
problem. Formally, let f, g : W × Λ → R. We consider the composition
h(λ) defined as

h(λ) := g(w⋆ (λ), λ), where w⋆ (λ) = arg max f (w, λ). (10.1)
w∈W

This includes for instance hyperparameter optimization, where f is an


inner log-likelihood objective, g is an outer validation loss, w ∈ W
are model parameters and λ ∈ Λ are model hyperparameters, such
as regularization strength. To minimize h(λ) one generally resorts to
a gradient descent scheme w.r.t. λ, which requires computing ∇h(λ).
Assuming that w⋆ (λ) is differentiable at λ, by the chain rule, we obtain
the Jacobian

∂h(λ) = ∂1 g(w⋆ (λ), λ)∂w⋆ (λ) + ∂2 g(w⋆ (λ), λ).


10.2. Envelope theorems 207

Functions for varying

Function

Figure 10.1: The graph of h(λ) = maxw∈W f (w, λ) is the upper-envelope of the
graphs of the functions λ 7→ f (w, λ) for all w ∈ W.

Using ∂h(λ)⊤ = ∇h(λ) (see Remark 2.3), we obtain the gradient

∇h(λ) = ∂w⋆ (λ)⊤ ∇1 g(w⋆ (λ), λ) + ∇2 g(w⋆ (λ), λ).

The only problematic term is ∂w⋆ (λ), as it requires argmax differenti-


ation. Indeed, most of the time, there is no explicit formula for w⋆ (λ)
and it does not decompose into elementary functions.

10.2 Envelope theorems

In the special case g = f , the composition h defined in Eq. (10.1) is


simply given by

h(λ) = f (w⋆ (λ), λ) = max f (w, λ).


w∈W

That is, we no longer need argmax differentiation, but only max


differentiation, which, as we shall now see is much easier. The function
h is often called a value function (Fleming and Rishel, 2012). The
reason for the name “envelope” is illustrated in Fig. 10.1. We emphasize
that there is not one, but several envelope theorems, depending on the
assumptions on f .
208 Differentiating through optimization

10.2.1 Danskin’s theorem


When f is concave-convex, we can use Danskin’s theorem.

Theorem 10.1 (Danskin’s theorem). Let f : W × Λ → R and W be


a compact convex set. Let

h(λ) := max f (w, λ)


w∈W

and
w⋆ (λ) := arg max f (w, λ).
w∈W
If f is concave in w, convex in λ, and the maximum w⋆ (λ) is
unique, then the function h is differentiable with gradient

∇h(λ) = ∇2 f (w⋆ (λ), λ).

If the maximum is not unique, we get a subgradient.

Informally, Danskin’s theorem means that we can treat w⋆ (λ) as if


it were a constant of λ, i.e., we do not need to differentiate through it,
even though it depends on λ. Danskin’s theorem can also be used to
differentiate through a minimum, h(λ) = minw∈W f (w, λ), if f (w, λ)
is convex in w and concave in λ, as we now illustrate.

Example 10.1 (Ilustration of Danskin’s theorem). Let us define h(λ) =


minw∈R f (w, λ), where f (w, λ) = λ2 w2 +bw+c and λ > 0. Let w⋆ (λ)
be the minimum. The derivative of f w.r.t. λ is 12 w2 . From Dan-
skin’s theorem, we have h′ (λ) = 12 w⋆ (λ). Let us check that this
result is correct. The derivative of f w.r.t. w is λw + b. Setting it to
2
zero, we get w⋆ (λ) = − λb . We thus obtain h′ (λ) = 12 λb 2 . Let’s check
that the result is indeed correct. Plugging w⋆ (λ) back into f (w, λ),
2
we get h(λ) = − 12 bλ + c. Using ( λ1 )′ = − λ12 , we indeed obtain the
same result for h′ (λ).

Danskin’s theorem has a simple interpretation for functions that are


linear in λ as shown below.
10.2. Envelope theorems 209

Example 10.2 (Convex conjugate). Let f (w, λ) := ⟨w, λ⟩ − Ω(w)


with Ω convex. We then have h(λ) = maxw∈W ⟨w, λ⟩ − Ω(w) =:
Ω∗ (λ), where Ω∗ denotes the convex conjugate of Ω. Since f satisfies
the conditions of Danskin’s theorem and since we have ∇2 f (w, λ) =
w, we obtain ∇h(λ) = ∇Ω∗ (λ) = w⋆ (λ). In other words, in this
special case, the gradient of the max is equal to the argmax. This
is due to the fact that f (w, λ) is linear in λ.

Another application is saddle point optimization.


Example 10.3 (Saddle point problem). Consider the saddle point
problem minλ∈Λ maxw∈W f (w, λ). If it is difficult to minimize w.r.t.
λ but easy to maximize w.r.t. w, we can rewrite the problem as
minλ∈Λ h(λ), where h(λ) := maxw∈W f (w, λ), and use ∇h(λ) to
perform (projected) gradient descent w.r.t. λ.

10.2.2 Rockafellar’s theorem


A related theorem can be proved under different assumptions about f ,
in particular without concavity w.r.t. w.
Theorem 10.2 (Rockafellar’s envelope theorem). Let f : W × Λ →
R and W be a compact convex set. Let

h(λ) := max f (w, λ)


w∈W

and
w⋆ (λ) := arg max f (w, λ).
w∈W
If f is continuously differentiable in λ for all w ∈ W, ∇1 f is
continuous and the maximum w⋆ (λ) is unique, then the function
h is differentiable with gradient

∇h(λ) = ∇2 f (w⋆ (λ), λ).

See Rockafellar and Wets (2009, Theorem 10.31). Compared to


Danskin’s theorem, Rockafellar’s theorem does not require f to be
concave-convex, but requires stronger assumptions on the differentiabil-
ity of f .
210 Differentiating through optimization

10.3 Implicit function theorem

10.3.1 Univariate functions


The implicit function theorem (IFT) provides conditions under which
an implicit relationship of the form F (x, λ) = 0 can be rewritten as a
function x = x⋆ (λ) locally, and provides a way to compute its derivative
w.r.t. λ.
Theorem 10.3 (Implicit function theorem, univariate case). Let
F : R × R → R. Assume F (x, λ) is a continuously differentiable
function in a neighborhood U of (x0 , λ0 ) such that F (x0 , λ0 ) = 0
and ∂1 F (x0 , λ0 ) ̸= 0. Then there exists a neighborhood V ⊆ U of
(x0 , λ0 ) in which there is a function x⋆ (λ) such that

• x⋆ (λ0 ) = x0 ,

• F (x⋆ (λ), λ) = 0 for all λ in the neighborhood V,



• ∂x⋆ (λ) = − ∂∂21 FF (x (λ),λ)
(x⋆ (λ),λ) .

We postpone the proof to the multivariate case and begin with a


classical example of application of the theorem.

Example 10.4 (Equation of the unit circle). We use w ≡ x and λ ≡


y for clarity. Let F (x, y) := x2 + y 2 − 1. In general, we cannot
rewrite the unit circle equation F (x, y) = 0 as a function from y
to x, because for everypy ∈ [−1, 1], therepare always two possible
x values, namely, x = 1 − y 2 or x = − 1 − y 2 . However, locally
around some point (x0 , y0 ), e.g., such that x0 >p0 and y0 > 0
(upper-right quadrant), the function x = x⋆ (y) = 1 − y 2 is well-
defined. Using ∂1 F (x, y) = 2x and ∂2 F (x, y) = 2y, we get ∂x⋆ (y) =
y
− 2x2y
⋆ (y) = − √ 2
in that neighborhood (the upper right quadrant
1−y
in this case). This is indeed the same derivative expression as if we
used the chain rule on 1 − y 2 and is well-defined on y ∈ [0, 1).
p

In the above simple example, we can easily derive an explicit function


relating y to x in a given neighborhood, but this is not always the case.
The IFT gives us conditions guaranteeing that such function exists and
10.3. Implicit function theorem 211

a way to differentiate it, but not a way to construct such a function.


In fact, finding x⋆ (λ) such that F (x⋆ (λ), λ) = 0 typically involves a root
finding algorithm, an optimization algorithm, a nonlinear system solver,
etc.

Example 10.5 (Polynomial). Let F (w, λ) = w5 + w3 + w − λ. Ac-


cording to the Abel-Ruffini theorem (Tignol, 2015), quintics (poly-
nomials of degree 5) do no enjoy roots in terms of radicals and
one must resort to numerical root finding. In addition, odd-degree
polynomials have real roots. Moreover, ∂1 F (w, λ) = 5w4 + 3w2 + 1
is strictly positive. Therefore, by the intermediate value theorem,
there must be only one root w⋆ (λ) such that F (w⋆ (λ), λ) = 0. This
unique root can for example be found by bisection. Using the IFT,
its derivative is found to be ∂w⋆ (λ) = (5w⋆ (λ)4 + 3w⋆ (λ)2 + 1)−1 .

While an implicit function is differentiable at a point if the assump-


tions of the IFT hold in a neighborhood of that point, the reciprocal is
not true: failure of the IFT assumptions does not necessarily mean that
the implicit function is not differentiable, as we now illustrate.

Example 10.6 (IFT conditions are not necessary for differentiability).


Consider F (w, λ) = (w − λ)2 . We clearly have that F (w⋆ (λ), λ) = 0
if we define w⋆ (λ) = λ, the identity function. It is clearly differen-
tiable for all λ, yet the assumptions of the IFT fail, since we have
∂1 F (w, λ) = 2(w − λ) and therefore ∂1 F (0, 0) = 0.

10.3.2 Multivariate functions


We now present the IFT in the general multivariate setting. Informally,
if F (w⋆ (λ), λ) = 0, then by the chain rule, we have

∂1 F (w⋆ (λ), λ)∂w⋆ (λ) + ∂2 F (w⋆ (λ), λ) = 0,

meaning that the Jacobian ∂w⋆ (λ), assuming that it exists, satisfies

−∂1 F (w⋆ (λ), λ)∂w⋆ (λ) = ∂2 F (w⋆ (λ), λ).

The IFT gives us conditions for the existence of ∂w⋆ (λ).


212 Differentiating through optimization

Theorem 10.4 (Implicit function theorem, multivariate case). Let us


define F : W × Λ → W. Assume F (w, λ) is a continuously differen-
tiable function in a neighborhood of (w0 , λ0 ) such that F (w0 , λ0 ) =
0 and ∂1 F (w0 , λ0 ) is invertible, i.e., its determinant is nonzero.
Then there exists a neighborhood of λ0 in which there is a function
w⋆ (λ) such that

• w⋆ (λ0 ) = w0 ,

• F (w⋆ (λ), λ) = 0 for all λ in the neighborhood,

• −∂1 F (w⋆ (λ), λ)∂w⋆ (λ) = ∂2 F (w⋆ (λ), λ)


⇐⇒ ∂w⋆ (λ) = −∂1 F (w⋆ (λ), λ)−1 ∂2 F (w⋆ (λ), λ).

We begin with a simple unconstrained optimization algorithm.


Example 10.7 (Unconstrained optimization). Assume we want to
differentiate through w⋆ (λ) = arg minw∈RP f (w, λ), where f is
strictly convex in w, which ensures that the solution is unique. From
the stationary conditions, if we define F (w, λ) := ∇1 f (w, λ), then
w⋆ (λ) is uniquely characterized as the root of F in the first argu-
ment, i.e., F (w⋆ (λ), λ) = 0. We have ∂1 F (w, λ) = ∇21 f (w, λ), the
Hessian of f in w, and ∂2 F (w, λ) = ∂2 ∇1 f (w, λ), the cross deriva-
tives of f . Therefore, assuming that the Hessian is well-defined and
invertible at (w⋆ (λ), λ), we can use the IFT to differentiate through
w⋆ (λ) and obtain ∂w⋆ (λ) = −∇21 f (w⋆ (λ), λ)∂2 ∇1 f (w⋆ (λ), λ).

Next, we generalize the previous example, by allowing constraints


in the optimization problem.
Example 10.8 (Constrained optimization). Now, assume we want
to differentiate through w⋆ (λ) = arg minw∈C f (w, λ), where f is
strictly convex in w and C ⊆ W is a convex set. A solution is
characterized by the fixed point equation w⋆ (λ) = PC (w⋆ (λ) −
η∇1 f (w⋆ (λ), λ)), for any η > 0, where P (y) := arg minx∈C ∥x−y∥22
is the Euclidean projection of y onto C. Therefore, w⋆ (λ) is the
root of F (w, λ) = w − PC (w − η∇1 f (w, λ)) (see Chapter 15). We
10.3. Implicit function theorem 213

can differentiate through w⋆ (λ) using the IFT, assuming that the
conditions of the theorem apply. Note that ∂1 F (w, λ) requires
the expression of the Jacobian ∂PC (y). Fortunately, PC (y) and its
Jacobian are easy to compute for many sets C (Blondel et al., 2021).

10.3.3 JVP and VJP of implicit functions

To integrate an implicit function w⋆ (λ) in an autodiff framework, we


need to be able to compute its JVP or VJP. This is the purpose of the
next proposition.

Proposition 10.1 (JVP and VJP of implicit functions). Let w⋆ : Λ →


W be a function implicitly defined as the solution of F (w⋆ (λ), λ) =
0, for some function F : W × Λ → W. Define

A := −∂1 F (w⋆ (λ), λ)


B := ∂2 F (w⋆ (λ), λ).

Assume the assumptions of the IFT hold. The JVP t := ∂w⋆ (λ)v
in the input direction v ∈ Λ is obtained by solving the linear system

At = Bv.

The VJP ∂w⋆ (λ)∗ u in the output direction u ∈ W is obtained by


solving the linear system

A∗ r = u.

Using the solution r, we get

∂w⋆ (λ)∗ u = ∂w⋆ (λ)∗ A∗ r = B ∗ r.

Note that in the above linear systems, we can access to A and B as


linear maps, the JVPs of F . Their adjoints, A∗ and B ∗ , correspond to
the VJPs of F . To solve these systems, we can therefore use matrix-free
solvers as detailed in Section 8.4. For example, when A is symmetric pos-
itive semi-definite, we can use the conjugate gradient method (Hestenes,
Stiefel, et al., 1952). When A is not symmetric positive definite, we can
214 Differentiating through optimization

use GMRES (Saad and Schultz, 1986) or BiCGSTAB (Vorst and Vorst,
1992).

10.3.4 Proof of the implicit function theorem


We prove the theorem using the inverse function theorem presented
in Theorem 10.5. Define
f (λ, w) = (λ, F (w, λ))
which goes from RQ × RP onto RQ × RP . The Jacobian of f is
!
I 0
∂f (λ, w) = .
∂2 F (w, λ) ∂1 F (w, λ)
So at w0 , λ0 , we have det(∂f (λ0 , w0 )) = det(I) det(∂1 F (w0 , λ0 )) >
0 since we assumed ∂1 F (w0 , λ0 ) invertible. By the inverse function
theorem, the function f is then invertible in a neighborhood N of
f (λ0 , w0 ) = (λ0 , 0). In particular, it is invertible in N ∩ {(λ, 0), λ ∈
RQ }. The solution of the implicit equation in a neighborhood of λ0
is then (λ, w∗ (λ)) = f −1 (λ, 0). By the inverse function theorem, f −1
is continuously differentiable inverse and so is w∗ (λ). The derivative
∂w∗ (λ) from the differential of the inverse as
!
∼ ∼
= ∂f −1 (λ, 0),
∂w∗ (λ) ∼
and by the inverse function theorem, we have ∂f −1 (λ, 0) = (∂f (λ, w∗ (λ)))−1 .
So using block matrix inversions formula
!−1 !
A B ∼ ∼
= ,
C D −(D − CA−1 B)−1 CA−1 ∼
we get the claimed expression. Though we expressed the proof in terms of
Jacobians and matrices, the result naturally holds for the corresponding
linear operators, JVPs, VJPs, and their inverses.

10.4 Adjoint state method

10.4.1 Differentiating nonlinear equations


We describe in this section the adjoint state method (a.k.a. adjoint
method, method of adjoints, adjoint sensitivity method). The method
10.4. Adjoint state method 215

can be used to compute the gradient of the composition of an explicit


function and an implicit function, defined through an equality
constraint (e.g., a nonlinear equation). The method dates back
to Céa (1986).
Suppose a variable s ∈ S (which corresponds to a state in optimal
control) is implicitly defined given some parameters w ∈ W through
the (potentially nonlinear) equation c(s, w) = 0, where c : S × W → S.
Assuming s is uniquely determined for all w ∈ W, this defines an
implicit function s⋆ (w) from W to S such that c(s⋆ (w), w) = 0.
Given an objective function L : S × W → R, the goal of the adjoint
state method is then to compute the gradient of
L(w) := L(s⋆ (w), w).
However, this is not trivial as s⋆ (w) is an implicit function. For instance,
this can be used to convert the equality-constrained problem
min L(s, w) s.t. c(s, w) = 0.
w∈W
into the unconstrained problem
min L(s⋆ (w), w).
w∈W

Access to ∇L(w) allows us to solve this problem by gradient descent.

Proposition 10.2 (Adjoint state method). Let c : S × W → S be a


mapping defining constraints of the form c(s, w). Assume that for
each w ∈ W, there exists a unique s⋆ (w) satisfying c(s⋆ (w), w) = 0
and that s⋆ (w) is differentiable. The gradient of

L(w) := L(s⋆ (w), w),

for some differentiable function L : S × W → R, is given by

∇L(w) = ∇2 L(s⋆ (w), w) + ∂2 c(s⋆ (w), w)∗ r ⋆ (w),

where r ⋆ (w) is the solution of the linear system

∂1 c(s⋆ (w), w)∗ r = −∇1 L(s⋆ (w), w).

As shown in the proof below, r ⋆ (w) corresponds to a Lagrange


multiplier. The linear system can be solved using matrix-free solvers.
216 Differentiating through optimization

10.4.2 Relation with envelope theorems


Because s is uniquely determined for any w ∈ W by c(s, w) = 0, we can
alternatively rewrite L(w) as the trivial minimization or maximization,
L(w) = min L(s, w) s.t. c(s, w) = 0
s∈S
= max L(s, w) s.t. c(s, w) = 0.
s∈S

Therefore, the adjoint state method can be seen as an envelope theorem


for computing ∇L(w), for the case when w is involved in both the
objective function and in the equality constraint.

10.4.3 Proof using the method of Lagrange multipliers


Classically, the adjoint state method is derived using the method of
Lagrange multipliers. Let us introduce the Lagrangian associated with
L and c,
L(s, w, r) := L(s, w) + ⟨r, c(s, w)⟩,
where r ∈ S is the Lagrange multiplier associated with the equality
constraint c(s, w) = 0. In the optimal control literature, r is often
called the adjoint variable or adjoint state. The gradients of the
Lagrangian are
∇s L(s, w, r) = ∇1 L(s, w) + ∂1 c(s, w)∗ r
∇w L(s, w, r) = ∇2 L(s, w) + ∂2 c(s, w)∗ r
∇r L(s, w, r) = c(s, w),
where ∂i c(s, w)∗ are the adjoint operators. Setting ∇r L(s, w, r) to
zero gives the constraint c(s, w) = 0. Setting ∇s L(s, w, r) to zero gives
the so-called adjoint state equation
∂1 c(s, w)∗ r = −∇1 L(s, w).
Solving this linear system w.r.t. r at s = s⋆ (w) gives the adjoint variable
r ⋆ (w). We then get
∇L(w) = ∇2 L(s⋆ (w), w, r ⋆ (w))
= ∇2 L(s⋆ (w), w) + ∂2 c(s⋆ (w), w)∗ r ⋆ (w),
which concludes the proof.
10.4. Adjoint state method 217

10.4.4 Proof using the implicit function theorem

A more direct proof is possible thanks to the implicit function theorem


(Section 10.3). Using the chain rule, we get

∇L(w) = ∇2 L(s⋆ (w), w) + ∂s⋆ (w)∗ ∇1 L(s⋆ (w), w),

where ∂s⋆ (w)∗ is the VJP of s⋆ , a linear map from S to W.


Computationally, the main difficulty is to apply ∂s⋆ (w)∗ to the
vector u = ∇1 L(s⋆ (w), w) ∈ S. Using the implicit function theorem
(Section 10.3) on the implicit function c(s⋆ (w), w) = 0, and Proposi-
tion 10.1, we get the linear system A∗ r = u, where A∗ := ∂1 c(s⋆ (w), w)∗
is a linear map from S to S. After solving for r, we get ∂s⋆ (w)∗ u = B ∗ r,
where B ∗ := ∂2 c(s⋆ (w), w)∗ is a linear map from S to W. Putting ev-
erything together, we get

∇L(w) = ∇2 L(s⋆ (w), w) + ∂2 c(s⋆ (w), w)∗ r.

10.4.5 Reverse mode as adjoint method with backsubstitution

In this section, we revisit reverse-mode autodiff from the perspective


of the adjoint state method. For clarity, we focus our exposition on
feedforward networks with input x ∈ X and network weights w =
(w1 , . . . , wk ) ∈ W1 × . . . × WK ,

s0 := x ∈ X
s1 := f1 (s0 , w1 ) ∈ S1
..
.
sK := fK (sK−1 , wK ) ∈ SK
f (w) := sK . (10.2)

Here we focus on gradients with respect to the parameters w, hence


the notation f (w). We can use the adjoint state method to recover
reverse-mode autodiff, and prove its correctness in the process. While
we focus for simplicity on feedforward networks, our exposition can be
generalized to computation graphs.
218 Differentiating through optimization

Feedforward networks as the solution of a nonlinear equation


While we defined the set of intermediate computations s = (s1 , . . . , sK ) ∈
S1 × . . . × SK as a sequence of operations, they can also be defined as
the unique solution of the nonlinear equation c(s, w) = 0, where

s1 − f1 (x, w1 )
 

s2 − f2 (s1 , w2 ) 
c(s, w) := 
 
.. .
.
 
 
sK − fK (sK−1 , wK )

This defines an implicit function s⋆ (w) = (s⋆1 (w), . . . , s⋆K (w)), the
solution of this nonlinear system, which is given by the variables
s1 , . . . , sK defined in (10.2). The output of the feedforward network is
then f (w) = s⋆K (w).
In machine learning, the final layer s⋆K (w) is typically fed into a
loss ℓ, to define
L(w) := ℓ(s⋆K (w); y).
Note that an alternative is to write L(w) as

L(w) = min ℓ(s; y) s.t. c(s, w) = 0.


s∈S

More generally, if we just want to compute the VJP of s⋆K (w) in


some direction uK ∈ SK , we can define the scalar-valued function

L(w) := ℓ(s⋆K (w); uK ) := ⟨s⋆K (w), uK ⟩

so that
∂f (w)∗ uK = ∇L(w).
Let us define u ∈ S1 × · · · × SK−1 × SK as u := (0, . . . , 0, ∇1 ℓ(f (w); y))
(gradient of the loss ℓ case) or u := (0, . . . , 0, uK ) (VJP of f in the
direction uK case). Using the adjoint state method, we know that the
gradient of this objective is obtained as

∇L(w) = ∂2 c(s(w), w)∗ r ⋆ (w),

for r ⋆ (w) the solution of the linear system

∂1 c(s(w), w)∗ r = −u.


10.4. Adjoint state method 219

Solving the linear system using backsubtitution


The JVP of the constraint function c at s⋆ (w), materialized as a matrix,
takes the form of a block lower-triangular matrix

I 0 0
 
... ...
.. .. 
I . .

−A
1
.. 
 
∂1 c(s (w), w) =  0

 ..

−A2 I . . ,
 .
 . .. .. ..

 . . . . 0

0 ... 0 −AK I

where Ak := ∂1 fk (sk−1 , wk ). Crucially the triangular structure of the


JVP stems from the fact that each intermediate activation only de-
pends from the past intermediate activations. Therefore, the constraints,
corresponding to the lines of the Jacobian, cannot introduce non-zero
values beyond its diagonal. The VJP takes the form of a block upper-
triangular matrix

I −A∗1 0 0
 
...
.. .. 
0 I . . 

−A∗2
.
 
.. ..
 ..
∂1 c(s⋆ (w), w)∗ = 

. I . 0  .
.
. .. ..

. . . −A∗K 

0 ... ... 0 I

Solving an upper triangular system like ∂1 c(s(w), w)∗ r = u can then


be done efficiently by backsubstitution. Starting from the last adjoint
state rK = u, we can compute each adjoint state rk from that computed
at k + 1. Namely, for k ∈ (K − 1, . . . , 1), we have

rk − A∗k+1 rk+1 = 0 ⇐⇒ rk = ∂fk+1 (sk , wk+1 )∗ rk+1 .

The VJPs with respect to the parameters are then obtained by

∂2 f1 (x, w1 )∗ r1
 



∂2 f2 (s1 (w), w2 )∗ r2 
∂2 c(s(w), w) r = 
 
.. ,
.
 
 
∂2 fK (s1 (w), wK )∗ rK
220 Differentiating through optimization

recovering reverse-mode autodiff.


The Lagrangian perspective of backpropagation for networks with
separate parameters w = (w1 , . . . , wK ) is well-known; see for instance
LeCun (1988) or Recht (2016). The Lagrangian perspective of back-
propagation through time (Werbos, 1990) for networks with shared
parameter w is discussed for instance by Franceschi et al. (2017). Our
exposition uses the adjoint state method, which can itself be proved
either using the method of Lagrange multipliers (Section 10.4.3) or by
the implicit function theorem (Section 10.4.4), combined with backsub-
titution for solving the upper-triangular linear system. Past works often
minimize over w but we do not require this, as gradients are not neces-
sarily used for optimization. Our exposition also supports computing
the VJP of any vector-valued function f , while existing works derive
the gradient of a scalar-valued loss function.

10.5 Inverse function theorem

10.5.1 Differentiating inverse functions


In some cases (see for instance Section 11.4.3), it is useful to compute
the Jacobian of an inverse function f −1 . The inverse function theorem
below allows us to relate the Jacobian of f −1 with the Jacobian of f .

Theorem 10.5 (Inverse function theorem). Assume f : W → W is


continuously differentiable with invertible Jacobian ∂f (w0 ) at w0 .
Then f is bijective from a neighborhood of w0 to a neighborhood
of f (w0 ). Moreover, the inverse f −1 is continuously differentiable
near ω0 = f (w0 ) and the Jacobian of the inverse ∂f −1 (ω) is

∂f (w)∂f −1 (ω) = I ⇔ ∂f −1 (ω) = (∂f (w))−1 ,

with w = f −1 (ω).

10.5.2 Link with the implicit function theorem


The inverse function theorem can be used to prove the implicit function
theorem; see proof of Theorem 10.4. Conversely, recall that, in order to
use the implicit function theorem, we need to choose a root objective
10.5. Inverse function theorem 221

F : W × Λ → W. If we set W = Λ = RQ and F (w, ω) = f (w) −


ω, with f : RQ → RQ , then we have that the root w⋆ (ω) satisfying
F (w⋆ (ω), ω) = 0 is exactly w⋆ (ω) = f −1 (ω). Moreover, ∂1 F (w, ω) =
∂f (w) and ∂2 F (w, ω) = −I. By applying the implicit function theorem
with this F , we indeed recover the inverse function theorem.

10.5.3 Proof of inverse function theorem


We first give a proof of the formula assuming that f −1 is well-defined
and continuously differentiable in a neighborhood of f (w0 ). In that
case, we have for any ω in a neighborhood of f (w0 ),

f ◦ f −1 (ω) = ω.

Differentiating both sides w.r.t. ω, we get

∂f (f −1 (ω))∂f −1 (ω) = I,

where I is the identity function in RQ . In particular, for w = f −1 (ω)


we recover the formula presented in the statement.
Now, it remains to show that invertibility of the JVP ensures that
the function is invertible in a neighborhood of f (w0 ) and that the
inverse is continuously differentiable. For that, denote l = ∂f (w0 ) such
that l−1 is well-defined by definition. f is invertible with continuously
differentiable inverse, if and only if l−1 (f (w)) − f (w0 ) is invertible with
continuously differentiable inverse. So without loss of generality, we
consider ∂f (w0 ) = I, f (w0 ) = 0, w0 = 0.
As f is continuously differentiable, there exists a neighborhood
N = {w : ∥w − w0 ∥2 ≤ δ} on which we have ∥∂f (w) − I ∥2 ≤ 1/2. In
this neighborhood, the function g(w) = f (w) − w is contractive by the
mean value theorem with contraction factor 1/2. For any ω such that
∥ω − f (w0 )∥2 ≤ δ/2, the sequence wk+1 = wk − f (wk ) − ω ′ remains in
N and converges (since it is a Cauchy sequence by the contraction of g)
to a unique fixed point w∞ satisfying w∞ = w∞ − f (w∞ ) − ω ⇐⇒
f (w∞ ) = ω. This shows the existence of the inverse in the neighborhood
M = {ω : ∥ω − ω0 ∥2 ≤ δ/2} of ω0 = f (w0 ) onto N .
We tackle now the differentiability (hence the continuity) of f −1 .
For any ω in the neighborhood of ω0 with inverse w := f −1 (ω) ∈ N ,
222 Differentiating through optimization

the JVP of f at w satisfies by assumption ∥∂f (w) − I ∥2 ≤ 1/2. Hence,


a = ∂f (w) − I defines a convergent series b = +∞ k=0 a and one verifies
k
P

easily that b = ∂f (w)−1 , that is ∂f (w) is invertible and ∥∂f (w)−1 ∥ ≤ 2.


To compute the JVP of the inverse, we consider then ∂f (w)−1 as the
candidate JVP and examine
∥f −1 (ω + η) − f (ω) − (∂f (w))−1 η∥2
.
∥η∥2

Denote then v such that f −1 (ω + η) = w + v. As g(w) = f (w) − w is


1/2-contractive in N , we have ∥v−η∥2 = ∥g(w+v)−g(w)∥2 ≤ 1/2∥v∥2 .
So ∥v∥2 ≥ ∥η∥/2. We then get

∥f −1 (ω + η) − f (ω) − (∂f (w))−1 η∥2


∥η∥2
∥v − (∂f (w))−1 (f (w + v) − f (w))∥2
=
∥η∥2
∥f (w + v) − f (w) − ∂f (w)v∥2
≤4
∥u∥2
As ∥η∥2 → 0, we have ∥v∥2 → 0 and so ∥f (w+v)−f (w)−∂f (w)v∥2 /∥v∥2 →
0. Hence, f −1 is differentiable with JVP ∂f −1 (ω) = (∂f (w))−1 =
(∂f (f −1 (ω)))−1 . This shows that f −1 is continuous and so ∂f −1 (ω) is
continuous as a composition of continuous functions.

10.6 Summary

Implicit functions are functions that cannot be decomposed into ele-


mentary operations and for which autodiff can therefore not be directly
applied. Examples are optimization problems and nonlinear equations.
Envelope theorems can be used for differentiating through the min or
max value (not solution) of a function. More generally, the implicit func-
tion theorem allows us to differentiate through such implicit functions.
It gives conditions for the existence of derivatives and how to obtain
them. The adjoint state method can be used to obtain the gradient
of the composition of an explicit function and of an implicit function,
specified by equality constraints. It can be used to prove the correctness
of reverse-mode autodiff. The inverse function theorem can be used to
10.6. Summary 223

differentiate function inverses. In a sense, the implicit function theorem


can be thought as the mother theorem, as it can be used to prove
envelope theorems, the adjoint state method and the inverse function
theorem.
11
Differentiating through integration

In this chapter, we study how to differentiate through integrals, with a


focus on expectations and solutions of ordinary differential equations.

11.1 Differentiation under the integral sign

Given two Euclidean spaces Θ and Y, and a function f : Θ × Y → R,


we often want to differentiate an integral of the form
Z
F (θ) := f (θ, y)dy.
Y
Provided that we can swap integration and differentiation, we have
Z
∇F (θ) = ∇θ f (θ, y)dy.
Y
The conditions enabling us to do so are best examined in the context
of measure theory. We refer the reader to e.g. (Cohn, 2013) for a course
on measure theory and Flanders (1973) for an in-depth study of the
differentiation under the integral sign. Briefly, if Θ = Y = R, the
following conditions are sufficient.
1. f is measurable in both its arguments, and f (θ, ·) is integrable
for almost all θ ∈ Θ fixed,

224
11.2. Differentiating through expectations 225

2. f (·, y) is absolutely continuous for almost all y ∈ Y, that is,


there exists an integrable function g(·, y) such that f (θ, y) =
f (θ0 , y) + θθ0 g(τ, y)dτ ,
R

3. ∂1 f (θ, y) (which exists almost everywhere if f (·, y) is absolutely


continuous), is locally integrable, that is, for any closed interval
[θ0 , θ1 ], the integral θθ01 |∂1 f (θ, y)|dydθ is finite.
R R

Any differentiable function f : Θ × Y → R is absolutely continuous.


However, the conditions also hold if f is just absolutely continuous, that
is, if f (·, y) is differentiable for almost all y. This weaker assumption
can be used to smooth out differentiable almost-everywhere functions,
such as the ReLu, as we study in Section 13.4.

11.2 Differentiating through expectations

A special case of differentiating through integrals is differentiating


through expectations. Consider an expectation of the form

E(θ) := EY ∼pθ [g(Y )], (11.1)

where Y ∈ Y ⊆ RM is a random variable, distributed according to


a distribution pθ parameterized by θ ∈ Θ and where g : RM → R is,
depending on the setting, potentially a blackbox function (i.e., we do not
have access to its gradients). Typically, θ ∈ Θ could be parameters we
wish to estimate, or it could indirectly be generated by θ = f (x, w) ∈ Θ,
where f is a neural network with parameters w ∈ W we wish to estimate.
The operation of computing the expectation over Y removes Y from
the computational graph and, depending on the choice of pθ , it can
smooth out g. In this chapter, we review estimation techniques to
(approximately) compute ∇E(θ), allowing us to optimize θ (or w using
the chain rule) by gradient-based algorithms.
The main difficulty in computing ∇E(θ) stems from the fact that θ
are the parameters of the distribution pθ . Estimating an expectation
E(θ) = EY ∼pθ [g(Y )] using Monte-Carlo estimation requires us to sample
from pθ . However, it is not clear how to differentiate E w.r.t. θ if θ is
involved in the sampling process.
226 Differentiating through integration

11.2.1 The easy case


If instead, we consider an expected function where θ are the parameters
of E, not of p, the situation is very much different. Formally, consider
the function
θ 7→ EY ∼p [g(Y, θ)],
where p does not depend on θ. Then, under mild conditions recalled
in Section 11.1, we can swap differentiation and integration to obtain

∇θ EY ∼p [g(Y, θ)] = EY ∼p [∇θ g(Y, θ)].

Because the gradient is itself an expectation, we can easily obtain an


unbiased estimate of it by Monte-Carlo estimation. That is, we draw
m i.i.d. samples Y1 , . . . , Ym from p, compute ∇θ g(Yi , θ) for each Yi and
average the results.

11.2.2 Exact gradients


Continuous case
When Y is a continuous set (that is, pθ (y) is a probability density
function), we can rewrite Eq. (11.1) as
Z
E(θ) = pθ (y)g(y)dy.
Y

Provided that we can swap integration and differentiation (see Sec-


tion 11.1), we then have
Z
∇E(θ) = ∇θ pθ (y)g(y)dy
Y
Z
= ∇θ pθ (y)g(y)dy.
Y

Discrete case
When Y is a discrete set (that is, pθ (y) is a probability mass function),
we can rewrite Eq. (11.1) as

E(θ) = pθ (y)g(y).
X

y∈Y
11.2. Differentiating through expectations 227

We then obtain
∇E(θ) = g(y)∇θ pθ (y). (11.2)
X

y∈Y

We notice that, as written in Eq. (11.2), ∇E(θ) is not an expectation.


We therefore cannot use Monte-Carlo estimation to estimate (11.2).
Instead, we can compute it by brute force, i.e., by summing over all
possible y ∈ Y. However, this is clearly only computationally tractable
if |Y| is small or if pθ is designed to have sparse support, i.e., so that the
set {y ∈ Y : pθ (y) ̸= 0} is small. Moreover, even if these conditions hold,
summing over y could be problematic if g(y) is expensive to compute.
Therefore, exact gradients are seldom used in practice.

11.2.3 Application to expected loss functions

This framework is particularly useful to work with expected loss func-


tions of the form
L(θ; y) := EY ∼pθ [ℓ(Y, y)],

where y is some ground truth. Equivalently, we can set ℓ = −r, where


r is a reward function. As we shall see, some gradient estimators will
support a discrete loss function ℓ : Y × Y → R, while others will require
a differentiable loss function ℓ : RM × Y → R. Intuitively, L(θ; y) will
be low if pθ assigns high probability to predictions yb with low loss value
ℓ(y,
b y).
In the classification setting, where Y = [M ], pθ is often chosen to be
the Gibbs distribution, which is a categorical distribution induced
by a softargmax

exp(θy )
pθ (Y = y) := P = [softmax(θ)]y ∈ (0, 1),
i∈[M ] exp(θi )

where θy := f (x, y, w) ∈ R are logits produced by a neural network f .


More generally, in the structured prediction setting, where Y ⊆ RM but
|Y| ≫ M , we often use the distribution

exp(⟨ϕ(y), θ⟩)
pθ (Y = y) := P .
Y∈Y exp(⟨ϕ(y), θ⟩)
228 Differentiating through integration

where θ = f (x, w) ∈ RM . Given a distribution ρ over X × Y, we then


want to minimize the expected loss function, also known as risk,

R(w) := E(X,Y )∼ρ [L(f (X, w); Y )].

Typically, minimizing R(w) is done through some form of gradient


descent, which requires us to be able to compute

∇R(w) = E(X,Y )∼ρ [∇w L(f (X, w); Y )]


= E(X,Y )∼ρ [∂2 f (x, w)∗ ∇L(f (X, w); Y )].

Computing ∇R(w) therefore boils down to computing the gradient of


L(θ; y), which is of the same form as Eq. (11.1).

11.2.4 Application to experimental design


In experimental design, we wish to minimize a function g(λ), which we
assume costly to evaluate. As an example, evaluating g(λ) could require
us to run a scientific experiment with parameters λ ∈ RQ . As another
example, in hyperparameter optimization, evaluating g(λ) would require
us to run a learning algorithm with hyperparameters λ ∈ RQ . Instead
of solving the problem arg minλ∈RQ g(λ), we can lift the problem to
probability distributions and solve arg minθ∈RM E(θ), where E(θ) =
Eλ∼pθ [g(λ)]. This requires the probability distribution pθ to assign
high probability to λ values that achieve small g(λ) value. Solving this
problem by stochastic gradient descent requires us to be able to compute
estimates of ∇E(θ). This can be done for instance with SFE explained
in Section 11.3, which does not require gradients of g, unlike implicit
differentiation explained in Chapter 10. This approach also requires us
to choose a distribution pθ over λ. For continuous hyperparameters,
a natural choice would be the normal distribution λ ∼ Normal(µ, Σ),
setting θ = (µ, Σ). Once we obtained θ by minimizing E(θ), we need a
way to recover λ. This can be done for example by choosing the mode of
the distribution, i.e., arg maxλ∈RQ pθ (λ), or the mean of the distribution
Eλ∼pθ (λ) [λ]. Of course, in the case of the normal distribution, they
coincide.
11.3. Score function estimators, REINFORCE 229

11.3 Score function estimators, REINFORCE

11.3.1 Scalar-valued functions


The key idea of the score function estimator (SFE), also known as
REINFORCE, is to rewrite ∇E(θ) as an expectation. The estimator is
based on the logarithmic derivative identity
∇θ pθ (y)
∇θ log pθ (y) = ⇐⇒ ∇θ pθ (y) = pθ (y)∇θ log pθ (y).
pθ (y)
This is known as the score function, hence the estimator name. Using
this identity, we obtain the following gradient estimator.
Proposition 11.1 (SFE for scalar-valued functions). Let
Z
E(θ) := EY ∼pθ [g(Y )] = pθ (y)g(y)dy,
Y

where Y ∈ Y ⊆ RM and g : Y → R. Then,

∇E(θ) = EY ∼pθ [g(Y )∇θ log pθ (Y )].


Proof. Z
∇E(θ) = ∇θ pθ (y)g(y)dy
Y
Z
= pθ (y)g(y)∇θ log pθ (y)dy
Y
= EY ∼pθ [g(Y )∇θ log pθ (Y )].

SFE is suitable when two requirements are met: it is easy to sample


from pθ and the gradient of the log-PDF, ∇θ pθ (y), is available in closed
form. Since the SFE gradient is an expectation, we can use Monte-Carlo
estimation to compute an unbiased estimator of ∇E(θ):
n
1X
∇E(θ) ≈ γn (θ) := g(yi )∇θ log pθ (yi ), (11.3)
n i=1
where y1 , . . . , yn are sampled from pθ . Interestingly, the gradient of g
is not needed in this estimator. Therefore, there is no differentiability
assumption about g. This is why SFE is useful when g is a discrete loss
function or more generally a blackbox function.
230 Differentiating through integration

Example 11.1 (SFE with a language model). In a language model,


the probability of a sentence y = (y1 , . . . , yL ) is typically factored
using the chain rule of probability

pθ (y) := pθ (y1 )pθ (y2 |y1 ) . . . pθ (yL |y1 , . . . , yL−1 ),


where pθ is modeled using a transformer or RNN. Note that the
probabilities are normalized by construction, so there is no need for
an explicit normalization constant. Thanks to this factorization, it
is easy to sample from pθ and the log-probability enjoys the simple
expression

∇θ pθ (y) = ∇θ log pθ (y1 )+∇θ log pθ (y2 |y1 )+· · ·+∇θ log pθ (yL |y1 , . . . , yL−1 ).

This gradient is easy to compute, since the token-wise distribu-


tions pθ (yj |y1 , . . . , yj−1 ) are typically using a softargmax. We can
therefore easily compute ∇E(θ) under pθ using SFE. This is for
instance useful to optimize an expected reward, in order to finetune
or align a language model (Ziegler et al., 2019).

Another example when ∇θ pθ (y) is available in closed form is in the


context of reinforcement learning, where pθ (y) is a Markov Decision
Process (MDP) and is called the policy. Applying the SFE leads to the
(vanilla) policy gradient method (Sutton et al., 1999) and can then be
used to compute the gradient of an expected cumulative reward. How-
ever, SFE is more problematic when used with the Gibbs distribution,
due to the explicit normalization constant.

Example 11.2 (SFE with a Gibbs distribution). The Gibbs distribu-


tion is parameterized as

pθ (y) := exp(θy /γ − A(θ)) = exp(θy /γ)/ exp(A(θ))

where we defined the log-partition function

A(θ) := log exp(θy /γ).


X

y∈Y
11.3. Score function estimators, REINFORCE 231

A typical parametrization is θy = f (x, y, w). We then have

log pθ (y) = θy /γ − A(θ),

so that
∇θ log pθ (y) = ey /γ − ∇A(θ).
We therefore see that ∇θ log pθ (y) crucially depends on ∇A(θ),
the gradient of the log-partion. This gradient is available for some
structured sets Y but not in general.
As another example, we apply SFE in Section 13.4 to derive the
gradient of perturbed functions.

Differentiating through both the distribution and the function


Suppose both the distribution and the function now depend on θ. When
g is scalar-valued and differentiable w.r.t. θ, we want to differentiate
E(θ) := EY ∼pθ [g(Y, θ)].
Using the product rule, we obtain
∇E(θ) = EY ∼pθ [g(Y, θ)∇θ log pθ (Y )] + EY ∼pθ [∇θ g(Y, θ)].

Differentiating through joint distributions


Suppose we now want to differentiate through
E(θ) := EY1 ∼pθ ,Y2 ∼qθ [g(Y1 , Y2 )].
The gradient is then given by
∇E(θ) = EY1 ∼pθ ,Y2 ∼qθ [(∇θ log pθ (Y1 ) + ∇ log pθ (Y2 ))g(Y1 , Y2 )],
which is easily seen by applying Proposition 11.1 on the joint distribution
ρθ := pθ ·qθ . The extension to more than two variables is straightforward.

11.3.2 Variance reduction


Bias and variance
Recall the definition of γn in Eq. (11.3). SFE is an unbiased estimator,
meaning that
∇E(θ) = E[γn (θ)],
232 Differentiating through integration

where the expectation is taken with respect to the n samples drawn.


Since the gradient is vector-valued, we need to define a scalar-valued
notion of variance. We do so by using the squared Euclidean distance
in the usual variance definition to define

V[γn (θ)] := E[∥γn (θ) − ∇E(θ)∥22 ]


= E[∥γn (θ)∥22 ] − ∥∇E(θ)∥22 .

This is equivalent to the trace of the covariance matrix. The variance


goes to zero as n → ∞.

Baseline
SFE is known to suffer from high variance (Mohamed et al., 2020).
This means that this estimator may require us to draw many samples
from the distribution pθ to work well in practice. One of the simplest
variance reduction technique consists in shifting the function g with a
constant β, called a baseline, to obtain

∇E(θ) = EY ∼pθ [(g(Y ) − β)∇θ log pθ (Y )].

The reason this is still a valid estimator of ∇E(θ) stems from


∇θ pθ (Y )
 
EY ∼pθ [∇θ log pθ (Y )] = EY ∼pθ
pθ (Y )
= ∇θ EY ∼pθ [1]
= ∇θ 1
= 0,

for any valid distribution pθ . The baseline β is often set to the running
average of past values of the function g, though it is neither optimal
nor does it guarantee to lower the variance (Mohamed et al., 2020).

Control variates
Another general technique are control variates. Let us denote the
expectation of a function h : RM → R under the distribution pθ as

H(θ) := EY ∼pθ [h(Y )].


11.3. Score function estimators, REINFORCE 233

Suppose that H(θ) and its gradient ∇H(θ) are known in closed form.
Then, for any γ ≥ 0, we clearly have

E(θ) = EY ∼pθ [g(Y )]


= EY ∼pθ [g(Y ) − γ(h(Y ) − H(θ))]
= EY ∼pθ [g(Y ) − γh(Y )] + γH(θ)

and therefore

∇E(θ) = ∇θ EY ∼pθ [g(Y ) − γh(Y )] + γ∇H(θ).

Applying SFE, we then obtain

∇E(θ) = EY ∼pθ [(g(Y ) − γh(Y ))∇θ log pθ (Y )] + γ∇H(θ).

Examples of h include a bound on f or a second-order Taylor expansion


of f , assuming that these approximations are easier to integrate than f
(Mohamed et al., 2020).

11.3.3 Vector-valued functions

It is straightforward to extend the SFE to vector-valued functions.

Proposition 11.2 (SFE for vector-valued functions). Let


Z
E(θ) := EY ∼pθ [g(Y )] = pθ (y)g(y)dy,
Y

where Y ∈ Y, g : Y → G and θ ∈ Θ 7→ pθ (y) ∈ [0, 1]. Then,

∂E(θ) = EY ∼pθ [g(Y ) ⊗ ∇θ log pθ (Y )].

The Jacobian ∂E(θ) is a linear map from Θ to G such that

∂E(θ)u = EY ∼pθ [⟨∇θ log pθ (Y ), u⟩g(Y )] ∈ G ∀u ∈ Θ.


234 Differentiating through integration

Proof. Z
∂E(θ) = ∂θ pθ (y)g(y)dy
Y
Z
= ∂θ [pθ (y)g(y)]dy
Y
Z
= g(y) ⊗ ∇θ pθ (y)dy
Y
Z
= pθ (y)g(y) ⊗ ∇θ log pθ (y)dy
Y
= EY ∼pθ [g(Y ) ⊗ ∇θ log pθ (Y )].

Differentiating through both the distribution and the function


If θ now influences both the distribution and the function,
E(θ) := EY ∼pθ [g(Y, θ)],
then, we obtain
∂E(θ) = EY ∼pθ [g(Y, θ) ⊗ ∇θ log pθ (Y )] + EY ∼pθ [∂θ g(Y, θ)].

11.3.4 Second derivatives


Using the previous subsection with g(y, θ) = g(y)∇θ log pθ (θ), we easily
obtain an estimator of the Hessian.
Proposition 11.3 (SFE for the Hessian). Let us define the scalar-
valued function E(θ) := EY ∼pθ [g(Y )]. Then,

∇2 E(θ) =EY ∼pθ [g(Y )∇θ log pθ (Y ) ⊗ ∇θ log pθ (Y )]+


EY ∼pθ [g(Y )∇2θ log pθ (Y )].

This can also be derived using the second-order log-derivative


1 1
∇2θ log pθ (y) = ∇2θ pθ (y) − ∇θ pθ (y) ⊗ ∇θ pθ (y)
pθ (y) pθ (y)2
so that
h i
∇2θ pθ (y) = pθ (y) ∇2θ log pθ (y) + ∇θ log pθ (y) ⊗ ∇θ log pθ (y) .
11.4. Path gradient estimators, reparametrization trick 235

Link with the Bartlett identities


The Bartlett identities are expressions relating the moments of the score
function (gradient of the log-likelihood function). Using Proposition 11.1
with g(y) = 1 and Y pθ (y)dy = 1, we obtain
R

EY ∼pθ [∇θ log pθ (Y )] = 0, (11.4)

which is known as Bartlett’s first identity. Similarly, using Proposi-


tion 11.3, we obtain

EY ∼pθ [∇2θ log pθ (Y )] + EY ∼pθ [∇θ log pθ (Y ) ⊗ ∇θ log pθ (Y )]


=EY ∼pθ [∇2θ log pθ (Y )] + cov[log pθ (Y )] (11.5)
=0,

which is known as Bartlett’s second identity.

11.4 Path gradient estimators, reparametrization trick

As we saw previously, the main difficulty in computing gradients of


expectations arises when the parameters θ play a role in the distribution
pθ being sampled. The key idea of the path gradient estimator (PGE),
also known as reparametrization trick, is to rewrite the expectation in
such a way that the parameters are moved from the distribution to the
function, using a change of variable.

11.4.1 Location-scale transforms


For example, let us set pθ := Normal(µ, σ 2 ), where θ := (µ, σ) and
Normal is a univariate normal distribution. If U ∼ pθ , then it is easy
to check that U = µ + σZ, where Z ∼ Normal(0, 1). We can therefore
write
E(θ) := EU ∼Normal(µ,σ2 ) [g(U )]
= EZ∼Normal(0,1) [g(µ + σZ)]
= EZ∼Normal(0,1) [g(t(Z, θ))],
where we defined the location-scale transformation

U = T (Z, θ) := µ + σZ. (11.6)


236 Differentiating through integration

Such a transformation exists, not only for the normal distribution, but
for location-scale family distributions. The key advantage is that we
can now easily compute ∇E(θ), since θ is no longer involved in the
distribution. We can generalize this idea, as summarized below.

Proposition 11.4 (Path gradient estimator). Let us define

E(θ) := EU ∼pθ [g(U )],

where U ∈ U ⊆ RM and g : RM → R is differentiable. Suppose


there is a differentiable transformation T : RM × RQ → RM such
that if Z ∼ p (where p does not depend on θ) and U = T (Z, θ),
then U ∼ pθ . Then, we have

E(θ) = EZ∼p [h(Z, θ)] = EZ∼p [g(T (Z, θ))],

where h(z, θ) := g(T (z, θ)). This implies

∇E(θ) = EZ∼p [∇2 h(Z, θ)]


= EZ∼p [∂2 T (Z, θ)∗ ∇g(T (Z, θ))].

The reparametrization trick gives an unbiased estimator of ∇E(θ).


It has however two key disadvantages. First, it assumes that g is dif-
ferentiable (almost everywhere), which may not always be the case.
Second, it assumes that g is well-defined on RM , not on U, which could
be problematic for some discrete loss functions, such as the zero-one
loss function or ranking loss functions.

11.4.2 Inverse transforms

The inverse transform method can be used for sampling from a proba-
bility distribution, given access to its associated quantile function.
Recall that the cumulative distribution function (CDF) associated with
a random variable Y is the function FY : R → [0, 1] defined by

FY (y) := P(Y ≤ y).

The quantile function is then a function Q : [0, 1] → R such that Q(π) =


y for π = FY (y). Assuming FY is continuous and strictly increasing, we
11.4. Path gradient estimators, reparametrization trick 237

have that Q is the inverse CDF,

Q(π) = FY−1 (π).

In the general case of CDF functions that are not strictly increasing,
the quantile function is usually defined as

Q(π) := inf{y ∈ R : π ≤ FY (y)}.

Given access to the quantile function Q(π) associated with a distribution


p, inverse transform sampling allows us to sample from p by first drawing
a sample from the uniform distribution and then making this sample
go through the quantile function.

Proposition 11.5 (Inverse transform sampling). Suppose Y ∼ p,


where p is a distribution with quantile function Q(π). If π ∼
Uniform(0, 1), then Q(π) ∼ p.

If the quantile function is differentiable, we can therefore use it


as a transformation within the reparametrization trick. Indeed,
if Y ∼ pθ , where pθ is a distribution with parameter θ and quantile
function Q(π, θ), then we have

E(θ) = EY ∼pθ [g(Y )] = Eπ∼Uniform(0,1) [g(Q(π, θ))]

and therefore, by the reparametrization trick (Proposition 11.4),

∇E(θ) = Eπ∼Uniform(0,1) [∂2 Q(π, θ)∗ ∇g(Q(π, θ))].

Example 11.3 (Examples of quantile functions). If


Y ∼ Exponential(λ), the CDF of Y is π = FY (y) = 1 − exp(−λy)
for y ≥ 0 and therefore the quantile function is hQ(π, λ) = − log(1−π)
i λ
.
If Y ∼ Normal(µ, σ 2 ), the CDF is FY (y) = 12 1 + erf σy−µ

2
and

the quantile function is Q(π, θ) = µ + σ 2 · erf −1 (2π − 1), where
θ = (µ, σ). This therefore defines an alternative transformation to
the location-scale transformation in Eq. (11.6).

Note that, in the above example, the error function erf and its
inverse do not enjoy analytical expressions but autodiff packages usually
238 Differentiating through integration

provide numerical routines to compute them and differentiate through


them. Nonetheless, one caveat of the inverse transform is that it indeed
requires access to (approximations of) the quantile function and its
derivatives, which may be difficult for complicated distributions.

11.4.3 Pushforward operators


Pushforward distributions
We saw so far that the reparametrization trick is based on using a
change of variables in order to differentiate an expectation w.r.t. the
parameters of the distribution. In this section, we further formalize that
approach using pushforward distributions.

Definition 11.1 (Pushforward distribution). Suppose Z ∼ p, where


p is a distribution over Z. Given a continuous map T : Z → U,
the pushforward distribution of p through T is the distribution q
according to which U := T (Z) ∈ U is distributed, i.e., U ∼ q.

Although not explicit in the above, the transformation T can depend


on some learnable parameters, for example if T is a neural network.
Intuitively, the pushforward distribution is obtained by moving the
position of all the points in the support of p. Inverse transform sampling
studied in Section 11.4.2 can be seen as performing the pushforward
of the uniform distribution through T = Q, where Q is the quantile
function. The Gumbel trick studied in Section 13.5 can be seen as a the
pushforward of Gumbel noise through T = argmax (a discontinuous
function) and Gumbel noise can itself be obtained by pushing forward
the uniform distribution through T = − log(− log(·)) (Remark 13.2). In
a generative modeling setting, we use the pushforward of Gaussian noise
through a parametrized transformation x = T (z, w) called a generator,
typically a neural network.
A crucial aspect of the pushforward distribution q is that it can be
implicitly defined, meaning that we do not necessarily need to know
the explicit form of the associated PDF. In fact, it is easy to to sample
from q, provided that it is easy to sample from p:
U ∼ q ⇐⇒ Z ∼ p, U := T (Z).
11.4. Path gradient estimators, reparametrization trick 239

Hence the usefulness of the pushforward distribution in generative


modeling. Furthermore, if p has associated PDF pZ , we can compute
the expectation of a function f according to q as
Z
EU ∼q [f (U )] = EZ∼p [f (T (Z))] = f (T (z))pZ (z)dz,
Z

even though we do not know the explicit form of the PDF of q.

Pushforward measures

More generally, we can define the notion of pushforward, in the language


of measures. A measure α ∈ M(Z), that has a density dα(z) :=
pZ (z)dz, can be integrated against as
Z Z
f (z)dα(z) = f (z)pZ (z)dz.
Z Z

A measure α is called a probability measure if it is positive and satisfies


α(Z) = Z dα(z) = Z pZ (z)dz = 1. See Peyré and Cuturi (2019,
R R

Chapter 2) for a concise introduction.

Definition 11.2 (Pushforward operator and measure). Given a con-


tinuous map T : Z → U and some measure α ∈ M(Z), the push-
forward measure β = T♯ α ∈ M(U) is such that for all functions
f ∈ C(U) Z Z
f (u)dβ(u) = f (T (z))dα(z).
U Z
Equivalently, for any measurable set A ⊂ U, we have

β(A) = α({z ∈ Z : T (z) ∈ A}) = α(T −1 (A)).

Importantly, the pushforward operator preserves positivity and mass,


therefore if α is a probability measure, then so is T♯ α. The pushforward
of a probability measure therefore defines a pushforward distribution
(since a distribution can be parametrized by a probability measure).
240 Differentiating through integration

11.4.4 Change-of-variables theorem


Assuming a transformation T is invertible, we have Z = T −1 (U ) and
therefore for A ⊆ U, we have
Z
P(U ∈ A) = P(Z ∈ T −1 (A)) = pZ (z)dz.
T −1 (A)

Using the change-of-variables theorem from multivariate calculus,


assuming T −1 is available, we can give an explicit formula for the PDF
of the pushforward distribution.

Proposition 11.6 (PDF of the pushforward distribution). Suppose


Z ∼ p, where p is a distribution over Z, with PDF pZ . Given
a diffeomorphism T : Z → U (i.e., an invertible and differen-
tiable map), the pushforward distribution of p through T is the
distribution q such that U := T (Z) ∼ q and its PDF is

qU (u) = | det(∂T −1 (u))|pZ (T −1 (u)),

where ∂T −1 (u) is the Jacobian of T −1 : U → Z.

Using this formula, we obtain


Z
P(U ∈ A) = pU (u)du
ZA
= | det(∂T −1 (u))|pZ (T −1 (u))du.
A
Using the inverse function theorem (Theorem 10.5), we then have
∂T −1 (u) = (∂T (T −1 (u)))−1 ,
under the assumption that T (z) is continuously differentiable and
has invertible Jacobian ∂T (z). Normalizing flows are parametrized
transformations T designed such that T −1 and its Jacobian ∂T −1 are
easy to compute; see e.g. Kobyzev et al. (2019) and Papamakarios et al.
(2021) for a review.

11.5 Stochastic programs

A stochastic program is a program that involves some form of random-


ness. In a stochastic program, the final output, as well as intermediate
11.5. Stochastic programs 241

variables, may therefore be random variables. In other words, a stochas-


tic program induces a probability distribution over program outputs,
as well as over execution trajectories.

11.5.1 Stochastic computation graphs

A stochastic program can be represented by a stochastic computation


graph as originally introduced by Schulman et al. (2015). Departing from
that work, our exposition explicitly supports two types of intermediate
operations: sampling from a conditional distribution or evaluating a
function. These operations can produce either deterministic variables
or random variables.

Function and distribution nodes

Formally, we define a stochastic computation graph as a directed acyclic


graph G = (V, E), where V = Vf ∪ Vp , Vf is the set of function nodes
and Vp is the set of distribution nodes. Similarly to computation graphs
reviewed in Section 4.1.3, we number the nodes as V = {0, 1, . . . , K}.
Node 0 corresponds to the input s0 ∈ S0 , which we assume to be deter-
ministic. It is the variable with respect to which we wish to differentiate.
Node K corresponds to the program output SK ∈ SK , which we assume
to be a random variable. A node k ∈ {1, . . . , K} can either be a func-
tion node k ∈ Vf with an associated function fk or a distribution
node k ∈ Vp , with associated conditional distribution pk . A stochastic
program has at least one distribution node, the source of randomness.
Otherwise, it is a deterministic program. As for computation graphs,
the set of edges E is used to represent dependences between nodes. We
denotes the parents of node k by parents(k).

Deterministic and random variables

We distinguish between two types of intermediate variables: deter-


ministic variables sk and random variables Sk . Therefore, a distri-
bution pk or a function fk may receive both types of variables as
conditioning or input. It is then convenient to split parents(k) as
242 Differentiating through integration

parents(k) = determ(k) ∪ random(k), where we defined the determin-


istic parents determ(k) := {i1 , . . . , ipk } and the random parents
random(k) := {j1 , . . . , jqk }. Therefore, si1 , . . . , sipk are the determinis-
tic parent variables and Sj1 , . . . , Sjqk are the random parent variables,
of node k.

Executing a stochastic program


We assume that nodes 0, 1, . . . , K are in topological order (if this is
not the case, we need to perform a topological sort). Given parent
variables si1 , . . . , sipk and Sj1 , . . . , Sjqk , a node k ∈ {1, . . . , K} produces
an output as follows.

• If k ∈ Vp (distribution node), the output is

Sk ∼ pk (· | sdeterm(k) , Srandom(k) )
⇐⇒ Sk ∼ pk (· | si1 , . . . , sipk , Sj1 , . . . , Sjqk )

Note that technically pk is the distribution of Sk conditioned on its


parents, not the distribution of Sk . Therefore, we should in princi-
ple write Sk | sdeterm(k) , Srandom(k) ∼ pk (· | sdeterm(k) , Srandom(k) ).
We avoid this notation for conciseness and for symmetry with
function nodes.
Contrary to a function node, a distribution node can have no
parents. That is, if k ∈ Vp , it is possible that parents(k) = ∅. A
good example would be a parameter-free noise distribution.

• If k ∈ Vf (function node), the output is in general

Sk := fk (sdeterm(k) , Srandom(k) )
:= fk (si1 , . . . , sipk , Sj1 , . . . , Sjqk )

and in the special case qk = |random(k)| = 0, the output is

sk := fk (sdeterm(k) )
:= fk (si1 , . . . , sipk ).

Unless the associated conditional distribution pk is a delta distribution,


that puts all the probability mass on a single point, the output of a
11.5. Stochastic programs 243

Algorithm 11.1 Executing a stochastic program


Nodes: 1, . . . , K in topological order, where node k is either a
function fk or a conditional distribution pk
Input: input s0 ∈ S0
1: for k := 1, . . . , K do
2: Retrieve parents(k) = determ(k) ∪ random(k)
3: if k ∈ Vp then ▷ Distribution node
4: Sk ∼ pk (·|sdeterm(k) , Srandom(k) )
5: else if k ∈ Vf then ▷ Function node
6: if |random(k)| = ̸ 0 then
7: Sk := fk (sdeterm(k) , Srandom(k) ) ▷ Output is a R.V.
8: else if |random(k)| = 0 then
9: sk := fk (sdeterm(k) ) ▷ Output is deterministic
10: Output: f (s0 ) := SK ∈ SK

distribution node k ∈ Vp is necessarily a random variable Sk ∈ Sk .


For function nodes k ∈ Vf , the output of the function fk is a random
variable Sk ∈ Sk if at least one of the parents of k produces a random
variable. Otherwise, if all parents of k produce deterministic variables,
the output of fk is a deterministic variable sk ∈ Sk .
The entire procedure is summarized in Algorithm 11.1. We emphasize
that SK = f (s0 ) ∈ SK is a random variable. Therefore, a stochastic
program (implicitly) induces a distribution over SK , and also over
intermediate random variables Sk . Executing the stochastic program
allows us to draw samples from that distribution.

Special cases
If all nodes are function nodes, we recover computation graphs, reviewed
in Section 4.1.3. If all nodes are distribution nodes, we recover Bayesian
networks, reviewed in Section 9.5.

11.5.2 Examples
We now present several examples that illustrate our formalism. We use
the legend below in the following illustrations.
244 Differentiating through integration

Deterministic Stochastic
Function Sampler
variable variable

• Example 1 (SFE estimator):

S1 ∼ p1 (· | s0 )
S2 := f2 (S1 )
E(s0 ) := E[S2 ]
∇E(s0 ) = ES1 [f2 (S1 )∇s0 log p1 (S1 | s0 )]

• Example 2 (Pathwise estimator):

S1 ∼ p1
S2 := f2 (S1 , s0 )
E(s0 ) := E[S2 ]
∇E(s0 ) = ES1 [∇s0 f2 (S1 , s0 )]

• Example 3 (SFE estimator + chain rule):

s1 := f1 (s0 )
S2 ∼ p2 (· | s1 )
S3 := f3 (S2 )
E(s0 ) := E[S3 ]
∇E(s0 ) = ∂f (s0 )∗ ES2 [f3 (S2 )∇s1 log p2 (S2 | s1 )]
11.5. Stochastic programs 245

• Example 4:

s1 := f1 (s0 )
s2 := f2 (s0 )
S3 ∼ p3 (· | s1 )
S4 ∼ p4 (· | s2 , S3 )
S5 := f5 (S4 )
E(s0 ) := E[S5 ] = ES3 [ES4 [f5 (S4 )]]
∇E(s0 ) = (ES3 [∇s1 log p(S3 | s1 )ES4 [f5 (S4 )]] ,
ES3 [ES4 [f5 (S4 )∇s2 log p4 (S4 |s2 , S3 )]])

As can be seen, the gradient expressions can quickly become quite


complicated, demonstrating the merits of automatic differentiation in
stochastic computation graphs.

11.5.3 Unbiased gradient estimators


The output of a stochastic program is a random variable

SK := f (s0 ).

It implicitly defines a probability distribution p(·|s0 ) such that SK ∼


p(·|s0 ). Executing the stochastic program once gives us an i.i.d. sample
from p(·|s0 ).
Since derivatives are defined for deterministic variables, we need a
way to convert a random variable to a deterministic variable. One way
246 Differentiating through integration

to do so is to consider the expected value (another way would be the


mode)
E(s0 ) := E[SK ] = E[f (s0 )] ∈ conv(SK ),
where the expectation is over SK ∼ p(·|s0 ) or equivalently over the
intermediate random variables Sk

Sk ∼ pk (·|sdeterm(k) , Srandom(k) ),

for k ∈ Vp (the distribution nodes). We then wish to compute the


gradient or more generally the Jacobian of E(s0 ).
If all nodes in the stochastic computation graph are function nodes,
we can estimate the gradient of E(s0 ) using the pathwise estimator
a.k.a. reparametrization trick (Section 11.4). This is the approach taken
by Kingma and Welling (2013) and Rezende et al. (2014).
If all nodes in the stochastic computation graph are distribution
nodes, we can use the SFE estimator (Section 11.3). Schulman et al.
(2015) propose a surrogate loss so that using autodiff on that loss
produces an unbiased gradient of the expectation, using the SFE esti-
mator. Foerster et al. (2018) extend the approach to support high-order
differentiation. Krieken et al. (2021) further extend the approach by
supporting different estimators per node, as well as control variates.

Converting distribution nodes into function nodes and vice-versa


Our formalism uses two types of nodes: distribution nodes with asso-
ciated conditional distribution pk and function nodes with associated
function fk . It is often possible to convert between node types.
Converting a distribution node into a function node is exactly the
reparametrization trick studied in Section 11.4. We can use transforma-
tions such as the location-scale transform or the inverse transform.
Converting a function node into a distribution node can be done
using the change-of-variables theorem, studied in Section 11.4.4, on a
pushforward disstribution.
Because the pathwise estimator has lower variance than SFE, this is
the method of choice when the fk functions are available. The conversion
from distribution node to function node and vice-versa is illustrated in
Fig. 11.1.
11.5. Stochastic programs 247

Transformation
(location-scale transform, inverse transform)

Pushforward (implicit) distribution

Parametric (explicit) distribution

Score Function Estimator Path Gradient Estimator


(SFE) (PGE)

Change-of-variables theorem

Figure 11.1: It is sometimes possible to convert a distribution node to a function


node and vice-versa using a suitable transformation.

11.5.4 Local vs. global expectations

A stochastic computation graph can be seen as a stochastic process,


a collection of random variables Sk , indexed by k, the position in the
topological order. However, random variables are incompatible with
autodiff. Replacing random variables by their expectation can be seen
as a way to make them compatible with autodiff. Two strategies are
then possible.
As we saw in the previous section, a strategy is to consider the
expectation of the last output SK . This strategy corresponds to a
global smoothing. The two major advantages are that i) we do not
need to assume that fk+1 is well-defined on conv(Sk ) and ii) this induces
a probability distribution over program executions. This is for instance
useful to compute the variance of the program. The gradient of the
program’s expected value can be estimated by the reparametrization
trick or by the SFE, depending on the type of nodes used.
A second strategy is to replace an intermediate random variable
Sk ∈ Sk , for k ∈ {1, . . . , K}, by its expectation E[Sk ] ∈ conv(Sk ). This
strategy corresponds to a local smoothing. A potential drawback of
this approach is that E[Sk ] belongs to conv(Sk ), the convex hull of Sk .
248 Differentiating through integration

Therefore, the function fk+1 in which E[Sk ] is fed must be well-defined


on conv(Sk ), which may not always be the case. In the case of control
flows, another disadvantage is computational. We saw in Section 5.6 and
Section 5.7 that using a soft comparison operator within a conditional
statement induces a distribution on a binary or categorical random
variable, corresponding to the branch to be selected. A conditional
statement can then be locally smoothed out by replacing the random
variable by its expectation i.e., a convex combination of all the
branches. This means that, unless the distribution has sparse support,
all branches must be evaluated.

11.6 Differential equations

11.6.1 Parameterized differential equations


From residual networks to neural ODEs
Starting from s0 := x, residual networks, reviewed in Section 4.5, iterate
for k ∈ {1, . . . , K}

sk := sk−1 + hk (sk−1 , wk ).

A residual network can be seen as parameterizing incremental discrete-


time input changes (hence the name “residual”)

sk − sk−1 = hk (sk , wk ).

Chen et al. (2018) proposed to parameterize continuous-time (instan-


taneous) changes instead. They considered the evolution s(t) of the
inputs in continuous time driven by a function h(t, s, w) parameterized
by w, starting from x. Formally, the evolution s(t) is the solution of
the ordinary differential equation (ODE)

s(0) = x
s′ (t) = h(t, s(t), w) t ∈ [0, T ] (11.7)

Here, s′ (t) is the vector of derivatives of s as defined in Remark 2.3, and


T denotes a final time for the trajectory. The output of such a neural
ODE (Chen et al., 2018) is then f (x, w) := s(T ). Alternatively, the
11.6. Differential equations 249

output can be seen as the solution of an integration problem


Z T
f (x, w) = s(T ) = x + h(t, s(t), w)dt. (11.8)
0

Differential equations like Eq. (11.7) arise in many contexts beyond neu-
ral ODEs, ranging from modeling physical systems to pandemics (Braun
and Golubitsky, 1983). Moreover, the differential equation presented
in Eq. (11.7) is just an example of an ordinary differential equation,
while controlled differential equations or stochastic differential equations
can also be considered.

Existence of a solution

First and foremost, the question is whether s(t) is well-defined. For-


tunately, the answer is positive under mild conditions, as shown by
Picard-Lindelöf’s theorem recalled below (Butcher, 2016, Theorem 16).

Theorem 11.1 (Exsistence and uniqueness of ODE solutions). If h :


[0, T ] × S → S is continuous in its first variable and Lipschitz-
continuous in its second variable, then there exists a unique differ-
entiable map s : [0, T ] → S satisfying

s(0) = s0
s′ (t) = h(t, s(t)) t ∈ [0, T ],

for some given s0 ∈ S.

For time-independent linear functions h(t, s) = As, the integral


in Eq. (11.8) can be computed in closed form as

st = exp(tA)(s0 ),

where exp(A) is the matrix exponential. Hence, the output s(T ) can
be expressed as a simple function of the parameters (A in this case).
However, generally, we do not have access to such analytical solutions,
and, just as for solving optimization problems in Chapter 10, we need
to resort to some iterative algorithms.
250 Differentiating through integration

Integration methods
To numerically solve an ODE, we can use integration methods,
whose goal is to build a sequence sk that approximates the solution
s(t) at times tk . The simplest integration method is the explicit Euler
method, that approximates the solutions between times tk−1 and tk as
Z tk
s(tk−1 ) − s(tk ) = h(t, s(t), w)dt
tk−1

≈ δk h(tk−1 , s(tk−1 ), w),

for a time-step
δk := tk − tk−1 .
The resulting integration scheme consists in computing starting from
s0 = x, for k ∈ {1, . . . , K},

sk := sk−1 + δk h(tk−1 , sk−1 , w).

Assimilating δk h(tk−1 , sk−1 , w) with hk (sk−1 , wk ), we find that residual


networks are essentially the discretization of a neural ODE by an
explicit Euler method; more precisely, a non-autonomous neural ODEs,
see e.g. (Davis et al., 2020).
Euler’s forward method is only one integration method among
many. To cite a few, there are implicit Euler methods, semi-implicit
methods, Runge-Kutta methods, linear multistep methods, etc. See,
e.g., Gautschi (2011) for a detailed review. The quality of an integration
method is measured by its consistency and its stability (Gautschi, 2011).
These concepts naturally influence the development of evaluation and
differentiation techniques for ODEs. We briefly summarize them below.
Given a fixed time interval δk = δ and K = ⌈T /δ⌉ points, an
integration method is consistent of order k if ∥sk − s(kδ)∥ = O(δ k )
as δ → 0 and therefore k → +∞. The higher the order k, the fewer points
we need to reach an approximation error ε on the points considered.
The term ∥sk − s(kδ)∥ = O(δ k ) is reminiscent of the error encountered
in finite differences (Chapter 6) and is called the truncation error.
The (absolute) stability of a method is defined by the set of time-
steps such that the integration method can integrate s′ (t) = λs(t) for
some λ ∈ C without blowing up as t → +∞.
11.6. Differential equations 251

11.6.2 Continuous adjoint method


Since different parameters w induce different trajectories associated to
h(t, s, w) in Eq. (11.7), we may want to select one of these trajectories
by minimizing some criterion. For example, we may consider selecting
w ∈ W by minimizing a loss L on the final point of the trajectory,

min L(f (x, w), y), (11.9)


w∈W

where Z T
f (x, w) := s(T ) = x + h(t, s(t), w)dt.
0
To solve such problems, we need to access gradients of ℓ composed with
f through VJPs of the solution of the ODE. The VJPs can actually
be characterized as solutions of an ODE themselves thanks to the
continuous time adjoint method (Pontryagin, 1985), presented
below, and whose proof is postponed to Section 11.6.6.

Proposition 11.7 (Continuous-time adjoint method). Consider a func-


tion h : [0, T ]×S ×W → S, continuous in its first variable, Lipschitz-
continuous and continuously differentiable in its second variable.
Assume that ∂3 h(t, s, w) exists for any t, s, w, and is also continu-
ous in its first variable, Lipschitz-continuous in its second variable.
Denote s : S → S the solution of the ODE

s(0) = x
s′ (t) = h(t, s(t), w) t ∈ [0, T ],

and f (x, w) = s(T ) the final state of the ODE at time T .


Then, the function f is differentiable, and for an output direction
u ∈ S, its VJP along u is given by

∂f (x, w)∗ u = (r(0), g)

for Z T
g= ∂3 h(t, s(t), w)∗ r(t)dt
0
252 Differentiating through integration

and for r solving the adjoint (backward) ODE

r ′ (t) = −∂2 h(t, s(t), w)∗ r(t)


r(T ) = u.

In particular, the gradient ∇(L ◦ f )(x, w) for L : S → R a


differentiable loss is obtained by solving the adjoint ODE with
r(T ) = ∇L(s(T )).

11.6.3 Gradients via the continuous adjoint method

Proposition 11.7 gives a formal definition of the gradient. However, just


as computing the mapping f (x, w) itself, computing its VJP or the
gradient of L ◦ f requires solving an integration problem. Note that
the integration of r(t) in Proposition 11.7 requires also values of s(t).
Therefore, we need to integrate both r(t) and s(t). Such an approach
is generally referred as optimize-then-discretize because we first
formulate the gradient in continuous time (the “optimize part”) and
then discretize the resulting ODE.

Simple discretization scheme

A first approach consists in defining a backward discretization scheme


that can approximate s(t) backward in time. Namely, by defining
σ(t) = s(T − t), ρ(t) = r(T − t), and γ(t) = tT ∂3 h(τ, s(τ ), w)∗ r(τ )dτ ,
R

the derivative of L ◦ f is given by (ρ(T ), γ(T )). The functions σ, ρ, γ


are solutions of a standard ODE

σ(0) = s(T ), σ ′ (t) = −h(T − t, σ(t), w),


ρ(0) = ∇L(s(T )), ρ′ (t) = ∂2 h(T − t, σ(t), w)∗ ρ(t),
γ(0) = 0, γ ′ (t) = ∂3 h(T − t, σ(t), w)∗ ρ(t).

The above ODE can then be solved by any integration method. Note,
however, that it requires first computing s(T ) and ∇L(s(T )) by an
integration method. The overall computation of the gradient using
an explicit Euler method to solve forward and backward ODEs is
summarized in Algorithm 11.2.
11.6. Differential equations 253

Algorithm 11.2 Gradient computation via continuous adjoint method


with Euler explicit discretization
1: Functions: h : [0, T ] × S × W → R, L : S → R
2: Inputs: input x, parameters w, number of discretization steps K.
3: Set discretization step δ = T /K, denote hk (s, w) = h(kδ, s, w).
4: Set s0 := x
5: for k := 1, . . . , K do ▷ Forward discretization
6: Compute sk := sk−1 + δhk−1 (sk−1 , w).
7: Compute u := ∇L(sK ).
8: Initialize rK := u, ŝK = sK , gK = 0
9: for k := K, . . . , 1 do ▷ Backward discretization
10: Compute ŝk−1 := ŝk − δhk (ŝk , w)
11: Compute rk−1 := rk + δ∂2 hk (ŝk , w)∗ rk
12: Compute gk−1 := gk + δ∂3 hk (ŝk , w)∗ rk
13: Output: (r0 , g0 ) ≈ ∇(L ◦ f )(x, w)

Algorithm 11.2 naturally looks like the reverse mode of autodiff for
a residual neural networks with shared weights. A striking difference
is that the intermediate computations sk are not kept in memory and,
instead, new variables ŝk are computed along the backward ODE. One
may believe that by switching to continuous time, we solved the memory
issues encountered in reverse-mode autodiff. Unfortunately, this comes
at the cost of numerical stability. As we use a discretization scheme
to recompute the intermediate states backward in time through ŝk in
Algorithm 11.2, we accumulate some truncation errors.

To understand the issue here, consider applying Algorithm 11.2


repeatedly on the same parameters but using ŝ0 instead of s0 = x each
time. In the continuous realm, σ(T ) = s(0). But after discretization,
ŝ0 ≈ σ(T ) does not match s0 . Therefore, by applying Algorithm 11.2
with s0 = ŝ0 , we would not get the same output even if in continuous
time we naturally should have. This thought exercise illustrates that
Algorithm 11.2 induces some noise in the estimation of the gradient.
254 Differentiating through integration

Multiple shooting scheme


An alternative approach consists in integrating both the forward and
backward ODEs jointly. Namely, we may solve an ODE with boundary
values

s′ (t) = h(t, s(t), w), s(0) = x,


r (t) = −∂2 h(t, s(t), w)∗ r(t),

r(T ) = ∇L(s(T ))
′ ∗
g (t) = −∂3 h(t, s(t), w) r(t), g(T ) = 0,

by means of a multiple shooting method or a collocation method (Stoer


et al., 1980). This approach still requires ∇L(s(T )) to be approximated
first.

11.6.4 Gradients via reverse-mode on discretization


A simpler approach consists in replacing the objective in Eq. (11.9) by
its version discretized using some numerical method, such as an Euler
forward discretization scheme. That is, we seek to solve

min L(sK ) where sk = sk−1 + δh(kδ, sk−1 , w) k ∈ {1, . . . , K},


w∈W

with s0 = 0 and δ some discretization step. Gradients of the objective


can be computed by automatic differentiation. That approach is often
referred to as discretize-then-optimize. At first glance, this approach
may suffer from very high memory requirements. Indeed, to get an
accurate solution of the ODE, a numerical integration method may
require K to be very large. Since a naive implementation of reverse-mode
automatic differentiation has a memory that scales linearly with K,
computing the gradient by a discretize-then-optimize method could be
prohibitive. However, the memory requirements may easily be amortized
using checkpointing, as explained in Section 7.5; see also (Gholaminejad
et al., 2019).
As for the optimize-then-discretize method, we still accumulate
some truncation errors in the forward discretization process. This dis-
cretization error occurs when computing the gradient in reverse-mode
too. The discretize-then-optimize method can be seen as computing
gradients of a surrogate objective. For that objective, the gradients are
11.6. Differential equations 255

correct and well-defined. However, they may not match the gradients of
the true ODE formulation.
To compare the discretize-then-optimize and optimize-then-discretize
approaches, Gholaminejad et al. (2019) compared their performance on
an ODE whose solution can be computed analytically by selecting h
to be linear in s. The authors observed that discretize-then-optimize
generally outperformed optimize-then-discretize. A middle ground can
actually be found by using reversible differentiation schemes.

11.6.5 Reversible discretization schemes


Our exposition of the optimize-then-discretize or discretize-then-optimize
approaches used a simple Euler explicit discretization scheme. However,
for both approaches, we could have used other discretization schemes
instead, such as reversible discretization schemes.
A reversible discretization scheme is a discretization scheme such
that we have access to a closed-form formula for the inverse of its
discretization step. Formally, a discretization method M builds an
approximation (sk )K k=1 of the solution of an ODE s (t) = h(t, s(t)) on

an interval [0, T ] by computing for k ∈ (1, . . . , K)


tk , sk , ck = M(tk−1 , sk−1 , ck−1 ; h, δ), (11.10)
where δ > 0 is some fixed discretization step, tk is the time step (typically
tk = tk−1 +δ), sk is the approximation of s(tk ), and ck is some additional
context variables used by the discretization method to build the iterates.
An explicit Euler method does not have a context, but just as an
optimization method may update some internal states, a discretization
method can update some context variable. The discretization scheme
in Eq. (11.10) is a forward discretization scheme as we took a positive
discretization step. By taking a negative discretization step, we obtain
the corresponding backward discretization scheme, for k ∈ (K, . . . , 1),
tk−1 , sk−1 , ck−1 = M(tk , sk , ck ; h, −δ).
A discretization method is reversible if we have access to M−1 to
recompute the inputs of the discretization step from its outputs,
tk−1 , sk−1 , ck = M−1 (tk , sk , ck ; h, δ).
256 Differentiating through integration

A reversible discretization method is symmetric if the backward dis-


cretization scheme is exactly the inverse of the forward discretization
scheme, i.e.,
M(tk , sk , ck ; h, −δ) = M−1 (tk , sk , ck ; h, δ).
The explicit Euler method is clearly not symmetric and a priori not
reversible, unless we can solve for yk−1 , the equation yk = yk−1 −
δf (yk−1 ).

Leapfrog method
The (asynchronous) leapfrog method (Zhuang et al., 2021; Mutze,
2013) on the other hand is an example of symmetric reversible discretiza-
tion method. For a constant discretization step δ, given tk−1 , sk−1 , ck−1
and a function h, it computes
δ
t̄k−1 := tk−1 +
2
δ
s̄k−1 := sk−1 + ck−1
2
c̄k−1 := h(t̄k−1 , s̄k−1 )
δ
tk := t̄k−1 +
2
δ
sk := s̄k−1 + c̄k−1
2
ck := 2c̄k−1 − ck−1
M(tk−1 , sk−1 , ck−1 ; h, δ) := (tk , sk , ck ).
One can verify that we indeed have M(tk , sk , ck ; h, −δ) = (tk−1 , sk , ck ).
By using a reversible symmetric discretization scheme in the optimize-
then-discretize approach, we ensure that, at the end of the backward
discretization pass, we recover exactly the original input. Therefore, by
repeating forward and backward discretization schemes we always get
the same gradient, which was not the case for an Euler explicit scheme.
By using a reversible discretization scheme in the discretize-then-
optimize method, we address the memory issues of reverse mode autodiff.
As explained in Section 7.6, we can recompute intermediate values during
the backward pass rather than storing them.
11.6. Differential equations 257

Momentum residual networks

In the leapfrog method, the additional variables ck may actually be


interpreted as velocities of a system whose acceleration is driven by
the given function, that is, s′′ (t) = h(t, s(t), w). Such an interpretation
suggests alternatives to the usual neural ODE paradygm. For instance,
momentum neural networks (Sander et al., 2021b), can be inter-
preted as the discretization of a second-order ordinary differential
equations, which are naturally amenable to reversible differentiation
schemes with a low memory footprint.

11.6.6 Proof of the continuous adjoint method

In the following, we denote s(t, x, w) the solution of the ODE at time t


given the input x and the parameters w. We focus here on the formula-
tion of the VJP. The proof relies on the existence of partial derivatives
of s(t, x, w), which we do not cover here and refer to, e.g., Pontryagin
(1985) for a complete proof of such facts given the assumptions.

We use the ODE constraint to introduce adjoint variables, this time


in the form of a continuously differentiable function r. For any such a
function r, we have

⟨f (x, w), u⟩ = ⟨s(T, x, w), u⟩


Z T
+ ⟨r(t), h(t, s(t, x, w), w) − ∂t s(t, x, w)⟩dt,
0

using Leibniz notations such as ∂t s(t, x, w) = ∂1 s(t, x, w). The VJPs


258 Differentiating through integration

then decompose as
∂w f (x, w)∗ [u]
= ∂w s(T, x, w)∗ u
Z T
+ (∂w s(t, x, w)∗ ∂s∗ h(t, s(t, x, w), w)∗ − ∂wt
2
s(t, x, w)∗ )r(t)dt
0
Z T
+ ∂w h(t, s(t, x, w), w)∗ r(t)dt,
0
∂x f (x, w)∗ [u]
= ∂x s(T, x, w)∗ u
Z T
+ (∂x s(t, x, w)∗ ∂s∗ h(t, s(t, x, w), w)∗ − ∂xt
2
s(t, x, w)∗ )r(t)dt
0

Here the second derivative terms ∂wt 2 s(t, x, w)∗ r, ∂ 2 s(t, x, w)∗ r cor-
xt
respond to second derivatives of ⟨s(t, x, w), r⟩. Since the Hessian is
symmetric (Schwartz’s theorem presented in Proposition 2.10), we can
swap the derivatives in t and w or x. Then, to express the gradient
uniquely in terms of first derivatives of s, we use an integration by part
to have for example
Z T Z T
2
∂wt s(t, x, w)∗ r(t)dt = 2
∂tw s(t, x, w)∗ r(t)dt
0 0
= (∂w s(T, x, w)∗ r(T ) − ∂w s(0, x, w)∗ r(0))
Z T
− ∂w s(t, x, w)∗ r(t)∗ ∂t r(t)dt.
0
Since s(0) = x, we have ∂w s(0, x, w)∗ r(0) = 0. The VJP w.r.t. w can
then be written as
∂w f (x, w)∗ [u]
= ∂w s(T, x, w)∗ [u − r(T )]
Z T
+ ∂w s(t, x, w)∗ [∂s h(t, s(t, x, w), w)∗ r(t) + ∂t r(t)]dt
0
Z T
+ ∂w h(t, s(t, x, w), w)∗ r(t)dt.
0
By choosing r(t) to satisfy the adjoint ODE
∂t r(t) = −∂s h(t, s(t, x, w), w)∗ r(t), r(T ) = u,
11.7. Summary 259

the expression of the VJP simplifies as


Z T
∂w f (x, w)∗ [u] = ∂w h(t, s(t, x, w), w)∗ r(t)dt.
0

For the VJP w.r.t. to x, we can proceed similarly. Using an integration


by part, we have, this time, ∂x s(0, x, w)∗ r(0) = r(0) since s(0) = x.
Choosing the same curve r(t) satisfying the adjoint ODE we get

∂x f (x, w)∗ [u] = r(0).

The existence of a curve r solution of the backward ODE can easily be


shown from Picard Lindelöf’s theorem and the assumptions.

11.7 Summary

In this chapter, we studied how to differentiate integrals, with a focus


on expectations and solutions of a differential equation.
For differentiating through expectations, we studied two main meth-
ods: the score function estimator (SFE, a.k.a. REINFORCE) and the
path gradient estimator (PGE, a.k.a. reparametrization trick). The
SFE is suitable when it is easy to sample from the distribution and
its log-PDF is explicitly available. It is an unbiased estimator, but is
known to suffer from high variance. The PGE is suitable for pushfor-
ward distributions, distributions that are implicitly defined through
a transformation, or a sequence of them. These distributions can be
easily sampled from, by injecting a source of randomness (such as noise)
through the transformations. An unbiased, low-variance estimator of
the gradient of their expectation is easily obtained, provided that we
can interchange integration and differentiation.
If we have an explicit distribution, we can sometimes convert it to
an implicit distribution, thanks to the location-scale transformation
or the inverse transformation. Conversely, if we have an implicit
distribution, we can convert it to an explicit distribution using the
change-of-variables theorem. However, this formula requires to
compute the determinant of an inverse Jacobian, and is computationally
expensive in general. Normalizing flows use invertible transformations
so that the inverse Jacobian is cheap to compute, by design. We saw
260 Differentiating through integration

that stochastic computation graphs can use a mix of explicit and


implicit distributions at each node.
For differentiating through the solution of a differential equation,
two approaches can be considered. We can express the gradient as the
solution of a differential equation thanks to the continuous adjoint
method. We may then discretize backwards in time the differential equa-
tion that the gradient satisfies. This is the optimize-then-discretize
approach. We can also first discretize the problem in such a way that
the gradient can simply be computed by reverse mode auto-diff, applied
on the discretization steps. This is the discretize-then-optimize ap-
proach. The optimize-then-discretize approach has no memory cost, but
discrepancies between the forward and backward discretization passes
often lead to numerical errors. The discretize-then-optimize introduces
no such discrepancies but may come at a large memory cost. Reversible
discretization schemes can circumvent the memory cost, as they
enable the recomputation of intermediate discretization steps backwards
in time.
Part IV

Smoothing programs
12
Smoothing by optimization

In this chapter, we review smoothing by infimal convolution.

12.1 Primal approach

We first review how to smooth functions in the primal space.

12.1.1 Infimal convolution


The infimal convolution, sometimes abbreviated inf-conv, is defined as
follows.
Definition 12.1 (Infimal convolution). The infimal convolution be-
tween two functions f : RM → R and g : RM → R is defined by

(f □g)(µ) := inf f (u) + g(µ − u)


u∈RM
:= inf f (µ + z) + g(z)
z∈RM
= inf f (u) + g(z) s.t. u = µ + z
u,z∈RM

Some examples are given in Table 12.1. As can be seen from the
above, the infimal convolution, like the classical convolution, is com-

262
12.1. Primal approach 263

Table 12.1: Examples of infimal convolutions.

f1 f2 f1 □f2
f 0 inf u∈RM f (u)
ιC ∥ · ∥2 dC := inf u∈C ∥ · −u∥2
f 1 2
2 ∥ · ∥2 Mf := inf u∈RM ∥ · −u∥22 + f (u)
ιC ιD ιC+D
f ι{v} f (· − v)
f ⟨·, v⟩ ⟨·, v⟩ − f ∗ (v)

mutative,
(f □g)(µ) = (g□f )(µ).

12.1.2 Moreau envelope


When g(z) is set to a quadratic regularization R(z) := 12 ∥z∥22 , the
infimal convolution f □R gives the so-called Moreau envelope of f ,
also known as Moreau-Yoshida regularization of f .

Definition 12.2 (Moreau envelope). Given a function f : RM → R,


its Moreau envelope is given by

Mf (µ) := (f □R)(µ)
1
= inf f (u) + ∥µ − u∥22
u∈R M 2
1
= inf f (µ + z) + ∥z∥22 .
z∈RM 2

Note that µ and u belong to the same primal space. Intuitively,


the Moreau envelope is the minimal value over u ∈ RM of a trade-
off between staying close to the input µ according to the proximity
term 12 ∥µ − u∥22 and f (u). As a comparison, the proximal operator,
reviewed in Definition 15.3, returns the solution, instead of the value:
1
proxf (µ) := arg min ∥µ − u∥22 + f (u).
u∈R M 2
264 Smoothing by optimization

That is, we have


1
Mf (µ) = ∥µ − proxf (µ)∥22 + f (proxf (µ)). (12.1)
2
We now state useful properties of the Moreau envelope.

Proposition 12.1 (Properties of Moreau envelope). Let f : RM →


R.

1. Gradient:
∇Mf (µ) = µ − proxf (µ).
If f is convex, then

∇Mf ∗ (µ) = proxf (µ),

where f ∗ is the convex conjugate of f .

2. Same minimum as original function:

min Mf (µ) = min f (u).


µ∈RM u∈RM

Proof.

1. The first claim follows from Danskin’s theorem, reviewed in Sec-


tion 10.2. The second claim is Moreau’s decomposition.

2. Provided that the Moreau envelope exists, we have


1
min Mf (µ) = min min ∥µ − u∥22 + f (u)
µ∈RM µ∈RM u∈RM 2
1
= min min ∥µ − u∥22 + f (u)
u∈RM µ∈RM 2

= min f (u).
u∈RM

To illustrate smoothing from the Moreau envelope perspective, we


show how to smooth the 1-norm.
12.2. Legendre–Fenchel transforms, convex conjugates 265

Example 12.1 (Smoothing the 1-norm via infimal convolution). We


wish to smooth f (u) = ∥u∥1 . The corresponding proximal operator
is the soft-thresholding operator (see Section 15.4)
1
proxf (µ) = arg min ∥µ − u∥22 + ∥u∥1
u∈RM 2
= sign(µ) · max(|µ| − 1, 0).

Using Eq. (12.1) and after some algebraic manipulations, we obtain


 2
 µj if |µj | ≤ 1
[Mf (µ)]j = 2 .
|µ | − 1
j 2 if |µj | > 1

That is, we recover the Huber loss.

12.2 Legendre–Fenchel transforms, convex conjugates

12.2.1 Definition
Consider affine functions of the form

u 7→ ⟨u, v⟩ − b.

These functions are parametrized by their slope v ∈ RD and their


intercept −b ∈ R. Now, suppose we fix v. Given a function f (u), affine
lower bounds of f (u) are all the functions of u such that b satisfies

⟨u, v⟩ − b ≤ f (u) ⇔ ⟨u, v⟩ − f (u) ≤ b.

The tightest lower bound is then the function such that b is defined by

b := sup ⟨u, v⟩ − f (u),


u∈dom(f )

where we recall that the domain of f is defined by

dom(f ) := {u ∈ RM : f (u) < ∞}.

This leads to the definition of Legendre-Fenchel transform, a.k.a. ,


convex conjugate.
266 Smoothing by optimization

Tightest affine function with


slope

Affine functions with slope

Figure 12.1: For a fixed slope v, the function u 7→ uv − f ∗ (v) is the tighest affine
lower bound of f .

Definition 12.3 (Convex conjugate). Given a function f : RM →


R ∪ {∞}, its convex conjugate is defined by

f ∗ (v) := sup ⟨u, v⟩ − f (u).


u∈dom(f )

We use a sup rather than a max to indicate that f ∗ (v) is potentially


∞. Following the previous discussion, −f ∗ (v) is the intercept of the
tightest affine lower bound with slope v of f (u). This is illustrated
Fig. 12.1.

12.2.2 Closed-form examples


Computing f ∗ (v) involves the resolution of a maximization problem,
which could be difficult in general without assumption on f . In some
cases, however, we can compute an analytical solution, as we now
illustrate.
12.2. Legendre–Fenchel transforms, convex conjugates 267

Example 12.2 (Analytical conjugate examples). When f (u) = 12 ∥u∥22 ,


with dom(f ) = RM , the conjugate is
1
f ∗ (v) = max ⟨u, v⟩ − ∥u∥22 .
u∈RM 2
Setting the gradient u 7→ ⟨u, v⟩ − 12 ∥u∥22 to zero, we obtain u⋆ = v.
Plugging u⋆ back, we therefore obtain
1 1
f ∗ (v) = ⟨u⋆ , v⟩ − ∥u⋆ ∥22 = ∥v∥22 .
2 2
Therefore, f = f ∗ in this case.
When f (u) = ⟨u, log u⟩, with dom(f ) = RM+ , the minimizer of
u 7→ ⟨u, v⟩ − ⟨u, log u⟩ is u = exp(v − 1) and the conjugate is

M

f (v) = exp(vj − 1).
X

j=1

See for instance Boyd and Vandenberghe (2004) or Beck (2017)


for many more examples. We emphasize that the image of f can be
the extended real line R ∪ {∞}. This is useful for instance in order to
incorporate constraints using an indicator function. Indeed, if

0 if u ∈ C
f (u) = ιC (u) = ,
∞ otherwise

so that dom(f ) = C, then

f ∗ (v) = sup ⟨u, v⟩ − f (u) = sup⟨u, v⟩,


u∈dom(f ) u∈C

which is known as the support function of the set C, with maximizer


known as linear maximization oracle (LMO).

12.2.3 Properties
The conjugate enjoys several useful properties, that we now summarize.
268 Smoothing by optimization

Proposition 12.2 (Convex conjugate properties).

1. Convexity: f ∗ (v) is a convex function for all f : RM →


R ∪ {∞} (even if f is nonconvex).

2. Fenchel-Young inequality: for all u, v ∈ RM

f (u) + f ∗ (v) − ⟨u, v⟩ ≥ 0.

3. Gradient: if the supremum in Definition 12.3 is uniquely


achieved, then f ∗ (v) is differentiable at v and its gradient is

∇f ∗ (v) = arg max ⟨u, v⟩ − f (u).


u∈dom(f )

Otherwise, f ∗ (v) is sub-differentiable at v and we get a sub-


gradient instead.

4. Maps: If f and f ∗ are differentiable, then

v = ∇f (u) ⇔ u = ∇f ∗ (v) ⇔ f ∗ (v) + f (u) − ⟨u, v⟩ = 0.

5. Biconjugate: f = f ∗∗ if and only if f is convex and closed


(i.e., its sublevel sets form a closed set), otherwise f ∗∗ ≤ f .

Proof. We include proofs for each claim.

1. This follows from the fact that v 7→ supu∈C g(u, v) is convex if g


is convex in v. Note that this is true even if g is nonconvex in u.
Here, g(u, v) = ⟨u, v⟩ − f (u), which is affine in v and therefore
convex in v.

2. This follows immediately from Definition 12.3.

3. This follows from Danskin’s theorem, reviewed in Section 10.2.


Another way to see this is by observing that

f ∗ (v) = ⟨g, v⟩ − f (g)


f ∗ (v ′ ) ≥ ⟨g, v ′ ⟩ − f (g),
12.2. Legendre–Fenchel transforms, convex conjugates 269

where g := arg max ⟨u, v⟩ − f (u). Subtracting the two, we obtain


u∈dom(f )

f ∗ (v ′ ) ≥ f ∗ (v) + ⟨g, v ′ − v⟩.

Now, using that f ∗ is convex and Definition 14.6, we obtain that


g = ∇f ∗ (v).

4. See, e.g., Bauschke and Combettes (2011, Proposition 16.10).

5. See Boyd and Vandenberghe (2004, Section 3.3).

12.2.4 Conjugate calculus


While deriving a conjugate expression can be difficult in general, in
some cases, it is possible to use simple rules to derive conjugates in
terms of other conjugates.

Proposition 12.3 (Conjugate calculus rules).

1. Separable functions: if f (u) = j=1 fj (uj ), then


PM

M

f (v) = fj∗ (vj ).
X

j=1

2. Scalar multiplication: if f (u) = c · g(u), for c > 0, then

f ∗ (v) = c · g ∗ (v/c).

3. Addition to an affine function and translation: if f (u) =


g(u) + ⟨α, u⟩ + β, then

f ∗ (v) = g ∗ (v − α) − β.

4. Composition with an invertible linear map: if f (u) =


g(M u), where x 7→ M x is an invertible linear map, then

f ∗ (v) = g ∗ (M −T v).
270 Smoothing by optimization

12.2.5 Fast Legendre transform

When an analytical solution is not available, we can resort to numerical


schemes to approximately compute the conjugate. When f is convex,
because −f is concave, the maximization in Definition 12.3 is that of a
concave function. Therefore, the conjugate can be computed to arbitrary
precision in polynomial time using algorithms such as projected gradient
ascent or conditional gradient. Without convexity assumption on f ,
f ∗ (v) can be approximated by

f ∗ (v) ≈ sup ⟨u, v⟩ − f (u),


u∈U

where U ⊆ dom(f ) is a discrete grid of values. We can then compute


f ∗ (v) for v ∈ V using the linear-time Legendre transform algorithm
(Lucet, 1997), where V ⊆ dom(f ∗ ) is another discrete grid. The com-
plexity is O(|U| · |V|), which is linear in the grid sizes. However, the
grid sizes are typically |U| = |V| = O(N M ), for N equally-distributed
points in each of the M dimensions. Therefore, this approach is limited
to small-dimensional settings, e.g., M ∈ {1, 2, 3}.

12.3 Dual approach

12.3.1 Duality between strong convexity and smoothness

We now state a well-known result that will underpin this whole chapter:
smoothness and strong convexity are dual to each other (Hiriart-Urruty
and Lemaréchal, 1993; Kakade et al., 2009; Beck, 2017; Zhou, 2018).
For a review of the notions of smoothness and strong convexity, see
Section 14.4.

Proposition 12.4 (Duality between strong convexity and smoothness).


f is µ-smooth w.r.t. the norm ∥ · ∥ over dom(f ) if and only if f ∗ is
µ -strongly convex w.r.t. the dual norm ∥ · ∥∗ over dom(f ).
1 ∗

We give two examples in Table 12.2.


12.3. Dual approach 271

Table 12.2: Examples of strongly-convex and smooth conjugate pairs.

Function Norm Domain Conjugate Dual norm Dual domain


1 1
2
2 ∥u∥2 ∥ · ∥2 RM 2
2 ∥v∥2 ∥ · ∥2 RM
⟨u, log u⟩ ∥ · ∥1 △M logsumexp(v) ∥ · ∥∞ RM

12.3.2 Smoothing by dual regularization


The duality between smoothness and strong convexity suggests a generic
approach in order to smooth any function f : RM → R, by going through
the dual space.
1. Compute the conjugate f ∗ :
f ∗ (v) := sup ⟨u, v⟩ − f (u).
u∈dom(f )

2. Add strongly-convex regularization Ω to the conjugate:


fΩ∗ (v) := f ∗ (v) + Ω(v).

3. Compute the conjugate of fΩ∗ :


fΩ (u) = fΩ∗∗ (u).

Note that u and v belong to different spaces, i.e., u ∈ dom(f ) and


v ∈ dom(f ∗ ). Following Proposition 12.4, if Ω is µ-strongly convex, then
fΩ (u) is µ1 -smooth. Furthermore, following Proposition 12.2, fΩ (u) is
convex, even if f is nonconvex. Therefore, fΩ (u) is a smooth convex
relaxation of f (u).
Steps 1 and 3 are the most challenging, as they both require the
derivation of a conjugate. While an analytical solution may not exist in
general, in some simple cases, there is, as we now illustrate.

Example 12.3 (Smoothing the 1-norm via dual regularization). We wish


to smooth the 1-norm f (u) := ∥u∥1 = Mj=1 |uj |.
P

1. Compute the conjugate. The conjugate of any norm ∥ · ∥


is the indicator function of the dual norm’s unit ball {v ∈
272 Smoothing by optimization

RM : ∥v∥∗ ≤ 1}. The dual norm of ∥u∥1 is ∥v∥∞ . Moreover,

{v ∈ RM : ∥v∥∞ ≤ 1} = [−1, 1]M .

Recalling that ιC is the indicator function of C, we obtain

f ∗ (v) = ι[−1,1]M (v).

2. Adding strongly-convex regularization. We add quadratic


regularization Ω(v) := 12 ∥v∥22 to define

fΩ∗ (v) := ι[−1,1]M (v) + Ω(v).

3. Computing the conjugate.

fΩ (u) = ⟨u, v ⋆ ⟩ − Ω(v ⋆ ) = (H(ui ))M


i=1 ,

where v ⋆ = clip (u) := max (min (u, 1) , −1) , and H, known


as the Huber loss is

 1 u2 if |u| ≤ 1
H(u) = 2
.
|u| − 1
2 otherwise

ReLU functions can be smoothed out in a similar way, as we see in


more details in Section 12.4.1.

Remark 12.1 (Regularization scaling). The conjugate of a 1-strongly


convex regularizer Ω gives a 1-smooth approximation of the origi-
nal function. To control the smoothness of the approximation, it
suffices to smooth with εΩ for ε > 0, leading to a 1/ε-smooth
approximation. This is easily achieved by using

fεΩ (v) = εfΩ (v/ε)


∇fεΩ (v) = ∇fΩ (v/ε).

Therefore, if we know how to compute fΩ , we can also compute


fεΩ and its gradient easily.
12.4. Examples 273

12.3.3 Equivalence between primal and dual regularizations


We call R : dom(f ) → R the primal regularization to distinguish it
from the dual regularization Ω : dom(f ∗ ) → R. It turns out that both
regularizations are equivalent.
Proposition 12.5 (Equivalence between primal and dual regularizations).
Let f : RM → R and R : RM → R. Then, fΩ = (f □R) with Ω = R∗ .
Proof. We have
fΩ (u) = (f ∗ + Ω)∗ (u) = sup ⟨u, v⟩ − f ∗ (v) − Ω(v).
v∈dom(f ∗ )

If h1 and h2 are convex, we have (h1 + h2 )∗ = h∗1 □h∗2 (Beck, 2017,


Theorem 4.17). Using h1 = f ∗ and h2 = Ω = R∗ gives the desired result
(note that f ∗ and R∗ are convex even if f and R are not).

In particular, with R(u) = 12 ∥u∥22 , this shows that


Mf = fR = fR∗ .
Given the equivalence between primal and dual regularizations, using
one way or the other is mainly a matter of mathematical or algorithmic
convenience, depending on the case.
Remark 12.2 (Regularization scaling). Following Definition 12.2, if
we use dual regularization εΩ, where ε > 0 controls the reg-
ularization strength, the corresponding primal regularization is
R = εΩ∗ (·/ε). That is, we have

fεΩ = f □εΩ∗ (·/ε).


For applications of smoothing techniques to non-smooth optimiza-
tion, see (Nesterov, 2005; Beck and Teboulle, 2012).

12.4 Examples

12.4.1 Smoothed ReLU functions


The ReLU function is defined by

u if u ≥ 0
relu(u) := = max(u, 0).
0 otherwise
274 Smoothing by optimization

Its conjugate is
relu∗ (v) = ι[0,1] (v).
To notice why, we observe that since the objective is linear

1 if u ≥ 0
max uv = max uv = . (12.2)
v∈[0,1] v∈{0,1} 0 otherwise

If we use the regularizer Ω(v) = v log v + (1 − v) log(1 − v), we obtain

reluΩ (u) = softplus(u) = log(1 + exp(u)).

If we use the regularizer Ω(v) = v(v − 1), we obtain



0,

 u ≤ −1
reluΩ (u) = sparseplus(u) = 1
(u + 1)2 , −1 < u < 1 .
4
u≥1

u,

12.4.2 Smoothed max operators


With a slight notation overloading, let us define the maximum operator
as
max(u) := max uj .
j∈[M ]
Its conjugate is
max∗ (v) = ι△M (v).
To notice why, we observe that the vertices of the probability simplex
△M are the standard basis vectors e1 , . . . , eM . Since the objective is
linear, we then have

max(u) = max ⟨u, v⟩ = max ⟨u, v⟩.


v∈△M v∈{e1 ,...,eM }

We can therefore define the max operator smoothed by a strongly convex


regularizer Ω as

maxΩ (u) := max ⟨u, v⟩ − Ω(v).


v∈△M

This smoothed max operator can be useful in a neural network, for


example as a smoothed max pooling layer.
Besides strong convexity, two useful assumptions about Ω are
12.4. Examples 275

A.1. Ω(v) = 0 for all v ∈ {e1 , . . . , eM },

A.2. Ω(P v) = Ω(v) for any permutation matrix P .

Using Jensen’s inequality, these assumptions imply that −Ω is an


“entropy function”: it is maximized at v = M 1
1 (the uniform distribution)
and minimized at any v ∈ {e1 , . . . , eM } (a delta distribution) (Blondel
et al., 2020, Proposition 4). We give two examples of Ω below satisfying
A.1 and A.2.
When Ω(v) = ⟨v, log v⟩ (Shannon’s negative entropy), we obtain
M
maxΩ (u) = logsumexp(u) = log
X
euj .
j=1

Alternatively, we can use Ω(v) = 12 ⟨v, v − 1⟩ (Gini’s negative en-


tropy), which is up to a constant equal to quadratic regularization. For
this choice, we therefore obtain
1
maxΩ (u) = sparsemax(u) := max ⟨u, v⟩ − ⟨v, v − 1⟩.
v∈△ M 2

We discuss how to compute v ⋆ = sparseargmax(u) in the sequel.


The properties of the smoothed max operator maxΩ have been
studied in (Mensch and Blondel, 2018, Lemma 1).
Using the vector u = [u, 0] ∈ R2 as input, the smoothed max
operator recovers the smoothed ReLU:

maxΩ ([u, 0]) = reluΨ (u),

where we defined Ψ(u) := Ω([u, 1 − u]). With Ω([v, 1 − v]), we recover


v log v + (1 − v) log(1 − v); with Ω([v, 1 − v]), we recover v(v − 1), that
we used to smooth the ReLU.

Smoothed min

Given a smoothed max operator maxΩ , we can easily define a smoothed


min operator as

minΩ (u) := −maxΩ (−u).


276 Smoothing by optimization

12.4.3 Relaxed step functions (sigmoids)


The binary step function, a.k.a. Heaviside step function, is defined by

1 if u ≥ 0
step(u) := .
0 otherwise

From Eq. (12.2), its variational form is

step(u) = arg max uv.


v∈[0,1]

We can therefore define the relaxation

stepΩ (u) := arg max uv − Ω(v).


v∈[0,1]

Notice that, unlike the case of the smoothed ReLU, it is a regularized


argmax, not a regularized max. Strongly convex regularization Ω ensures
that stepΩ (u) is a Lipschitz continuous function of u, unlike step(u). If
we use the regularizer Ω(v) = v log v + (1 − v) log(1 − v), we obtain the
closed form
1 eu
stepΩ (u) = logistic(u) := = .
1 + e−u 1 + eu
This function is differentiable everywhere. As an alternative, if we use
Ω(v) = v(v − 1), we obtain a piecewise linear sigmoid,

0, u ≤ −1


stepΩ (u) = sparsesigmoid(u) := 1
2 (u + 1), −1 < u < 1 .

1, u≥1

Unlike the logistic function, it can reach the exact values 0 or 1. However,
the function has two kinks, where the function is non-differentiable.
It turns out that the three sigmoids we presented above (step, logistic,
sparsesigmoid) are all equal to the derivative of their corresponding
non-linearity:

step(u) = relu′ (u)


logistic(u) = softplus′ (u)
sparsesigmoid(u) = sparseplus′ (u)
12.4. Examples 277

Activations Sigmoids
2 1.0
ReLU Heaviside
SoftPlus Logistic
1 SparsePlus 0.5 SparseSigmoid

0 0.0
2 1 0 1 2 2 1 0 1 2

Figure 12.2: Some ReLU functions and sigmoids. Differentiating the left functions
give the right functions.

and more generally


relu′Ω (u) = stepΩ (u).
This is again a consequence of Danskin’s theorem; see Example 10.2.
We illustrate the smoothed ReLu functions and relaxed step func-
tions (sigmoids) in Fig. 12.2.

12.4.4 Relaxed argmax operators


With a slight notation overloading, let us now define

argmax(u) := ϕ(arg max uj ),


j∈[M ]

where ϕ(j) = onehot(j) = ej is used to embed any integer j ∈ [M ] into


RM . Following the previous paragraph, we have the variational form

argmax(u) = arg max⟨u, v⟩ = arg max ⟨u, v⟩,


v∈△M v∈{e1 ,...,eM }

where the second equality uses that a linear function is maximized at


one vertex of the simplex. This variational form suggests to define the
relaxation
argmaxΩ (u) := arg max⟨u, v⟩ − Ω(v).
v∈△M

When Ω(v) = ⟨v, log v⟩, we obtain


exp(u)
argmaxΩ (u) = softargmax(u) = PM ,
j=1 exp(uj )

which is differentiable everywhere.


278 Smoothing by optimization

When Ω(v) = 12 ⟨v, v − 1⟩, which is up to a constant equal to 12 ∥v∥22 ,


we obtain the sparseargmax (Martins and Astudillo, 2016)
argmaxΩ (u) = sparseargmax(u)
:= arg max⟨u, v⟩ − ⟨v, v − 1⟩
v∈△M

= arg min ∥u − v∥22 ,


v∈△M

which is nothing but the Euclidean projection onto the probability


simplex, discussed in Section 15.3. At its name indicates, sparseargmax
is sparse but it is only differentiable almost everywhere.
Similarly to sigmoids, it turns out that these mappings are equal to
the gradient of their corresponding non-linearities:
argmaxΩ (u) = ∇maxΩ (u).

Relaxed argmin operators


Given a smoothed max operator maxΩ , we can easily define a relaxed
argmin
argminΩ (u) := ∇minΩ (u) = ∇maxΩ (−u).

12.5 Summary

We reviewed how to smooth a function by infimal convolution between


the function and primal regularization. The Moreau envelope is a special
case, obtained by using quadratic regularization.
The Legendre-Fenchel transformation, a.k.a. convex conjugate, can
be seen as a dual representation of a function: instead of representing f
by its graph (u, f (u)) for u ∈ dom(f ), we can represent it by the set
of tangents with slope v and intercept −f ∗ (v) for v ∈ dom(f ∗ ) As its
name indicates, it is convex, even if the original function is not.
Infimal convolution can equivalently be obtained by going to this
dual space, applying dual regularization there and coming back. We
showed how to apply these techniques to create smoothed ReLU and
max functions. We also showed that taking their gradients allowed us
to obtain generalized sigmoids and argmax functions.
13
Smoothing by integration

In this chapter, we review smoothing techniques based on convolution.

13.1 Convolution

13.1.1 Convolution operators

The convolution between two functions f and g produces another


function, denoted f ∗ g. It is defined by
Z ∞
(f ∗ g)(µ) := f (u)g(µ − u) du, (13.1)
−∞

assuming that the integral is well defined. It is therefore the integral of


the product of f and g after g is reflected about the y-axis and shifted.
It can be seen as a generalization of the moving average. Using the
change of variable z := µ − u, we can also write
Z ∞
(f ∗ g)(µ) = f (µ − z)g(z) dz = (g ∗ f )(µ). (13.2)
−∞

The convolution operator is therefore commutative.

279
280 Smoothing by integration

13.1.2 Convolution with a kernel

The convolution is frequently used together with a kernel κ to create a


smooth approximation f ∗ κ of f . The most frequently used kernel is
the Gaussian kernel with width σ, defined by
1 1 z 2
κσ (z) := √ e− 2 ( σ ) .
2πσ
This is the probability density function (PDF) of the normal distribution
with zero mean and variance σ 2 . The term √2πσ 1
is a normalization
constant, ensuring that the kernel sums to 1 for all σ. We therefore say
that κσ is a normalized kernel.

Averaging perspective

Applying the definition of the convolution in Eq. (13.1), we obtain


1
Z ∞
1 µ−u 2
(f ∗ κσ )(µ) := √ f (u)e− 2 ( σ
)
du
2πσ −∞
= EU ∼pµ,σ [f (U )],

where
1 1 µ−u 2
pµ,σ (u) := κσ (µ − u) = κσ (z) = √ e− 2 ( σ )
2πσ
is the PDF of the Gaussian distribution with mean µ and variance
σ 2 . Therefore, we can see f ∗ κσ as the expectation of f (u) over a
Gaussian centered around µ. This property is true for all translation-
invariant kernels, that correspond to a location-scale family distribution
(e.g., the Laplace distribution). The convolution therefore performs an
averaging with all points, with points nearby µ given more weight by
the distribution. The parameter σ controls the importance we want to
give to farther points. We call this viewpoint averaging, as we replace
f (u) by E[f (U )].

Perturbation perspective

Conversely, using the alternative definition of the convolution operator


in Eq. (13.2), which stems from the commutativity of the convolution,
13.1. Convolution 281

we have
Z ∞
1 z 2
(f ∗ κσ )(µ) := f (µ − z)e− 2 ( σ ) dz
−∞
= EZ∼p0,σ [f (µ − Z)]
= EZ∼p0,σ [f (µ + Z)],

where, in the third line, we used that p0,σ is sign invariant, i.e., p0,σ (z) =
p0,σ (−z). This viewpoint shows that smoothing by convolution with
a Gaussian kernel can also be seen as injecting Gaussian noise or
perturbations to the function’s input.

Limit case

When σ → 0, the kernel κσ converges to a Dirac delta function,

lim κσ (z) = δ(z).


σ→0

Since the Dirac delta is the multiplicative identity of the convolution


algebra (this is also known as the sifting property), when σ → 0, f ∗ κσ
converges to f , i.e.,

lim (f ∗ κσ )(u) = f (u).


σ→0

13.1.3 Discrete convolution

Many times, we work with functions whose convolution does not have
an analytical form. In these cases, we can use a discrete convolution on
a grid of values. For two functions f and g defined over Z, the discrete
convolution is defined by

(f ∗ g)[i] := f [j]g[i − j].
X

j=−∞

As for its continuous counterpart, the discrete convolution is commu-


mative, namely,

(f ∗ g)[i] = f [i − j]g[j] = (g ∗ f )[i].
X

j=−∞
282 Smoothing by integration

10
= 0.25
8 = 0.5
= 1.0
6

0
3 2 1 0 1 2 3

Figure 13.1: Smoothing of the signal f [t] := t2 + 0.3 sin(6πt) with a sampled and
renormalized Gaussian kernel.

When g has finite support over the set S := {−M, −M +1, . . . , 0, . . . , M −


1, M }, meaning that g[i] = 0 for all i ̸∈ S, a finite summation may be
used instead, i.e.,

M
(f ∗ g)[i] = f [i − j]g[j] = (g ∗ f )[i].
X

j=−M

In practice, convolution between a discrete signal f : Z → R and a


continuous kernel κ : R → R is implemented by discretizing the kernel.
One of the simplest approaches consists in sampling points on an interval,
evaluating the kernel at these points and renormalizing the obtained
values, so that the sampled kernel sums to 1. This is illustrated with
the Gaussian kernel in Fig. 13.1. Since the Gaussian kernel decays
exponentially fast, we can choose a small interval around 0. For a survey
of other possible discretizations of the Gaussian kernel, see Getreuer
(2013).
13.1. Convolution 283

13.1.4 Differentiation
Remarkably, provided that the two functions are integrable with inte-
grable derivatives, the derivative of the convolution satisfies
(f ∗ g)′ = (f ′ ∗ g) = (f ∗ g ′ ),
which simply stems from switching derivative and integral in the defini-
tion of the convolution. Moreover, we have the following proposition.

Proposition 13.1 (Differentiability of the convolution). If g is n-times


differentiable with compact support over R and f is locally inte-
grable over R, then f ∗ g is n-times differentiable over R.

13.1.5 Multidimensional convolution


So far, we studied the convolution of one-dimensional functions. The
definition can be naturally extended to multidimensional functions
f : RM → R and g : RM → R as
Z
(f ∗ g)(µ) := f (u)g(µ − u) du,
RM
assuming again that the integral exists. Typically, a Gaussian kernel
with diagonal covariance matrix is used
M
1 1
z 2
− 1 ( j )2 1 ∥z∥2
κσ (z) := e 2 σj = √ M e− 2 σ 2 ,
Y

j=1 2πσj 2π σ M
where, in the second equality, we assumed σ1 = · · · = σM . In an image
processing context, where M = 2, it is approximated using a discrete
convolution and it is called a Gaussian blur.

13.1.6 Link between convolution and infimal convolution


Recall the definition of infimal convolution (Section 12.1):
h(µ) := (f □g)(µ) := inf f (u) + g(µ − u).
u∈R

If we replace the minimum with a soft minimum (Section 12.4.2), we


then obtain
Z
1
hε (µ) := (f □g)ε (µ) := −ε log e− ε (f (u)+g(µ−u)) du.
284 Smoothing by integration

By using the exponential change of variable (sometimes referred to as


Cole-Hopf transformation in a partial differential equation context)
1
Cε {f }(u) := e− ε f (u)
Cε−1 {F }(v) = −ε log F (v),

we can define each function in the exponential domain,

Fε := Cε {f }, Gε := Cε {g}, Hε := Cε {hε }.

We then obtain
Hε (µ) = (Fε ∗ Gε )(µ).
In the exponential domain, the convolution is therefore the counterpart
of the infimal convolution, if we replace the min-plus algebra with
the sum-product algebra. Back to log domain, we obtain

hε (µ) = Cε−1 {Hε }(µ).

Combining the transformation and its inverse, we can write

hε (µ) = Cε−1 {Cε {f } ∗ Cε {g}}(µ).

As an example, when g(z) = 12 z 2 , in which case f □g is the Moreau


envelope (see Section 12.1), we obtain Gε (z) = exp(− 2ε
1 2
z ), the (unnor-
malized) Gaussian kernel.

13.2 Fourier and Laplace transforms

Let us define the Fourier transform of f by


Z ∞
F (s) := F{f }(s) := f (t)e−i2πst dt, s ∈ R.
−∞

Note that F{f } is a function transformation: it transforms f into


another function F .

13.2.1 Convolution theorem


Now, consider the convolution

h(t) := (f ∗ g)(t).
13.2. Fourier and Laplace transforms 285

If we define the three transformations


F := F{f }, G := F{g}, H := F{h},
the convolution theorem states that
H(s) = F{h}(s) = F (s) · G(s), s ∈ R.
Written differently, we have
F{f ∗ g} = F{f } · F {g}.
In words, in the Fourier domain, the convolution operation becomes a
multiplication. Conversely,
h(t) = (f ∗ g)(t) = F −1 {F · G}(t), t ∈ R.
The convolution theorem also holds if we replace the Fourier transform
with the Laplace transform or with the two-sided (bilateral) Laplace
transform.

13.2.2 Link between Fourier and Legendre transforms


In Section 12.2, we studied another function transformation: the convex
conjugate, also known as Legendre-Fenchel transform. We recap the
analogies between these transforms in Table 13.1. In particular, the
counterpart of
F{f ∗ g} = F{f } · F {g}.
for the infimal convolution is
(f □g)∗ = f ∗ + g ∗ .
In words, the Legendre-Fenchel transform is to the infimal convolution
what the Fourier transform is to the convolution.

13.2.3 The soft Legendre-Fenchel transform


Let f : RM → R . The convex conjugate of f (see Section 12.2) is
f ∗ (v) := max ⟨u, v⟩ − f (u).
u∈RM
If necessary, we can support constraints by including an indicator
function in the definition of f . The conjugate can be smoothed out using
a log-sum-exp, which plays the role of a soft maximum (Section 12.4.2).
286 Smoothing by integration

Table 13.1: Analogy between Fourier and Legendre transforms. See Proposition 12.3
for more conjugate calculus rules.

Fourier F{f } Legendre f ∗


Semiring (+, ·) (min, +)
Scaling (a > 0) f (t) = g(t/a) f (t) = ag(t/a)
F{f }(t) = aF{g}(as) f ∗ (s) = ag ∗ (s)
Translation f (t) = g(t − t0 ) f (t) = g(t − t0 )
F{f }(s) = e−i2πt0 s F{g}(s) f ∗ (s) = g ∗ (s) + t0
Convolution h=f ∗g h = f □g
F{h} = F{f } · F {g} h = f ∗ + g∗

2
Gaussian / quadratic f (t) =qe−at f (t) = a2 t2
2 2
F{f }(s) = πa e−π s /a f ∗ (s) = 2a1 2
s
Smoothing f ∗ κσ 1
f □ 2ε ∥ · ∥22

Definition 13.1 (Soft convex conjugate).


1
Z  
fε∗ (v) := ε log exp [⟨u, v⟩ − f (u)]) du.
ε

In the limit ε → 0, we recover the convex conjugate. We now show


that this smoothed conjugate can be rewritten using a convolution if
we apply a bijective transformation to f .

Proposition 13.2 (Smoothed convex conjugate as convolution). The


smoothed conjugate can be rewritten as
1
 
fε∗ (v) = Q−1
ε (v)
Qε {f } ∗ Gε
13.2. Fourier and Laplace transforms 287

where
1 1 ∥ · ∥22
  !
Gε := C = exp −
∥ · ∥22
2 2 ε
1 1 1
   
Qε {f } := C f (·) − ∥ · ∥2 = exp
2
∥ · ∥2 − f (·)
2
2 2ε ε
1
Qε {F } := ∥ · ∥2 − ε log(F (·)).
−1 2
2
This insight was tweeted by Gabriel Peyré in April 2020.

Proof.
1 1
Z  
fε (v) := ε log exp ⟨u, v⟩ − f (u)) du
ε ε
1 1 1 1
Z  
= ε log exp − ∥u − v∥22 + ∥u∥22 + ∥v∥2 − f (u)) du
2
2ε 2ε 2ε ε
1 1 1 1
Z  
= ε log exp − ∥u − v∥22 + ∥u∥22 − f (u)) du + ∥v∥22
2ε 2ε ε 2
1
Z
= ε log Gε (v − u)Qε {f }(u)du + ∥v∥22
2
1
= ε log(Qε {f } ∗ Gε )(v) + ∥v∥2
2 
1 1

= ∥v∥ − ε log
2
(v)
2 Qε {f } ∗ Gε
1
 
= Q−1ε (v)
Qε {f } ∗ Gε

What did we gain from this viewpoint? The convex conjugate can
often be difficult to compute in closed form. If we replace RM with a dis-
crete set S (i.e., a grid), we can then approximate the smoothed convex
conjugate in O(n log n), where n = |S|, using a discrete convolution,

(Qε {f } ∗ Gε )(v) ≈ Gε (v − u)Qε {f }(u)


X

u∈S
= Kq,

where K is the n×n Gaussian kernel matrix whose entries correspond to


exp(− 2ε
1
∥u−u′ ∥22 ) for u, u′ ∈ S and q is the n-dimensional vector whose

You might also like