0% found this document useful (0 votes)
209 views307 pages

Supp2 2

This document provides supplementary material for the book "Probabilistic Machine Learning: Advanced Topics" by Kevin Murphy. It contains additional examples and explanations for fundamental probabilistic concepts like probability, graphical models, inference algorithms, prediction models, and discovery methods. The document is divided into 5 parts covering fundamentals, inference, prediction, generation, and discovery. It provides more in-depth explanations of advanced probabilistic machine learning topics.
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
209 views307 pages

Supp2 2

This document provides supplementary material for the book "Probabilistic Machine Learning: Advanced Topics" by Kevin Murphy. It contains additional examples and explanations for fundamental probabilistic concepts like probability, graphical models, inference algorithms, prediction models, and discovery methods. The document is divided into 5 parts covering fundamentals, inference, prediction, generation, and discovery. It provides more in-depth explanations of advanced probabilistic machine learning topics.
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 307

Supplementary Material for

“Probabilistic Machine Learning: Advanced Topics”

Kevin Murphy

August 15, 2023


Contents

1 Introduction 5

I Fundamentals 7
2 Probability 9
2.1 More fun with MVNs 9
2.1.1 Inference in the presence of missing data 9
2.1.2 Sensor fusion with unknown measurement noise 9
2.2 Google’s PageRank algorithm 12
2.2.1 Retrieving relevant pages using inverted indices 12
2.2.2 The PageRank score 13
2.2.3 Efficiently computing the PageRank vector 14
2.2.4 Web spam 15
2.2.5 Personalized PageRank 16
3 Bayesian statistics 17
3.1 Bayesian concept learning 17
3.1.1 Learning a discrete concept: the number game 17
3.1.2 Learning a continuous concept: the healthy levels game 23
3.2 Informative priors 26
3.2.1 Domain specific priors 27
3.2.2 Gaussian prior 28
3.2.3 Power-law prior 29
3.2.4 Erlang prior 29
3.3 Tweedie’s formula (Empirical Bayes without estimating the prior) 30
4 Graphical models 33
4.1 More examples of DGMs 33
4.1.1 Water sprinkler 33
4.1.2 Asia network 34
4.1.3 The QMR network 35
4

1
2
4.1.4 Genetic linkage analysis 37
3
4.2 More examples of UGMs 40
4
4.3 Restricted Boltzmann machines (RBMs) in more detail 40
5
4.3.1 Binary RBMs 40
6 4.3.2 Categorical RBMs 41
7 4.3.3 Gaussian RBMs 41
8 4.3.4 RBMs with Gaussian hidden units 42
9
5 Information theory 43
10
5.1 Minimizing KL between two Gaussians 43
11
5.1.1 Moment projection 43
12
5.1.2 Information projection 43
13
14 6 Optimization 45
15
6.1 Proximal methods 45
16
6.1.1 Proximal operators 45
17
6.1.2 Computing proximal operators 48
18
6.1.3 Proximal point methods (PPM) 51
19
6.1.4 Mirror descent 53
20
6.1.5 Proximal gradient method 53
21
6.1.6 Alternating direction method of multipliers (ADMM) 55
22
6.2 Dynamic programming 57
23
6.2.1 Example: computing Fibonnaci numbers 57
24
6.2.2 ML examples 58
25
6.3 Conjugate duality 58
26
6.3.1 Introduction 58
27
6.3.2 Example: exponential function 60
28
6.3.3 Conjugate of a conjugate 61
29
6.3.4 Bounds for the logistic (sigmoid) function 61
30
6.4 The Bayesian learning rule 63
31
6.4.1 Deriving inference algorithms from BLR 64
32
6.4.2 Deriving optimization algorithms from BLR 66
33
6.4.3 Variational optimization 70
34
35
36 II Inference 71
37
38 7 Inference algorithms: an overview 73
39
40
8 Inference for state-space models 75
41 8.1 More Kalman filtering 75
42 8.1.1 Example: tracking an object with spiral dynamics 75
43 8.1.2 Derivation of RLS 75
44 8.1.3 Handling unknown observation noise 77
45 8.1.4 Predictive coding as Kalman filtering 78
46 8.2 More extended Kalman filtering 79
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
CONTENTS

1
2 8.2.1 Derivation of the EKF 79
3 8.2.2 Example: Tracking a pendulum 80
4 8.3 Exponential-family EKF 81
5 8.3.1 Modeling assumptions 81
6 8.3.2 Algorithm 82
7 8.3.3 EEKF for training logistic regression 82
8 8.3.4 EEKF performs online natural gradient descent 83
9
9 Inference for graphical models 89
10
11
9.1 Belief propagation on trees 89
12 9.1.1 BP for polytrees 89
13 9.2 The junction tree algorithm (JTA) 92
14 9.2.1 Tree decompositions 92
15 9.2.2 Message passing on a junction tree 96
16 9.2.3 The generalized distributive law 99
17 9.2.4 JTA applied to a chain 100
18 9.2.5 JTA for general temporal graphical models 101
19 9.3 MAP estimation for discrete PGMs 103
20 9.3.1 Notation 103
21 9.3.2 The marginal polytope 104
22 9.3.3 Linear programming relaxation 105
23 9.3.4 Graphcuts 108
24
10 Variational inference 113
25
26 10.1 More Gaussian VI 113
27 10.1.1 Example: Full-rank vs diagonal GVI on 1d linear regression 113
28 10.1.2 Example: Full-rank vs rank-1 GVI for logistic regression 115
29 10.1.3 Structured (sparse) Gaussian VI 116
30 10.2 Online variational inference 118
31 10.2.1 FOO-VB 118
32 10.2.2 Bayesian gradient descent 119
33 10.3 Beyond mean field 120
34 10.3.1 Exploiting partial conjugacy 120
35 10.3.2 Structured mean for factorial HMMs 124
36 10.4 VI for graphical model inference 126
37 10.4.1 Exact inference as VI 126
38 10.4.2 Mean field VI 127
39 10.4.3 Loopy belief propagation as VI 128
40 10.4.4 Convex belief propagation 131
41 10.4.5 Tree-reweighted belief propagation 132
42 10.4.6 Other tractable versions of convex BP 133
43
44
11 Monte Carlo Inference 135
45
12 Markov Chain Monte Carlo (MCMC) inference 137
46
47
Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license
6

1
2 13 Sequential Monte Carlo (SMC) inference 139
3 13.1 More applications of particle filtering 139
4 13.1.1 1d pendulum model with outliers 139
5 13.1.2 Visual object tracking 139
6 13.1.3 Online parameter estimation 141
7 13.1.4 Monte Carlo robot localization 141
8 13.2 Particle MCMC methods 142
9 13.2.1 Particle Marginal Metropolis Hastings 143
10 13.2.2 Particle Independent Metropolis Hastings 143
11 13.2.3 Particle Gibbs 144
12
13
14 III Prediction 145
15
16 14 Predictive models: an overview 147
17
15 Generalized linear models 149
18
19 15.1 Variational inference for logistic regression 149
20 15.1.1 Binary logistic regression 149
21 15.1.2 Multinomial logistic regression 151
22 15.2 Converting multinomial logistic regression to Poisson regression 155
23 15.2.1 Beta-binomial logistic regression 155
24 15.2.2 Poisson regression 156
25 15.2.3 GLMM (hierarchical Bayes) regression 157
26
16 Deep neural networks 161
27
28
16.1 More canonical examples of neural networks 161
29
16.1.1 Transformers 161
30 16.1.2 Graph neural networks (GNNs) 163
31
17 Bayesian neural networks 169
32
17.1 More details on EKF for training MLPs 169
33
17.1.1 Global EKF 169
34
17.1.2 Decoupled EKF 169
35
36
17.1.3 Mini-batch EKF 170
37 18 Gaussian processes 171
38
18.1 Deep GPs 171
39
18.2 GPs and SSMs 176
40
41 19 Beyond the iid assumption 177
42
43
44 IV Generation 179
45
46 20 Generative models: an overview 181
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
CONTENTS

1
2 21 Variational autoencoders 183
3 21.0.1 VAEs with missing data 183
4
5 22 Auto-regressive models 187
6
23 Normalizing flows 189
7
8 24 Energy-based models 191
9
10 25 Denoising diffusion models 193
11
12 26 Generative adversarial networks 195
13
14
15 V Discovery 197
16
27 Discovery methods: an overview 199
17
18
28 Latent factor models 201
19
28.1 Inference in topic models 201
20
28.1.1 Collapsed Gibbs sampling for LDA 201
21
28.1.2 Variational inference for LDA 203
22
23 29 State-space models 207
24 29.1 Continuous time SSMs 207
25 29.1.1 Ordinary differential equations 207
26 29.1.2 Example: Noiseless 1d spring-mass system 208
27 29.1.3 Example: tracking a moving object in continuous time 209
28 29.1.4 Example: tracking a particle in 2d 212
29 29.2 Structured State Space Sequence model (S4) 212
30
31 30 Graph learning 215
32 30.1 Latent variable models for graphs 215
33 30.1.1 Stochastic block model 215
34 30.1.2 Mixed membership stochastic block model 217
35 30.1.3 Infinite relational model 219
36 30.2 Learning tree structures 220
37 30.2.1 Chow-Liu algorithm 221
38 30.2.2 Finding the MAP forest 223
39 30.2.3 Mixtures of trees 223
40 30.3 Learning DAG structures 224
41 30.3.1 Faithfulness 224
42 30.3.2 Markov equivalence 225
43 30.3.3 Bayesian model selection: statistical foundations 226
44 30.3.4 Bayesian model selection: algorithms 229
45 30.3.5 Constraint-based approach 231
46 30.3.6 Methods based on sparse optimization 234
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


8

1
2 30.3.7 Consistent estimators 234
3 30.3.8 Handling latent variables 235
4 30.4 Learning undirected graph structures 243
5 30.4.1 Dependency networks 243
6 30.4.2 Graphical lasso for GGMs 245
7 30.4.3 Graphical lasso for discrete MRFs/CRFs 246
8 30.4.4 Bayesian inference for undirected graph structures 247
9 30.5 Learning causal DAGs 249
10 30.5.1 Learning cause-effect pairs 249
11 30.5.2 Learning causal DAGs from interventional data 252
12 30.5.3 Learning from low-level inputs 253
13
14
31 Non-parametric Bayesian models 255
15 31.1 Dirichlet processes 255
16 31.1.1 Definition of a DP 255
17 31.1.2 Stick breaking construction of the DP 257
18 31.1.3 The Chinese restaurant process (CRP) 259
19 31.2 Dirichlet process mixture models 260
20 31.2.1 Model definition 260
21 31.2.2 Fitting using collapsed Gibbs sampling 262
22 31.2.3 Fitting using variational Bayes 265
23 31.2.4 Other fitting algorithms 266
24 31.2.5 Choosing the hyper-parameters 267
25 31.3 Generalizations of the Dirichlet process 267
26 31.3.1 Pitman-Yor process 267
27 31.3.2 Dependent random probability measures 268
28 31.4 The Indian buffet process and the beta process 271
29 31.5 Small-variance asymptotics 274
30 31.6 Completely random measures 277
31 31.7 Lévy processes 278
32 31.8 Point processes with repulsion and reinforcement 280
33 31.8.1 Poisson process 280
34 31.8.2 Renewal process 281
35 31.8.3 Hawkes process 282
36 31.8.4 Gibbs point process 284
37 31.8.5 Determinantal point process 285
38
39 32 Representation learning 289
40
41 33 Interpretability 291
42
43
44 VI Decision making 293
45
46
34 Multi-step decision problems 295
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
CONTENTS

1
2 35 Reinforcement learning 297
3
4
36 Causality 299
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


1 Introduction

This book contains supplementary material for [book2]. Some sections have not been checked as
carefully as the main book, so caveat lector.
Part I

Fundamentals
2 Probability

2.1 More fun with MVNs


2.1.1 Inference in the presence of missing data
Suppose we have a linear Gaussian system where we only observe part of y, call it y1 , while the
other part, y2 , is hidden. That is, we generalize Main Equation (2.119) is as follows:

p(z) = N (z|µz , Σz ) (2.1)


          
y1 y1 W1 b Σ11 Σ12
p |z = N | z+ 1 , (2.2)
y2 y2 W2 b2 Σ21 Σ22

We can compute p(z|y1 ) by partitioning the joint into p(z, y1 , y2 ), marginalizing out y2 , and then
conditioning on y1 . The result is as follows:

p(z|y1 ) = N (z|µz|1 , Σz|1 ) (2.3)


Σ−1 −1 T −1
z|1 = Σz + W1 Σ11 W1 (2.4)
µz|1 = Σz|1 [WT1 Σ−1
11 (y1 − b1 ) + Σ−1
z µz ] (2.5)

2.1.2 Sensor fusion with unknown measurement noise


In this section, we extend the sensor fusion results from Main Section 2.3.2.3 to the case where the
precision of each measurement device is unknown. This turns out to yield a potentially multi-modal
posterior, as we will see, which is quite different from the Gaussian case. Our presentation is based
on [Minka01GaussVar].
For simplicity, we assume the latent quantity is scalar, z ∈ R, and that we just have two
measurement devices, x and y. However, we allow these to have different precisions, so the data
generating mechanism has the form xn |z ∼ N (z, λ−1 x ) and yn |z ∼ N (z, λy ). We will use a
−1

non-informative prior for z, p(z) ∝ 1, which we can emulate using an infinitely broad Gaussian,
p(z) = N (z|m0 = 0, λ−10 = ∞). So the unknown parameters are the two measurement precisions,
θ = (λx , λy ).
Suppose we make 2 independent measurements with each device, which turn out to be

x1 = 1.1, x2 = 1.9, y1 = 2.9, y2 = 4.1 (2.6)


16

1
2 If the parameters θ were known, then the posterior would be Gaussian:
3
4
p(z|D, λx , λy ) = N (z|mN , λ−1
N ) (2.7)
5 λN = λ0 + Nx λx + Ny λy (2.8)
6 λx Nx x + λy Ny y
mN = (2.9)
7 Nx λx + Ny λy
8
9 where Nx = 2 is the number of x measurements, Ny = 2 is the number of y measurements,
PNx PNy
10 x = N1x n=1 xn = 1.5 and y = N1y n=1 yn = 3.5. This result follows because the posterior
11 precision is the sum of the measurement precisions, and the posterior mean is a weighted sum of the
12 prior mean (which is 0) and the data means.
13 However, the measurement precisions are not known. A simple solution is to estimate them by
14 maximum likelihood. The log-likelihood is given by
15
Nx λx X Ny λy X
16 `(z, λx , λy ) = log λx − (xn − z)2 + log λy − (yn − z)2 (2.10)
17 2 2 n 2 2 n
18
The MLE is obtained by solving the following simultaneous equations:
19
20 ∂`
21
= λx Nx (x − z) + λy Ny (y − z) = 0 (2.11)
∂z
22 Nx
∂` 1 1 X
23 = − (xn − z)2 = 0 (2.12)
24
∂λx λx Nx n=1
25 Ny
∂` 1 1 X
26 = − (yn − z)2 = 0 (2.13)
27
∂λy λy Ny n=1
28
This gives
29
30
Nx λ̂x x + Ny λ̂y y
31 ẑ = (2.14)
Nx λ̂x + Ny λ̂y
32
1 X
33 1/λ̂x = (xn − ẑ)2 (2.15)
34 Nx n
35 1 X
36 1/λ̂y = (yn − ẑ)2 (2.16)
Ny n
37
38 We notice that the MLE for z has the same form as the posterior mean, mN .
39 We can solve these equations by fixed point iteration. Let us initialize by estimating λx = 1/s2x
PNx PNy
40 and λy = 1/s2y , where s2x = N1x n=1 (xn − x)2 = 0.16 and s2y = N1y n=1 (yn − y)2 = 0.36. Using
41
42
this, we get ẑ = 2.1154, so p(z|D, λ̂x , λ̂y ) = N (z|2.1154, 0.0554). If we now iterate, we converge to
43 λ̂x = 1/0.1662, λ̂y = 1/4.0509, p(z|D, λ̂x , λ̂y ) = N (z|1.5788, 0.0798).
44 The plug-in approximation to the posterior is plotted in Figure 2.1(a). This weights each sensor
45 according to its estimated
h precision.
i Since sensor y was estimated to be much less reliable than
46 sensor x, we have E z|D, λ̂x , λ̂y ≈ x, so we effectively ignore the y sensor.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
2.1. MORE FUN WITH MVNS

1
2 Now we will adopt a Bayesian approach and integrate out the unknown precisions, following
3 Main Section 3.4.3.3. That is, we compute
4 Z  Z 
5 p(z|D) ∝ p(z) p(Dx |z, λx )p(λx |z)dλx p(Dy |z, λy )p(λy |z)dλy (2.17)
6
7 We will use uninformative Jeffrey priors (Main Section 3.5.2) p(z) ∝ 1, p(λx |z) ∝ 1/λx and p(λy |z) ∝
8 1/λy . Since the x and y terms are symmetric, we will just focus on one of them. The key integral is
9 Z
10 I = p(Dx |z, λx )p(λx |z)dλx (2.18)
11  
Z
12 Nx Nx 2
∝ λ−1 x λx
Nx /2
exp − λx (x − z)2 − s λx dλx (2.19)
13 2 2 x
14
Exploiting the fact that Nx = 2 this simplifies to
15
Z
16
I = λx−1 λ1x exp(−λx [(x − z)2 + s2x ])dλx (2.20)
17
18
We recognize this as proportional to the integral of an unnormalized Gamma density
19
20 Ga(λ|a, b) ∝ λa−1 e−λb (2.21)
21
22 where a = 1 and b = (x − z)2 + s2x . Hence the integral is proportional to the normalizing constant of
23 the Gamma distribution, Γ(a)b−a , so we get
Z
24 −1
25 I ∝ p(Dx |z, λx )p(λx |z)dλx ∝ (x − z)2 + s2x (2.22)
26
27 and the posterior becomes
28 1 1
29 p(z|D) ∝ (2.23)
(x − z)2 + s2x (y − z)2 + s2y
30
31 The exact posterior is plotted in Figure 2.1(b). We see that it has two modes, one near x = 1.5 and
32 one near y = 3.5. These correspond to the beliefs that the x sensor is more reliable than the y one,
33 and vice versa. The weight of the first mode is larger, since the data from the x sensor agree more
34 with each other, so it seems slightly more likely that the x sensor is the reliable one. (They obviously
35 cannot both be reliable, since they disagree on the values that they are reporting.) However, the
36 Bayesian solution keeps open the possibility that the y sensor is the more reliable one; from two
37 measurements, we cannot tell, and choosing just the x sensor, as the plug-in approximation does,
38 results in overconfidence (a posterior that is too narrow).
39 So far, we have assumed the prior is conjugate to the likelihood, so we have been able to compute
40 the posterior analytically. However, this is rarely the case. A common alternative is to approximate
41 the integral using Monte Carlo sampling, as follows:
42 Z
43 p(z|D) ∝ p(z|D, θ)p(θ|D)dθ (2.24)
44
1X
45 ≈ p(z|D, θ s ) (2.25)
46
S s
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


18

1
0.8 1.5
2
0.7
3
0.6
4 1
0.5
5
0.4
6
0.3
7 0.5
0.2
8
0.1
9
0 0
−2 −1 0 1 2 3 4 5 6 −2 −1 0 1 2 3 4 5 6
10
11 (a) (b)
12
13 Figure 2.1: Posterior for z. (a) Plug-in approximation. (b) Exact posterior. Generated by sen-
14 sor_fusion_unknown_prec.ipynb.
15
16
17 where θ s ∼ p(θ|D). Note that p(z|D, θ s ) is conditionally Gaussian, and is easy to compute. So we
18 just need a way to draw samples from the parameter posterior, p(θ|D). We discuss suitable methods
19 for this in Main Chapter 11.
20
21
2.2 Google’s PageRank algorithm
22
23 In this section, we discuss Google’s PageRank algorithm, since it provides an interesting application
24 of Markov chain theory. PageRanke is one of the components used for ranking web page search
25 results. We sketch the basic idea below; see [Bryan06] for a more detailed explanation.
26
27
28
2.2.1 Retrieving relevant pages using inverted indices
29 We will treat the web as a giant directed graph, where nodes represent web pages (documents) and
30 edges represent hyper-links.1 We then perform a process called web crawling. We start at a few
31 designated root nodes, such as wikipedia.org, and then follows the links, storing all the pages that
32 we encounter, until we run out of time.
33 Next, all of the words in each web page are entered into a data structure called an inverted index.
34 That is, for each word, we store a list of the documents where this word occurs. At test time, when a
35 user enters a query, we can find potentially relevant pages as follows: for each word in the query, look
36 up all the documents containing each word, and intersect these lists. (We can get a more refined
37 search by storing the location of each word in each document, and then testing if the words in a
38 document occur in the same order as in the query.)
39 Let us give an example, from http://en.wikipedia.org/wiki/Inverted_index. Suppose we
40 have 3 documents, D0 = “it is what it is”, D1 = “what is it” and D2 = “it is a banana”. Then we can
41 create the following inverted index, where each pair represents a document and word location:
42
43 "a": {(2, 2)}
44
1. In 2008, Google said it had indexed 1 trillion (1012 ) unique URLs. If we assume there are about 10 URLs per page
45 (on average), this means there were about 100 billion unique web pages. Estimates for 2010 are about 121 billion
46 unique web pages. Source: https://bit.ly/2keQeyi
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
2.2. GOOGLE’S PAGERANK ALGORITHM

1
2 "banana": {(2, 3)}
3 "is": {(0, 1), (0, 4), (1, 1), (2, 1)}
4 "it": {(0, 0), (0, 3), (1, 2), (2, 0)}
5 "what": {(0, 2), (1, 0)}
6
For example, we see that the word “what” occurs in document 0 at location 2 (counting from 0), and
7
in document 1 at location 0. Suppose we search for “what is it”. If we ignore word order, we retrieve
8
the following documents:
9
10 {D0 , D1 } ∩ {D0 , D1 , D2 } ∩ {D0 , D1 , D2 } = {D0 , D1 } (2.26)
11
12 If we require that the word order matches, only document D1 would be returned. More generally, we
13 can allow out-of-order matches, but can give “bonus points” to documents whose word order matches
14 the query’s word order, or to other features, such as if the words occur in the title of a document.
15 We can then return the matching documents in decreasing order of their score/ relevance. This is
16 called document ranking.
17
18 2.2.2 The PageRank score
19
So far, we have described the standard process of information retrieval. But the link structure of the
20
web provides an additional source of information. The basic idea is that some web pages are more
21
authoritative than others, so these should be ranked higher (assuming they match the query). A web
22
page is a considered an authority if it is linked to by many other pages. But to protect against the
23
effect of so-called link farms, which are dummy pages which just link to a given site to boost its
24
apparent relevance, we will weight each incoming link by the source’s authority. Thus we get the
25
following recursive definition for the authoritativeness of page j, also called its PageRank:
26
X
27
πj = Aij πi (2.27)
28
i
29
30 where Aij is the probability of following a link from i to j. (The term “PageRank” is named after
31 Larry Page, one of Google’s co-founders.)
32 We recognize Equation (2.27) as the stationary distribution of a Markov chain. But how do we
33 define the transition matrix? In the simplest setting, we define Ai,: as a uniform distribution over all
34 states that i is connected to. However, to ensure the distribution is unique, we need to make the
35 chain into a regular chain. This can be done by allowing each state i to jump to any other state
36 (including itself) with some small probability. This effectively makes the transition matrix aperiodic
37 and fully connected (although the adjacency matrix Gij of the web itself is highly sparse).
38 We discuss efficient methods for computing the leading eigenvector of this giant matrix below.
39 Here we ignore computational issues, and just give some examples.
40 First, consider the small web in Figure 2.2. We find that the stationary distribution is
41
π = (0.3209, 0.1706, 0.1065, 0.1368, 0.0643, 0.2008) (2.28)
42
43 So a random surfer will visit site 1 about 32% of the time. We see that node 1 has a higher PageRank
44 than nodes 4 or 6, even though they all have the same number of in-links. This is because being
45 linked to from an influential node helps increase your PageRank score more than being linked to by
46 a less influential node.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


20

1
2
X1
3
0.30
4
5 0.25
6
X2
0.20
7
8 0.15

9 X3
0.10
10
0.05
11
12 X4 X5 X6 0.00
X1 X2 X3 X4 X5 X6
13
14 (a) (b)
15
16 Figure 2.2: (a) A very small world wide web. Generated by pagerank_small_plot_graph.ipynb (b) The
17 corresponding stationary distribution. Generated by pagerank_demo_small.ipynb.
18
19
20
21
22
23
24
25
26
27
28 (a) (b)
29
30 Figure 2.3: (a) Web graph of 500 sites rooted at www. harvard. edu . (b) Corresponding page rank vector.
31 Generated by pagerank_demo_harvard.ipynb.
32
33
34 As a slightly larger example, Figure 2.3(a) shows a web graph, derived from the root of harvard.edu.
35 Figure 2.3(b) shows the corresponding PageRank vector.
36
37 2.2.3 Efficiently computing the PageRank vector
38
Let Gij = 1 iff there is a link from j to i. Now imagine performing a random walk on this graph,
39
where at every time step, with probability p you follow one of the outlinks uniformly at random,
40
and with probability 1 − p you jump to a random node, again chosen uniformly at random. If there
41
are no outlinks, you just jump to a random page. (These random jumps, including self-transitions,
42
ensure the chain is irreducible (singly connected) and regular. Hence we can solve for its unique
43
stationary distribution using eigenvector methods.) This defines the following transition matrix:
44

45 pGij /cj + δ if cj 6= 0
Mij = (2.29)
46 1/n if cj = 0
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
2.2. GOOGLE’S PAGERANK ALGORITHM

1
2 where n is the number of nodes, δ =P (1 − p)/n is the probability of jumping from one page to another
3 without following a link and cj = i Gij represents the out-degree of page j. (If n = 4 · 109 and
4 p = 0.85, then δ = 3.75 · 10−11 .) Here M is a stochastic matrix in which columns sum to one. Note
5 that M = AT in our earlier notation.
6 We can represent the transition matrix compactly as follows. Define the diagonal matrix D with
7 entries
8 
1/cj if cj 6= 0
9 djj = (2.30)
0 if cj = 0
10
11
Define the vector z with components
12

13 δ if cj 6= 0
zj = (2.31)
14 1/n if cj = 0
15
16 Then we can rewrite Equation (2.29) as follows:
17
18 M = pGD + 1zT (2.32)
19
20
The matrix M is not sparse, but it is a rank one modification of a sparse matrix. Most of the elements
21
of M are equal to the small constant δ. Obviously these do not need to be stored explicitly.
22
Our goal is to solve v = Mv, where v = πT . One efficient method to find the leading eigenvector
23
of a large matrix is known as the power method. This simply consists of repeated matrix-vector
24
multiplication, followed by normalization:
25
v ∝ Mv = pGDv + 1zT v (2.33)
26
27
It is possible to implement the power method without using any matrix multiplications, by simply
28
sampling from the transition matrix and counting how often you visit each state. This is essentially a
29
Monte Carlo approximation to the sum implied by v = Mv. Applying this to the data in Figure 2.3(a)
30
yields the stationary distribution in Figure 2.3(b). This took 13 iterations to converge, starting from
31
a uniform distribution. To handle changing web structure, we can re-run this algorithm every day or
32
every week, starting v off at the old distribution; this is called warm starting [Langville06].
33
For details on how to perform this Monte Carlo power method in a parallel distributed computing
34
environment, see e.g., [Rajaraman10].
35
36
37
2.2.4 Web spam
38 PageRank is not foolproof. For example, consider the strategy adopted by JC Penney, a department
39 store in the USA. During the Christmas season of 2010, it planted many links to its home page on
40 1000s of irrelevant web pages, thus increasing its ranking on Google’s search engine [Segal11]. Even
41 though each of these source pages has low PageRank, there were so many of them that their effect
42 added up. Businesses call this search engine optimization; Google calls it web spam. When
43 Google was notified of this scam (by the New York Times), it manually downweighted JC Penney,
44 since such behavior violates Google’s code of conduct. The result was that JC Penney dropped from
45 rank 1 to rank 65, essentially making it disappear from view. Automatically detecting such scams
46 relies on various techniques which are beyond the scope of this chapter.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


22

1
2 2.2.5 Personalized PageRank
3
The PageRank algorithm computes a single global notion of importance of each web page. In some
4
cases, it is useful for each user to define his own notion of importance. The Personalized PageRank
5
algorithm (aka random walks with restart) computes a stationary distribution relative to node
6
k, by returning with some probability to a specific starting node k rather than a random node. The
7
corresponding stationary distribution, π k , gives a measure of how important each node is relative to
8
k. See [Lofgren2015] for details. (A similar system is used by Pinterest to infer the similarity of
9
one “pin” (bookmarked webpage) to another, as explained in [Eksombatchai2018]).
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
3 Bayesian statistics

3.1 Bayesian concept learning


In this section, we introduce Bayesian statistics using some simple examples inspired by Bayesian
models of human learning. This will let us get familiar with the key ideas without getting bogged
down by mathematical technalities.
Consider how a child learns the meaning of a word, such as “dog”. Typically the child’s parents
will point out positive examples of this concept, saying such things as, “look at the cute dog!”, or
“mind the doggy”, etc. The core challenge is to figure out what we mean by the concept “dog”, based
on a finite (and possibly quite small) number of such examples. Note that the parent is unlikely to
provide negative examples; for example, people do not usually say “look at that non-dog”. Negative
examples may be obtained during an active learning process (e.g., the child says “look at the dog”
and the parent says “that’s a cat, dear, not a dog”), but psychological research has shown that people
can learn concepts from positive examples alone [Xu07fei]. This means that standard supervised
learning methods cannot be used.
We formulate the problem by assuming the data that we see are generated by some hidden concept
h ∈ H, where H is called the hypothesis space. (We use the notation h rather than θ to be
consistent with the concept learning literature.) We then focus on computing the posterior p(h|D).
In Section 3.1.1, we assume the hypothesis space consists of a finite number of alternative hypotheses;
this will significantly simplify the computation of the posterior, allowing us to focus on the ideas and
not get too distracted by the math. In Section 3.1.2, we will extend this to continuous hypothesis
spaces. This will form the foundation for Bayesian inference of real-valued parameters for more
familiar probability models, such as the Bernoulli and the Gaussian, logistic regression, and deep
neural networks, that we discuss in later chapters. (See also [Jia2013] for an application of these
ideas to the problem of concept learning from images.)

3.1.1 Learning a discrete concept: the number game


Suppose that we are trying to learn some mathematical concept from a teacher who provides examples
of that concept. We assume that a concept is defined as the set of positive integers that belong to
its extension; for example, the concept “even number” is defined by heven = {2, 4, 6, . . .}, and the
concept “powers of two ” is defined by htwo = {2, 4, 8, 16, . . .}. For simplicity, we assume the range of
numbers is between 1 and 100.
For example, suppose we see one example, D = {16}. What other numbers do you think are
examples of this concept? 17? 6? 32? 99? It’s hard to tell with only one example, so your predictions
24

1
Examples
2
16 1
3
0.5
4 0
5
4 8 12 16 20 24 28 32 36 40 44 48 52 56 60 64 68 72 76 80 84 88 92 96 100
6 1
60
7 0.5

8 0

9 4 8 12 16 20 24 28 32 36 40 44 48 52 56 60 64 68 72 76 80 84 88 92 96 100

16 8 2 64 1
10
0.5
11 0
12
4 8 12 16 20 24 28 32 36 40 44 48 52 56 60 64 68 72 76 80 84 88 92 96 100
13
16 23 19 20 1
14 0.5

15 0

16 4 8 12 16 20 24 28 32 36 40 44 48 52 56 60 64 68 72 76 80 84 88 92 96 100

17 60 80 10 30 1
Figure 3.1: Empirical membership distribution in the numbers game, derived from predictions from 8
0.5
18
humans. First two rows: after 0 seeing D = {16} and D = {60}. This illustrates diffuse similarity. Third
19
row: after seeing D = {16, 8, 2, 4 64}.
8 12 16This
20 24 illustrates
28 32 36 40 44rule-like behavior
48 52 56 60 (powers
64 68 72 76 of 962).
80 84 88 92 100 Bottom row: after
20
seeing D = {16, 23,
60 19, 20}.
52 57 55 1 This illustrates focussed similarity (numbers near 20). From Figure 5.5 of
21 [Tenenbaum99]. Used with 0.5
kind permission of Josh Tenenbaum.
22 0

23 4 8 12 16 20 24 28 32 36 40 44 48 52 56 60 64 68 72 76 80 84 88 92 96 100

24 81 25 4 36 1
0.5
25 will be quite vague. Presumably
0
numbers that are similar in some sense to 16 are more likely. But
26 similar in what way? 17 is 4similar, because it is “close by”, 6 is similar because it has a digit in
8 12 16 20 24 28 32 36 40 44 48 52 56 60 64 68 72 76 80 84 88 92 96 100
27 common, 32 is similar because
81 98 86 93 1
it is also even and a power of 2, but 99 does not seem similar. Thus
28 some numbers are more 0.5 likely than others.
29 Now suppose I tell you0 that D = {2, 8, 16, 64} are positive examples. You may guess that the
30 hidden concept is “powers of 4two”. Given
8 12 16 your
20 24 28 32 36 beliefs
40 44 48 about
52 56 60 the
64 68 true
72 76 (but hidden)
80 84 88 92 96 100 concept, you may

31 confidently predict that y ∈ {2, 4, 8, 16, 32, 64} may also be generated in the future by the teacher.
32 This is an example of generalization, since we are making predictions about future data that we
33 have not seen.
34 Figure 3.1 gives an example of how humans perform at this task. Given a single example, such
35 as D = {16} or D = {60}, humans make fairly diffuse predictions over the other numbers that are
36 similar in magnitude. But when given several examples, such as D = {2, 8, 16, 64}, humans often find
37 an underlying pattern, and use this to make fairly precise predictions about which other numbers
38 might be part of the same concept, even if those other numbers are “far away”.
39 How can we explain this behavior and emulate it in a machine? The classic approach to the
40 problem of induction is to suppose we have a hypothesis space H of concepts (such as even numbers,
41 all numbers between 1 and 10, etc.), and then to identify the smallest subset of H that is consistent
42 with the observed data D; this is called the version space. As we see more examples, the version
43 space shrinks and we become increasingly certain about the underlying hypothesis [Mitchell97].
44 However, the version space theory cannot explain the human behavior we saw in Figure 3.1. For
45 example, after seeing D = {16, 8, 2, 64}, why do people choose the rule “powers of two” and not,
46 say, “all even numbers”, or “powers of two except for 32”, both of which are equally consistent with
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
3.1. BAYESIAN CONCEPT LEARNING

1
2 Examples

3 16 1
0.5
4
0
5
4 8 12 16 20 24 28 32 36 40 44 48 52 56 60 64 68 72 76 80 84 88 92 96 100
6
60 1
7 0.5

8 0

9 4 8 12 16 20 24 28 32 36 40 44 48 52 56 60 64 68 72 76 80 84 88 92 96 100
1
10 16 8 2 64
0.5
11
0
12
4 8 12 16 20 24 28 32 36 40 44 48 52 56 60 64 68 72 76 80 84 88 92 96 100
13
16 23 19 20 1
14 0.5

15 0

16 4 8 12 16 20 24 28 32 36 40 44 48 52 56 60 64 68 72 76 80 84 88 92 96 100

17 60 80 10 30 1

18
Figure 3.2: Posterior membership probabilities derived using the full hypothesis space. Compare to Figure 3.1.
0.5
0
The predictions of the Bayesian model are only plotted for those values for which human data is available;
19
this is why the top line looks sparser
4 8 12 than Figure
16 20 24 3.4.
28 32 36 From
40 44 48 52 Figure 5.672 of
56 60 64 68 76 [Tenenbaum99].
80 84 88 92 96 100 Used with kind
20
52 57 55 1
permission of Josh 60Tenenbaum.
21 0.5
0
22
4 8 12 16 20 24 28 32 36 40 44 48 52 56 60 64 68 72 76 80 84 88 92 96 100
23
the evidence? We81will now show how Bayesian inference can explain this behavior. The resulting
25 4 36 1
24
predictions are shown in 0.5Figure 3.2.
25 0
26 4 8 12 16 20 24 28 32 36 40 44 48 52 56 60 64 68 72 76 80 84 88 92 96 100
27
3.1.1.1 Likelihood
81 98 86 93 1
28 We must explain why people chose htwo and not, say, heven after seeing D = {16, 8, 2, 64}, given
0.5
0
29 that both hypotheses are consistent with the evidence. The key intuition is that we want to avoid
suspicious coincidences. For example, if the true concept was even numbers, it would be surprising
4 8 12 16 20 24 28 32 36 40 44 48 52 56 60 64 68 72 76 80 84 88 92 96 100
30
31 if we just happened to only see powers of two.
32 To formalize this, let us assume that the examples are sampled uniformly at random from the
33 extension of the concept. (Tenenbaum calls this the strong sampling assumption.) Given this
34 assumption, the probability of independently sampling N items (with replacement) from the unknown
35 concept h is given by
36
N
Y N
Y  N
37 1 1
p(D|h) = p(yn |h) = I (yn ∈ h) = I (D ∈ h) (3.1)
38
n=1 n=1
size(h) size(h)
39
40 where I (D ∈ h) is non zero iff all the data points lie in the support of h. This crucial equation
41 embodies what Tenenbaum calls the size principle, which means the model favors the simplest
42 (smallest) hypothesis consistent with the data. This is more commonly known as Occam’s razor.
43 To see how it works, let D = {16}. Then p(D|htwo ) = 1/6, since there are only 6 powers of two less
44 than 100, but p(D|heven ) = 1/50, since there are 50 even numbers. So the likelihood that h = htwo is
45 higher than if h = heven . After 4 examples, the likelihood of htwo is (1/6)4 = 7.7 × 10−4 , whereas
46 the likelihood of heven is (1/50)4 = 1.6 × 10−7 . This is a likelihood ratio of almost 5000:1 in
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


26

1
2 favor of htwo . This quantifies our earlier intuition that D = {16, 8, 2, 64} would be a very suspicious
3 coincidence if generated by heven .
4
5 3.1.1.2 Prior
6
7
In the Bayesian approach, we must specify a prior over unknowns, p(h), as well as the likelihood,
8
p(D|h). To see why this is useful, suppose D = {16, 8, 2, 64}. Given this data, the concept h0 =“powers
9
of two except 32” is more likely than h =“powers of two”, since h0 does not need to explain the
10
coincidence that 32 is missing from the set of examples. However, the hypothesis h0 =“powers of
11
two except 32” seems “conceptually unnatural”. We can capture such intuition by assigning low
12
prior probability to unnatural concepts. Of course, your prior might be different than mine. This
13
subjective aspect of Bayesian reasoning is a source of much controversy, since it means, for example,
14
that a child and a math professor will reach different answers.1
15
Although the subjectivity of the prior is controversial, it is actually quite useful. If you are told
16
the numbers are from some arithmetic rule, then given 1200, 1500, 900 and 1400, you may think 400
17
is likely but 1183 is unlikely. But if you are told that the numbers are examples of healthy cholesterol
18
levels, you would probably think 400 is unlikely and 1183 is likely, since you assume that healthy
19
levels lie within some range. Thus we see that the prior is the mechanism by which background
20
knowledge can be brought to bear on a problem. Without this, rapid learning (i.e., from small
21
samples sizes) is impossible.
22
So, what prior should we use? We will initially consider 30 simple arithmetical concepts, such
23
as “even numbers”, “odd numbers”, “prime numbers”, or “numbers ending in 9”. We could use a
24
uniform prior over these concepts; however, for illustration purposes, we make the concepts even and
25
odd more likely apriori, and use a uniform prior over the others. We also include two “unnatural”
26
concepts, namely “powers of 2, plus 37” and “powers of 2, except 32”, but give them low prior weight.
27
See Figure 3.3a(bottom row) for a plot of this prior.
28
In addition to “rule-like” hypotheses, we consider the set of intervals between n and m for
29
1 ≤ n, m ≤ 100. This allows us to capture concepts based on being “close to” some number, rather
30
than satisfying some more abstract property. We put a uniform prior over the intervals.
31
We can combine these two priors by using a mixture distribution, as follows:
32
p(h) = πUnif(h|rules) + (1 − π)Unif(h|intervals) (3.2)
33
34 where 0 < π < 1 is the mixture weight assigned to the rules prior, and Unif(h|S) is the uniform
35 distribution over the set S.
36
37
3.1.1.3 Posterior
38
39 The posterior is simply the likelihood times the prior, normalized: p(h|D) ∝ p(D|h)p(h). Figure 3.3a
40 plots the prior, likelihood and posterior after seeing D = {16}. (In this figure, we only consider
41 rule-like hypotheses, not intervals, for simplicity.) We see that the posterior is a combination of
42 prior and likelihood. In the case of most of the concepts, the prior is uniform, so the posterior is
43
44
1. A child and a math professor presumably not only have different priors, but also different hypothesis spaces. However,
we can finesse that by defining the hypothesis space of the child and the math professor to be the same, and then
45 setting the child’s prior weight to be zero on certain “advanced” concepts. Thus there is no sharp distinction between
46 the prior and the hypothesis space.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
3.1. BAYESIAN CONCEPT LEARNING

1 data = {16} data = {16,8,2,64}


2 even even
odd odd
squares squares
3 mult of 3
mult of 4
mult of 3
mult of 4
mult of 5 mult of 5
4 mult of 6
mult of 7
mult of 6
mult of 7
mult of 8 mult of 8
mult of 9 mult of 9
5 mult of 10 mult of 10
ends in 1 ends in 1
ends in 2 ends in 2
6 ends in 3 ends in 3
ends in 4 ends in 4
ends in 5 ends in 5
7 ends in 6 ends in 6
ends in 7 ends in 7
ends in 8 ends in 8
8 ends in 9 ends in 9
powers of 2 powers of 2
powers of 3 powers of 3
9 powers of 4
powers of 5
powers of 4
powers of 5
powers of 6 powers of 6
10 powers of 7
powers of 8
powers of 7
powers of 8
powers of 9 powers of 9
11 powers of 10
all
powers of 10
all
powers of 2 +{37} powers of 2 +{37}
powers of 2 -{32} powers of 2 -{32}
12
0.00 0.02 0.04 0.06 0.08 0.10 0.12 0.00 0.05 0.10 0.15 0.20 0.25 0.30 0.350.0 0.1 0.2 0.3 0.00 0.02 0.04 0.06 0.08 0.10 0.12 0.00 0.25 0.50 0.75 1.00 1.25 1.50 0.0 0.2 0.4 0.6 0.8 1.0
1e 3
13 prior lik post prior lik post

14 (a) (b)
15
16 Figure 3.3: (a) Prior, likelihood and posterior for the model when the data is D = {16}. (b) Results when
17 D = {2, 8, 16, 64}. Adapted from [Tenenbaum99]. Generated by numbers_game.ipynb.
18
19
20 proportional to the likelihood. However, the “unnatural” concepts of “powers of 2, plus 37” and
21 “powers of 2, except 32” have low posterior support, despite having high likelihood, due to the low
22 prior. Conversely, the concept of odd numbers has low posterior support, despite having a high prior,
23 due to the low likelihood.
24 Figure 3.3b plots the prior, likelihood and posterior after seeing D = {16, 8, 2, 64}. Now the
25 likelihood is much more peaked on the powers of two concept, so this dominates the posterior.
26 Essentially the learner has an “aha” moment, and figures out the true concept.2 This example also
27 illustrates why we need the low prior on the unnatural concepts, otherwise we would have overfit the
28 data and picked “powers of 2, except for 32”.
29
30
3.1.1.4 Posterior predictive
31
32 The posterior over hypotheses is our internal belief state about the world. The way to test if our
33 beliefs are justified is to use them to predict objectively observable quantities (this is the basis of the
34 scientific method). To do this, we compute the posterior predictive distribution over possible
35 future observations:
36 X
37 p(y|D) = p(y|h)p(h|D) (3.3)
38 h
39
This is called Bayes model averaging [Hoeting99]. Each term is just a weighted average of the
40
predictions of each individual hypothesis. This is illustrated in Figure 3.4. The dots at the bottom
41
show the predictions from each hypothesis; the vertical curve on the right shows the weight associated
42
with each hypothesis. If we multiply each row by its weight and add up, we get the distribution at
43
the top.
44
45
2. Humans have a natural desire to figure things out; Alison Gopnik, in her paper “Explanation as orgasm” [Gopnik98],
46 argued that evolution has ensured that we enjoy reducing our posterior uncertainty.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


28

1
1.0
2
0.8
3 0.6
0.4
4
0.2
5 0.0
4 8 12 16 20 24 28 32 36 40 44 48 52 56 60 64 68 72 76 80 84 88 92 96 100
6 powers of 4

7 powers of 2

8 ends in 6

9 squares

10 even

11 mult of 8

12 mult of 4

13 all

14 powers of 2 -{32}

15 powers of 2 +{37}
0.0 0.2
16 p(h|16)

17
Figure 3.4: Posterior over hypotheses, and the induced posterior over membership, after seeing one example,
18
D = {16}. A dot means this number is consistent with this hypothesis. The graph p(h|D) on the right is the
19
weight given to hypothesis h. By taking a weighed sum of dots, we get p(y ∈ h|D) (top). Adapted from Figure
20
2.9 of [Tenenbaum99]. Generated by numbers_game.ipynb.
21
22
23 3.1.1.5 MAP, MLE, and the plugin approximation
24
25
As the amount of data increases, the posterior will (usually) become concentrated around a single
26
point, namely the posterior mode, as we saw in Figure 3.3 (top right plot). The posterior mode is
27
defined as the hypothesis with maximum posterior probability:
28 hmap , argmax p(h|D) (3.4)
29 h
30 This is also called the maximum a posterior or MAP estimate.
31 We can compute the MAP estimate by solving the following optimization problem:
32
hmap = argmax p(h|D) = argmax log p(D|h) + log p(h) (3.5)
33 h h
34
The first term, log p(D|h), is the log of the likelihood, p(D|h). The second term, log p(h), is the
35
log of the prior. As the data set increases in size, the log likelihood grows in magnitude, but the
36
log prior term remains constant. We thus say that the likelihood overwhelms the prior. In this
37
context, a reasonable approximation to the MAP estimate is to ignore the prior term, and just pick
38
the maximum likelihood estimate or MLE, which is defined as
39
40 N
X
41 hmle , argmax p(D|h) = argmax log p(D|h) = argmax log p(yn |h) (3.6)
h h h n=1
42
43 Suppose we approximate the posterior by a single point estimate ĥ, might be the MAP estimate
44 or MLE. We can represent this degenerate distribution as a single point mass
45  
46 p(h|D) ≈ I h = ĥ (3.7)
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
3.1. BAYESIAN CONCEPT LEARNING

1
2 where I () is the indicator function. The corresponding posterior predictive distribution becomes
3 X  
4 p(y|D) ≈ p(y|h)I h = ĥ = p(y|ĥ) (3.8)
5 h
6
This is called a plug-in approximation, and is very widely used, due to its simplicity.
7
Although the plug-in approximation is simple, it behaves in a qualitatively inferior way than the
8
fully Bayesian approach when the dataset is small. In the Bayesian approach, we start with broad
9
predictions, and then become more precise in our forecasts as we see more data, which makes intuitive
10
sense. For example, given D = {16}, there are many hypotheses with non-negligible posterior mass,
11
so the predicted support over the integers is broad. However, when we see D = {16, 8, 2, 64}, the
12
posterior concentrates its mass on one or two specific hypotheses, so the overall predicted support
13
becomes more focused. By contrast, the MLE picks the minimal consistent hypothesis, and predicts
14
the future using that single model. For example, if we see D = {16}, we compute hmle to be “all
15
powers of 4” (or the interval hypothesis h = {16}), and the resulting plugin approximation only
16
predicts {4, 16, 64} as having non-zero probability. This is an example of overfitting, where we pay
17
too much attention to the specific data that we saw in training, and fail to generalise correctly to
18
novel examples. When we observe more data, the MLE will be forced to pick a broader hypothesis
19
to explain all the data. For example, if we D = {16, 8, 2, 64}, the MLE broadens to become “all
20
powers of two”, similar to the Bayesian approach. Thus in the limit of infinite data, both approaches
21
converge to the same predictions. However, in the small sample regime, the fully Bayesian approach,
22
in which we consider multiple hypotheses, will give better (less over confident) predictions.
23
24
25
3.1.2 Learning a continuous concept: the healthy levels game
26 The number game involved observing a series of discrete variables, and inferring a distribution over
27 another discrete variable from a finite hypothesis space. This made the computations particularly
28 simple: we just needed to sum, multiply and divide. However, in many applications, the variables
29 that we observe are real-valued continuous quantities. More importantly, the unknown parameters
30 are also usually continuous, so the hypothesis space becomes (some subset) of RK , where K is
31 the number of parameters. This complicates the mathematics, since we have to replace sums with
32 integrals. However, the basic ideas are the same.
33 We illustrate these ideas by considering another example of concept learning called the healthy
34 levels game, also due to Tenenbaum. The idea is this: we measure two continuous variables,
35 representing the cholesterol and insulin levels of some randomly chosen healthy patients. We would
36 like to know what range of values correspond to a healthy range. As in the numbers game, the
37 challenge is to learn the concept from positive data alone.
38 Let our hypothesis space be axis-parallel rectangles in the plane, as in Figure 3.5. This is
39 a classic example which has been widely studied in machine learning [Mitchell97]. It is also a
40 reasonable assumption for the healthy levels game, since we know (from prior domain knowledge)
41 that healthy levels of both insulin and cholesterol must fall between (unknown) lower and upper
42 bounds. We can represent a rectangle hypothesis as h = (`1 , `2 , s1 , s2 ), where `j ∈ (−∞, ∞) are
43 the coordinates (locations) of the lower left corner, and sj ∈ [0, ∞) are the lengths of the two sides.
44 Hence the hypothesis space is H = R2 × R2+ , where R≥0 is the set of non-negative reals.
45 More complex concepts might require discontinuous regions of space to represent them. Alterna-
46 tively, we might want to use latent rectangular regions to represent more complex, high dimensional
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


30

1
2 samples from p(h|D1 : 3), uninfPrior samples from p(h|D1 : 12), uninfPrior
0.8 0.8
3
4 0.7 0.7
5
6 0.6 0.6
7
0.5 0.5
8
9 0.4 0.4
10
11 0.3 0.3
0.2 0.3 0.4 0.5 0.6 0.7 0.2 0.3 0.4 0.5 0.6 0.7
12
13 (a) (b)
14
15 Figure 3.5: Samples from the posterior in the “healthy levels” game. The axes represent “cholesterol level”
16 and “insulin level”. (a) Given a small number of positive examples (represented by 3 red crosses), there
17 is a lot of uncertainty about the true extent of the rectangle. (b) Given enough data, the smallest enclos-
18 ing rectangle (which is the maximum likelihood hypothesis) becomes the most probable, although there are
19
many other similar hypotheses that are almost as probable. Adapted from [Tenenbaum99]. Generated by
healthy_levels_plots.ipynb.
20
21
22
23 concepts [Li2019]. The question of where the hypothesis space comes from is a very interesting one,
24 but is beyond the scope of this chapter. (One approach is to use hierarchical Bayesian models, as
25 discussed in [Tenenbaum11].)
26
27 3.1.2.1 Likelihood
28
29 We assume points are sampled uniformly at random from the support of the rectangle. To simplify
30 the analysis, let us first consider the case of one-dimensional “rectangles”, i.e., lines. In the 1d case,
31 the likelihood is p(D|`, s) = (1/s)N if all points are inside the interval, otherwise it is 0. Hence
32  −N
s if min(D) ≥ ` and max(D) ≤ ` + s
33 p(D|`, s) = (3.9)
34
0 otherwise
35
To generalize this to 2d, we assume the observed features are conditionally independent given the
36
hypothesis. Hence the 2d likelihood becomes
37
38
p(D|h) = p(D1 |`1 , s1 )p(D2 |`2 , s2 ) (3.10)
39
40 where Dj = {ynj : n = 1 : N } are the observations for dimension (feature) j = 1, 2.
41
42
3.1.2.2 Prior
43
44 For simplicity, let us assume the prior factorizes, i.e., p(h) = p(`1 )p(`2 )p(s1 )p(s2 ). We will use
45 uninformative priors for each of these terms. As we explain in Main Section 3.5, this means we
46 should use a prior of the form p(h) ∝ s11 s12 .
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
3.1. BAYESIAN CONCEPT LEARNING

1
2 3.1.2.3 Posterior
3
The posterior is given by
4
5 1 1
p(`1 , `2 , s1 , s2 |D) ∝ p(D1 |`1 , s1 )p(D2 |`2 , s2 ) (3.11)
6 s1 s2
7
8 We can compute this numerically by discretizing R4 into a 4d grid, evaluating the numerator pointwise,
9 and normalizing.
10 Since visualizing a 4d distribution is difficult, we instead draw posterior samples from it,
11 hs ∼ p(h|D), and visualize them as rectangles. In Figure 3.5(a), we show some samples when the
12 number N of observed data points is small — we are uncertain about the right hypothesis. In
13 Figure 3.5(b), we see that for larger N , the samples concentrate on the observed data.
14
15 3.1.2.4 Posterior predictive distribution
16
17
We now consider how to predict which data points we expect to see in the future, given the data we
18
have seen so far. In particular, we want to know how likely it is that we will see any point y ∈ R2 .
19
Let us define yjmin = minn ynj , yjmax = maxn ynj , and rj = yjmax − yjmin . Then one can show that
20
the posterior predictive distributxion is given by
21  N −1
1
22 p(y|D) = (3.12)
23
(1 + d(y1 )/r1 )(1 + d(y2 )/r2 )
24
where d(yj ) = 0 if yjmin ≤ yj ≤ yjmax , and otherwise d(yj ) is the distance to the nearest data point
25
along dimension j. Thus p(y|D) = 1 if y is inside the support of the training data; if y is outside the
26
support, the probability density drops off, at a rate that depends on N .
27
Note that if N = 1, the predictive distribution is undefined. This is because we cannot infer the
28
extent of a 2d rectangle from just one data point (unless we use a stronger prior).
29
In Figure 3.6(a), we plot the posterior predictive distribution when we have just seen N = 3
30
examples; we see that there is a broad generalization gradient, which extends further along the vertical
31
dimension than the horizontal direction. This is because the data has a broader vertical spread than
32
horizontal. In other words, if we have seen a large range in one dimension, we have evidence that the
33
rectangle is quite large in that dimension, but otherwise we prefer compact hypotheses, as follows
34
from the size principle.
35
In Figure 3.6(b), we plot the distribution for N = 12. We see it is focused on the smallest consistent
36
hypothesis, since the size principle exponentially down-weights hypothesis which are larger than
37
necessary.
38
39
3.1.2.5 Plugin approximation
40
41 Now suppose we use a plug-in approximation to the posterior predictive, p(y|D) ≈ p(y|θ̂), where θ̂
42 is the MLE or MAP estimate, analogous to the discussion in Section 3.1.1.5. In Figure 3.6(c-d), we
43 show the behavior of this approximation. In both cases, it predicts the smallest enclosing rectangle,
44 since that is the one with maximum likelihood. However, this does not extrapolate beyond the range
45 of the observed data. We also see that initially the predictions are narrower, since very little data
46 has been observed, but that the predictions become broader with more data. By contrast, in the
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


32

1
2 Bayes predictive, n=3, uninfPrior Bayes predictive, n=12, uninfPrior
1.0 1.0
3
4 0.8 0.8
5
6 0.6 0.6
7
0.4 0.4
8
9 0.2 0.2
10
11 0.0 0.0
0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0
12
13 (a) (b)
14
MLE predictive, n=3 MLE predictive, n=12
15 1.0 1.0
16
17 0.8 0.8

18
0.6 0.6
19
20 0.4 0.4
21
22 0.2 0.2
23
0.0 0.0
24 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0
25
26 (c) (d)
27
Figure 3.6: Posterior predictive distribution for the healthy levels game. Red crosses are observed data points.
28
Left column: N = 3. Right column: N = 12. First row: Bayesian prediction. Second row: Plug-in prediction
29
using MLE (smallest enclosing rectangle). We see that the Bayesian prediction goes from uncertain to certain
30 as we learn more about the concept given more data, whereas the plug-in prediction goes from narrow to broad,
31 as it is forced to generalize when it sees more data. However, both converge to the same answer. Adapted
32 from [Tenenbaum99]. Generated by healthy_levels_plot.ipynb.
33
34
35
Bayesian approach, the initial predictions are broad, since there is a lot of uncertainty, but become
36
narrower with more data. In the limit of large data, both methods converge to the same predictions.
37
38
39 3.2 Informative priors
40
41 When we have very little data, it is important to choose an informative prior.
42 For example, consider the classic taxicab problem [Jaynes03]: you arrive in a new city, and see
43 a taxi numbered t = 27, and you want to infer the total number T of taxis in the city. We will use
44 a uniform likelihood, p(t|T ) = T1 I (T ≥ t), since we assume that we could have observed any taxi
45 number up the maximum value T . The MLE estimate of T is T̂mle = t = 27. But this does not seem
46 reasonable. Instead, most people would guess that T ∼ 2 × 27 = 54, on the assumption that if the
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
3.2. INFORMATIVE PRIORS

1
Life Spans Life Spans Life Spans
Movie Runtimes Life Movie
Movie Runtimes
Movie Spans Runtimes
Movie Grosses
Grosses Movie Grosses
MoviePoems
Runtimes Poems Poems
Movie Grosses Poems Representatives Pharaohs Cakes
2

Probability

Probability
Probability

Probability
Probability
3
4
5 0 0100 40 0200
40 80 012040 080 120 80 100
12000 0 80 100
200
40 300 200
0600 300
120 00 0
600
100
500 300
0
200
1000 600
500
0 0
1000
300 500
600 1000
0 500 1000 0 30 60 0 50 100 0 60 120
ttotal ttotal t t t tt t ttotal tt ttotal t t ttotal t ttotal ttotal t
total total total total
totaltotal total
total total total total total
6
240 240 240 240
160 160 160 160 80 60 160
7 200 200 200 200
200 200 200 200

8 180 180 180 180


120 120 120 120 60 45 120

Predicted ttotal

Predicted ttotal

total
Predicted ttotal
Predicted ttotal

Predicted ttotal

150 150 150 150


150 150 150 150

Predicted t
9
120 120 120 120
80 80 80 80 40 30 80
100 100 100 100
100 100 100 100
10
50 60 60 50
50 60 60
40 40 50 40 40 20 15 40
11 50 50 50 50

12 0 0 0 0 0 00 0 0 00 0 0 0 0 0 0 0 0
0 50 100
0 50 0 100060 500120 100
6000 50
050
120 100
100
60
0 00 100
120
50 060
40 0120
50 80 100
400 500
80 100
40 080 40 80 0 20 40 0 15 30 0 40 80
13 t t t t t tt t t tt t t t t t t t t

14
15
Figure 3.7: Top: empirical distribution of various durational quantities. Bottom: predicted total duration as
a function of observed duration, p(T |t). Dots are observed median responses of people. Solid line: Bayesian
16
prediction using informed prior. Dotted line: Bayesian prediction using uninformative prior. From Figure 2a
17
of [Griffiths06]. Used with kind permission of Tom Griffiths.
18
19
20
21
taxi you saw was a uniform random sample between 0 and T , then it would probably be close to the
22
middle of the distribution.
23
In general, the conclusions we draw about T will depend strongly on our prior assumptions about
24
what values of T are likely. In the sections below, we discuss different possible informative priors for
25
different problem domains; our presentation is based on [Griffiths06].
26
Once we have chosen a prior, we can compute the posterior as follows:
27
p(t|T )p(T )
28 p(T |t) = (3.13)
p(t)
29
30 where
31 Z ∞ Z ∞
p(T )
32
p(t) = p(t|T )p(T )dT = dT (3.14)
33 0 t T
34
35 We will use the posterior median as a point estimate for T . This is the value T̂ such that
36 Z ∞
37 p(T ≥ T̂ |t) = p(T |t)dT = 0.5 (3.15)
38 T̂

39
Note that the posterior median is often a better summary of the posterior than the posterior mode,
40
for reasons explained in Main Section 7.4.1.
41
42
43
3.2.1 Domain specific priors
44 At the top of Figure 3.7, we show some histograms representing the empirical distribution of various
45 kinds of scalar quantities, specifically: the number of years people live, the number of minutes a
46 movie lasts, the amount of money made (in 1000s of US dollars) by movies, the number of lines of a
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


34

1
2 Gaussian prior Power−law prior Erlang prior

Probability
4
5
6 0 50 100 0 50 100 0 50 100
t t t
7 total total total

8 60 60 60
µ=30 γ=1 β=30
9 µ=25 γ=1.5 β=18
45 45 45
µ=15 γ=2 β=10

total
10

Predicted t
11
30 30 30
12
13 15 15 15
14
15 0 0 0
0 15 30 0 15 30 0 15 30
16 t t t

17
Figure 3.8: Top: three different prior distributions, for three different parameter values. Bottom: corresponding
18
predictive distributions. From Figure 1 of [Griffiths06]. Used with kind permission of Tom Griffiths.
19
20
21
22 poem, and the number of years someone serves in the US house of Representatives. (The sources for
23 these data is listed in [Griffiths06].)
24 At the bottom, we plot p(T |t) as a function of t for each of these domains. The solid dots are the
25 median responses of a group of people when asked to predict T from a single observation t. The
26 solid line is the posterior median computed by a Bayesian model using a domain-appropriate prior
27 (details below). The dotted line is the posterior median computed by a Bayesian model using an
28 uninformative 1/T prior. We see a remarkable correspondence between people and the informed
29 Bayesian model. This suggests that people can implicitly use an appropriate kind of prior for a
30 wide range of problems, as argued in [Griffiths06]. In the sections below, we discuss some suitable
31 parametric priors which catpure this behavior. In [Griffiths06], they also consider some datasets
32 that can only be well-modeled by a non-parametric prior. Bayesian inference works well in that case,
33 too, but we omit this for simplicity.
34
35
3.2.2 Gaussian prior
36
37 Looking at Figure 3.7(a-b), it seems clear that life-spans and movie run-times can be well-modeled
38 by a Gaussian, N (T |µ, σ 2 ). Unfortunately, we cannot compute the posterior median in closed form if
39 we use a Gaussian prior, but we can still evaluate it numerically, by solving a 1d integration problem.
40 The resulting plot of T̂ (t) vs t is shown in Figure 3.8 (bottom left). For values of t much less than the
41 prior mean, µ, the predicted value of T is about equal to µ, so the left part of the curve is flat. For
42 values of t much greater than µ, the predicted value converges to a line slightly above the diagonal,
43 i.e., T̂ (t) = t +  for some small (and decreasing)  > 0.
44 To see why this behavior makes intuitive sense, consider encountering a man at age 18, 39 or 51:
45 in all cases, a reasonable prediction is that he will live to about µ = 75 years. But now imagine
46 meeting a man at age 80: we probably would not expect him to live much longer, so we predict
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
3.2. INFORMATIVE PRIORS

1
2 T̂ (80) ≈ 80 + .
3
4
3.2.3 Power-law prior
5
6 Looking at Figure 3.7(c-d), it seems clear that movie grosses and poem length can be modeled by
7 a power law distribution of the form p(T ) ∝ T −γ for γ > 0. (If γ > 1, this is called a Pareto
8 distribution, see Main Section 2.2.3.5.) Power-laws are characterized by having very long tails.
9 This captures the fact that most movies make very little money, but a few blockbusters make a lot.
10 The number of lines in various poems also has this shape, since there are a few epic poems, such
11 as Homer’s Odyssey, but most are short, like haikus. Wealth has a similarly skewed distribution in
12 many countries, especially in plutocracies such as the USA (see e.g., inequality.org).
13 In the case of a power-law prior, p(T ) ∝ T −γ , we can compute the posterior median analytically.
14 We have
15 Z ∞
1 1 −γ
16 p(t) ∝ T −(γ+1) dT = − T −γ |∞
t = t (3.16)
17 t γ γ
18
Hence the posterior becomes
19
20 T −(γ+1) γtγ
21 p(T |t) = 1 −γ = γ+1 (3.17)
γt
T
22
23
for values of T ≥ t. We can derive the posterior median as follows:
24
Z ∞  γ  γ
25 γtγ t ∞ t
26 p(T > TM |t) = γ+1
dT = − TM
= (3.18)
TM T T TM
27
28 Solving for TM such that P (T > TM |t) = 0.5 gives TM = 21/γ t.
29 This is plotted in Figure 3.8 (bottom middle). We see that the predicted duration is some constant
30 multiple of the observed duration. For the particular value of γ that best fits the empirical distribution
31 of movie grosses, the optimal prediction is about 50% larger than the observed quantity. So if we
32 observe that a movie has made $40M to date, we predict that it will make $60M in total.
33 As Griffiths and Tenenbaum point out, this rule is inappropriate for quantities that follow a
34 Gaussian prior, such as people’s ages. As they write, “Upon meeting a 10-year-old girl and her
35 75-year-old grandfather, we would never predict that the girl will live a total of 15 years (1.5 × 10)
36 and that the grandfather will live to be 112 (1.5 × 75).” This shows that people implicitly know what
37 kind of prior to use when solving prediction problems of this kind.
38
39
40
3.2.4 Erlang prior
41 Looking at Figure 3.7(e), it seems clear that the number of years a US Representative is approximately
42 modeled by a gamma distribution (Main Section 2.2.3.1). Griffiths and Tenenbaum use a special
43 case of the Gamma distributon, where the shape parameter is a = 2; this is known as the Erlang
44 distribution:
45
46 p(T ) = Ga(T |2, 1/β) ∝ T e−T /β (3.19)
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


36

1
2 For the Erlang prior, we can also compute the posterior median analytically. We have
3 Z ∞
4 p(t) ∝ exp(−T /β)dT = −β exp(−T /β)|∞ t = β exp(−t/β) (3.20)
5 t
6
so the posterior has the form
7
8 exp(−T /β) 1
9 p(T |t) = = exp(−(T − t)/β) (3.21)
β exp(−t/β) β
10
11 for values of T ≥ t. We can derive the posterior median as follows:
12 Z ∞
1
13 p(T > TM |t) = exp(−(T − t)/β)dT = − exp(−(T − t)/β)|∞ TM = exp(−(TM − t)/β) (3.22)
14 TM β
15
16
Solving for TM such that p(T > TM |t) = 0.5 gives TM = t + β log 2.
17
This is plotted in Figure 3.8 (bottom right). We see that the best guess is simply the observed
18
value plus a constant, where the constant reflects the average term in office.
19
20 3.3 Tweedie’s formula (Empirical Bayes without estimating the prior)
21
22 In this section, we present Tweedie’s formula [Efron2011] which is a way to estimate the posterior
23 of a quantity without knowing the prior. Instead, we replace the prior with an empirical estimate of
24 the score function, which is the derivative of the log marginal density, as we explain below. This is
25 useful since it allows us to estimate a latent quantity from noisy observations without ever observing
26 the latent quantity itself.
27 Consider the case of a scalar natural parameter η with prior g(η) and exponential family likelihood
28
29
30 y|η ∼ fη (y) = eηy−ψ(η) f0 (y) (3.23)
31
32
where ψ(η) is the cumulant generating function (cgf) that makes fη (y) integrate to 1, and f0 (y) is
33
the density when η = 0. For example, in the case of a 1d Gaussian with fixed variance σ 2 , we have
34
η = µ/σ 2 , ψ(η) = 12 σ 2 η 2 , and f0 (y) = N (y|0, σ 2 ).
35
Bayes rule gives the following posterior
36
fη (y)g(η)
37 g(η|y) = (3.24)
f (y)
38 Z
39 f (y) = fη (y)g(η)dη (3.25)
40 Y
41
Plugging in Equation (3.23) we get
42
43
g(η|y) = eyη−λ(y) [g(η)e−ψ(y) ] (3.26)
44  
f (y)
45 λ(y) = log (3.27)
46 f0 (y)
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
3.3. TWEEDIE’S FORMULA (EMPIRICAL BAYES WITHOUT ESTIMATING THE PRIOR)

1
2 So we see that the posterior is an exponential family with canonical parameter y and cgf λ(y). Letting
3
4
5 `(y) = log f (y), `0 (y) = log f0 (y) (3.28)
6
7
we can therefore derive the posterior moments as follows:
8
E [η|y] = λ0 (y) = `0 (y) − `00 (y), V [η|y] = λ00 (y) = `00 (y) − `000 (y) (3.29)
9
10 For the Gaussian case we have `0 (y) = −y 2 /(2σ 2 2), so
11
12 E [µ|y] = y + σ 2 `0 (y), V [µ|y] = σ 2 (1 + σ 2 `00 (y)) (3.30)
13
14 If we plug in an empirical estimate of the score function `0 (y), we can compute an empirical Bayes
15 estimate of the posterior mean without having to estimate the prior.
16 We can extend this to the multivariate Gaussian case as shown in [Raphan2007]. In particular,
17 suppose the likelihood has the form
18
19 p(y|µ) = N (y|µ, Σ) (3.31)
20
21
with marginal likelihood
22
Z
23 p(y) = p(µ)p(y|µ)dµ (3.32)
24
25 Since ∇y p(y|x) = p(y|x)∇y log p(y|x) we have
26 R
27 ∇y p(y) p(µ)∇y p(y|µ)dµ
= (3.33)
28 p(y) p(y)
R
29
p(µ)Σ−1 p(y|µ)(µ − y)dµ
30 = (3.34)
p(y)
31 Z
−1
32 =Σ p(µ|y)(µ − y)dµ (3.35)
33
34 = Σ−1 [E [µ|y] − y] (3.36)
35
36 and hence
37
38
E [µ|y] = y + Σ∇y log p(y) (3.37)
39
For high dimensional signals, we can use neural networks to approximate the score function, as we
40
discuss in Main Section 24.3.
41
42
43
44
45
46
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


4 Graphical models

4.1 More examples of DGMs


4.1.1 Water sprinkler
Suppose we want to model the dependencies between 5 random variables: C (whether it is cloudy
season or not), R (whether it is raining or not), S (whether the water sprinkler is on or not), W
(whether the grass is wet or not), and L (whether the grass is slippery or not). We know that the
cloudy season makes rain more likely, so we add a C → R arc. We know that the cloudy season
makes turning on a water sprinkler less likely, so we add a C → S arc. We know that either rain or
sprinklers can cause the grass to get wet, so we add S → W and R → W edges. Finally, we know
that wet grass can be slippery, so we add a W → L edge. See Figure 4.1 for the resulting DAG.
Formally, this defines the following joint distribution:

S )p(W |S, R, C)p(L|W, 


p(C, S, R, W, L) = p(C)p(S|C)p(R|C,  S , R, C) (4.1)

where we strike through terms that are not needed due to the conditional independence properties of
the model.
In addition to the graph structure, we need to specify the CPDs. For discrete random variables, we
can represent the CPD as a table, which means we have a separate row (i.e., a separate categorical
distribution) for each conditioning case, i.e., for each combination of parent values. This is known
as a conditional probability table or CPT. We can represent the i’th CPT as a tensor

θijk , p(xi = k|xpa(i) = j) (4.2)


PKi
Thus θi is a row stochastic matrix, that satisfies the properties 0 ≤ θijk ≤ 1 and k=1 θijk = 1
for each row j. Here i indexes nodes, i ∈ [NG ]; k indexes node states, k ∈ [Ki ], Q where Ki is the
number of states for node i; and j indexes joint parent states, j ∈ [Ji ], where Ji = p∈pa(i) Kp . For
example, consider the wet grass node in Figure 4.1. If all nodes are binary, we can represent its CPT
by the table of numbers shown in Figure 4.1(right).
The number of parameters in a CPT is O(K p+1 ), where K is the number of states per node, and
p is the number of parents. Later we will consider more parsimonious representations, with fewer
learnable parameters.
Given the model, we can use it to answer probabilistic queries. For example, one can show (using
the code at sprinkler_pgm.ipynb) that p(R = 1) = 0.5, which means the probability it rained
(before we collect any data) is 50%. This is consistent with the CPT for that node. Now suppose
40

1
2
cloudy
3
4
5
6
7 rain sprinkler
8
9
10
11 wet S R p(W = 0) p(W = 1)
0 0 1.0 0.0
12
1 0 0.1 0.9
13 0 1 0.1 0.9
14 1 1 0.01 0.99
15 slippery
16
17
Figure 4.1: Water sprinkler graphical model. (Left). The DAG. Generated by sprinkler_pgm.ipynb. (Right).
18
The CPT for the W node, assuming all variables are binary.
19
20
21
22 we see that the grass is wet: our belief that it rained changes to p(R = 1|W = 1) = 0.7079. Now
23 suppose we also notice the water sprinkler was turned on: our belief that it rained goes down to
24 p(R = 1|W = 1, S = 1) = 0.3204. This negative mutual interaction between multiple causes of
25 some observations is called the explaining away effect, also known as Berkson’s paradox (see
26 Main Section 4.2.4.2 for details).
27 In general, we can use our joint model to answer all kinds of probabilistic queries. This includes
28 inferring latent quantities (such as whether the water sprinkler turned on or not given that the grass
29 is wet), as well as predicting observed quantities, such as whether the grass will be slippery. It is this
30 ability to answer arbitrary queries that makes PGMs so useful. See Main Chapter 9 for algorithmic
31 details.
32 Note also that inference requires a fully-specified model. This means we need to know the
33 graph structure G and the parameters of the CPDs θ. We discuss how to learn the parameters in
34 Main Section 4.2.7 and the structure in Main Section 30.3.
35
36
4.1.2 Asia network
37
38 In this section, we consider a hypothetical medical model proposed in [Lauritzen88] which is known
39 in the literature as the “Asia network”. (The name comes from the fact that was designed to
40 diagnoise various lung diseases in Western patients returning from a trip to Asia. Note that this
41 example predates the COVID-19 pandemic by many years, and is a purely fictitious model.)
42 Figure 4.2a shows the model, as well as the prior marginal distributions over each node (assumed
43 to be binary). Now suppose the patient reports that they have Dyspnea, aka shortness of breath.
44 We can represent this fact “clamping” the distribution to be 100% probability that Dyspnea=Present,
45 and 0% probability that Dyspnea=Absent. We then propagate this new information through the
46 network to get the updated marginal distributions shown in Figure 4.2b. We see that the probability
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
4.1. MORE EXAMPLES OF DGMS

1
2
3
4
5
6
7
8
9
10
11 (a) (b)
12
13
14
15
16
17
18
19
20
21
(c) (d)
22
23
Figure 4.2: Illustration of belief updating in the “Asia” PGM. The histograms show the marginal distribution
24 of each node. (a) Prior. (b) Posterior after conditioning on Dyspnea=Present. (c) Posterior after also
25 conditioning on VisitToAsia=True. (d) Posterior after also conditioning on Smoking=True. Generated by
26 asia_pgm.ipynb.
27
28
29
of lung cancer has gone up from 5% to 10%, and probability of bronchitis has gone up from 45% to
30
83%.
31
However, it could also be a an undiagnosed case of TB (tuberculosis), which may have been caused
32
by exposure to an infectious lung disease that was prevalent in Asia at the time. So he doctor asks
33
the patient if they have recently been to Asia, and they say yes. Figure 4.2c shows the new belief
34
state of each node. We see that the probability of TB has increased from 2% to 9%. However,
35
Bronchitis remains the most likely explanation of the symptoms.
36
To gain more information the doctor asks if the patient smokes, and they say yes. Figure 4.2d
37
shows the new belief state of each node. Now the probability of cancer and bronchitis have both gone
38
up. In addition, the posterior predicted probability that an X-ray will show an abnormal result has
39
gone up to 24%, so the doctor may decide it is now worth ordering a test to verify this hypothesis.
40
This example illustrates the nature of recursive Bayesian updating, and how it can be useful for
41
active learning and sequential decision making,
42
43
44
4.1.3 The QMR network
45 In this section, we describe the DPGM known as the quick medical reference or QMR network
46 [Shwe91]. This is a model of infectious diseases and is shown (in simplified form) in Figure 4.3. (We
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


42

1
2
3
4 Z1 Z2 Z3
5
6
7
8
9 X1 X2 X3 X4 X5
10
11
Figure 4.3: A small version of the QMR network. All nodes are binary. The hidden nodes zk represent
12
diseases, and the visible nodes xd represent symptoms. In the full network, there are 570 hidden (disease)
13
nodes and 4075 visible (symptom) nodes. The shaded (solid gray) leaf nodes are observed; in this example,
14
symptom x2 is not observed (i.e., we don’t know if it is present or absent). Of course, the hidden diseases are
15 never observed.
16
17 z0 z1 z2 P (xd = 0|z0 , z1 , z2 ) P (xd = 1|z0 , z1 , z2 )
18 1 0 0 θ0 1 − θ0
19 1 1 0 θ0 θ1 1 − θ0 θ1
20 1 0 1 θ0 θ2 1 − θ0 θ2
21 1 1 1 θ0 θ1 θ2 1 − θ0 θ1 θ2
22
23 Table 4.1: Noisy-OR CPD for p(xd |z0 , z1 , z2 ), where z0 = 1 is a leak node.
24
25
26 omit the parameters for clarity, so we don’t use plate notation.) The QMR model is a bipartite
27 graph structure, with hidden diseases (causes) at the top and visible symptoms or findings at the
28 bottom. We can write the distribution as follows:
29
K
Y D
Y
30
p(z, x) = p(zk ) p(xd |xpa(d ) (4.3)
31
k=1 d=1
32
33 where zk represents the k’th disease and xd represents the d’th symptom. This model can be used
34 inside an inference engine to compute the posterior probability of each disease given the observed
35 symptoms, i.e., p(zk |xv ), where xv is the set of visible symptom nodes. (The symptoms which are not
36 observed can be removed from the model, assuming they are missing at random (Main Section 3.11),
37 because they contribute nothing to the likelihood; this is called barren node removal.)
38 We now discuss the parameterization of the model. For simplicity, we assume all nodes are binary.
39 The CPD for the root nodes are just Bernoulli distributions, representing the prior probability of
40 that disease. Representing the CPDs for the leaves (symptoms) using CPTs would require too
41 many parameters, because the fan-in (number of parents) of many leaf nodes is very high. A
42 natural alternative is to use logistic regression to model the CPD, p(xd |zpa(d) ) = Ber(xd |σ(wTd zpa(d) )).
43 However, we use an alternative known as the noisy-OR model, which we explain below,
44 The noisy-OR model assumes that if a parent is on, then the child will usually also be on (since it
45 is an or-gate), but occasionally the “links” from parents to child may fail, independently at random.
46 If a failure occurs, the child will be off, even if the parent is on. To model this more precisely, let
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
4.1. MORE EXAMPLES OF DGMS

1
2 θkd = 1 − qkd be the probability that the k → d link fails. The only way for the child to be off is if
3 all the links from all parents that are on fail independently at random. Thus
4 Y I(z =1)
5 p(xd = 0|z) = θkd k (4.4)
6 k∈pa(d)
7
8 Obviously, p(xd = 1|z) = 1 − p(xd = 0|z). In particular, let us define qkd = 1 − θkd = p(xd = 1|zk =
9 1, z−k = 0); this is the probability that k can activate d “on its own”; this is sometimes called its
10 “causal power” (see e.g., [Korb2011]).
11 If we observe that xd = 1 but all its parents are off, then this contradicts the model. Such a data
12 case would get probability zero under the model, which is problematic, because it is possible that
13 someone exhibits a symptom but does not have any of the specified diseases. To handle this, we add
14 a dummy leak node z0 , which is always on; this represents “all other causes”. The parameter q0d
15 represents the probability that the background leak can cause symptom d on its own. The modified
16 CPD becomes
17 Y z
18
p(xd = 0|z) = θ0d θkdk (4.5)
k∈pa(d)
19
20
See Table 4.1 for a numerical example.
21
If we define wkd , log(θkd ), we can rewrite the CPD as
22
!
23 X
24 p(xd = 1|z) = 1 − exp w0d + zk wkd (4.6)
25 k
26
We see that this is similar to a logistic regression model.
27
It is relatively easy to set the θkd parameters by hand, based on domain expertise, as was done
28
with QMR. Such a model is called a probabilistic expert system. In this book, we focus on
29
learning parameters from data; we discuss how to do this in Main Section 4.2.7.2 (see also [Neal92;
30
Meek97]).
31
32
33 4.1.4 Genetic linkage analysis
34
DPGM’s are widely used in statistical genetics. In this section, we discuss the problem of genetic
35
linkage analysis, in which we try to infer which genes cause a given disease. We explain the method
36
below.
37
38
4.1.4.1 Single locus
39
40 We start with a pedigree graph, which is a DAG that representing the relationship between parents
41 and children, as shown in Figure 4.4(a). Next we construct the DGM. For each person (or animal)
42 i and location or locus j along the genome, we create three nodes: the observed phenotype Pij
43 (which can be a property such as blood type, or just a fragment of DNA that can be measured), and
p
44 two hidden alleles (genes), Gmij and Gij , one inherited from i’s mother (maternal allele) and the
p
45 other from i’s father (paternal allele). Together, the ordered pair Gij = (Gm
ij , Gij ) constitutes i’s
46 hidden genotype at locus j.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


44

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 (a) (b)
18
19 Figure 4.4: Left: family tree, circles are females, squares are males. Individuals with the disease of interest are
20 highlighted. Right: DGM for locus j = L. Blue node Pij is the phenotype for individual i at locus j. Orange
p/m p/m
21 nodes Gij is the paternal/ maternal allele. Small red nodes Sij are the paternal/ maternal selection
22 switching variables. The founder (root) nodes do not have any parents, and hence do no need switching variables.
23
All nodes are hidden except the blue phenotypes. Adapted from Figure 3 from [Friedman00linkage].
24
Gp Gm p(P = a) p(P = b) p(P = o) p(P = ab)
25
a a 1 0 0 0
26 a b 0 0 0 1
27 a o 1 0 0 0
28 b a 0 0 0 1
b b 0 1 0 0
29
b o 0 1 0 0
30 o a 1 0 0 0
31 o b 0 1 0 0
32 o o 0 0 1 0
33
Table 4.2: CPT which encodes a mapping from genotype to phenotype (bloodtype). This is a deterministic,
34
but many-to-one, mapping.
35
36
37
p
38 Obviously we must add Gm ij → Xij and Gij → Pij arcs representing the fact that genotypes cause
p
39 phenotypes. The CPD p(Pij |Gij , Gij ) is called the penetrance model. As a very simple example,
m
p
40 suppose Pij ∈ {A, B, O, AB} represents person i’s observed bloodtype, and Gm ij , Gij ∈ {A, B, O}
41 is their genotype. We can represent the penetrance model using the deterministic CPD shown in
42 Table 4.2. For example, A dominates O, so if a person has genotype AO or OA, their phenotype will
43 be A.
m/p
44 In addition, we add arcs from i’s mother and father into Gij , reflecting the Mendelian inheri-
45 tance of genetic material from one’s parents. More precisely, let µi = k be i’s mother. For example,
46 in Figure 4.4(b), for individual i = 3, we have µi = 2, since 2 is the mother of 3. The gene Gm ij could
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
4.1. MORE EXAMPLES OF DGMS

1
2 Locus # 1 Locus # 2

3 1 2

4
5 3 4 1 2
6
7
5 6 3 4
8
9
10
11 5 6

12
13
14
15
Figure 4.5: Extension of Figure 4.4 to two loci, showing how the switching variables are spatially cor-
m m p p
related. This is indicated by the Sij → Si,j+1 and Sij → Si,j+1 edges. Adapted from Figure 3 from
16
[Friedman00linkage].
17
18
19
p
20 either be equal to Gm kj or Gkj , that is, i’s maternal allele is a copy of one of its mother’s two alleles.
21 Let Sij be a hidden switching variable that specifies the choice. Then we can use the following CPD,
m

22 known as the inheritance model:


23
  
 I Gm = Gm if Sij
m
=m
p ij kj
24
p(Gmij |Gm
kj , G , S m
ij ) =   (4.7)
kj  I Gm = G p
25
ij kj if Sij
m
=p
26
27 We can define p(Gpij |Gm p p
kj , Gkj , Sij ) similarly, where π = pi is i’s father. The values of the Sij are
28
said to specify the phase of the genotype. The values of Gpi,j , Gm p
i,j , Si,j and Si,j constitute the
m
p p
29
haplotype of person i at locus j. (The genotype Gi,j and Gi,j without the switching variables Si,j
m
30
and Si,j is called the “unphased” genotype.)
m
p
31
Next, we need to specify the prior for the root nodes, p(Gm ij ) and p(Gij ). This is called the founder
32
model, and represents the overall prevalence of difference kinds of alleles in the population. We
33
usually assume independence between the loci for these founder alleles, and give these root nodes
34
uniform priors. Finally, we need to specify priors for the switch variables that control the inheritance
35
process. For now, we will assume there is just a single locus, so we can assume uniform priors for the
36
switches. The resulting DGM is shown in Figure 4.4(b).
37
38
4.1.4.2 Multiple loci
39
40 We get more statistical power if we can measure multiple phenotypes and genotypes. In this case,
41 we must model spatial correlation amonst the genes, since genes that are close on the genome are
42 likely to be coinherited, since there is less likely to be a crossover event between them. We can model
43 this by imposing a two-state Markov chain on the switching variables S’s, where the probability
44 of switching state at locus j is given by θj = 12 (1 − e−2dj ), where dj is the distance between loci j
45 and j + 1. This is called the recombination model. The resulting DGM for two linked loci in
46 Figure 4.5.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


46

1
2 We can now use this model to determine where along the genome a given disease-causing gene
3 is assumed to lie — this is the genetic linkage analysis task. The method works as follows. First,
4 suppose all the parameters of the model, including the distance between all the marker loci, are
5 known. The only unknown is the location of the disease-causing gene. If there are L marker loci,
6 we construct L + 1 models: in model `, we postulate that the disease gene comes after marker `,
7 for 0 < ` < L + 1. We can estimate the Markov switching parameter θ̂` , and hence the distance d`
8 between the disease gene and its nearest known locus. We measure the quality of that model using
9 its likelihood, p(D|θ̂` ). We then can then pick the model with highest likelihood.
10 Note, however, that computing the likelihood requires marginalizing out all the hidden S and G
11 variables. See [Fishelson02] and the references therein for some exact methods for this task; these
12 are based on the variable elimination algorithm, which we discuss in Main Section 9.5. Unfortunately,
13 for reasons we explain in Main Section 9.5.4, exact methods can be computationally intractable if the
14 number of individuals and/or loci is large. See [Albers06] for an approximate method for computing
15 the likelihood based on the “cluster variation method”.
16 Note that it is possible to extend the above model in multiple ways. For example, we can model
17 evolution amongst phylogenies using a phylogenetic HMM [Siepel03].
18
19
20 4.2 More examples of UGMs
21
22
4.3 Restricted Boltzmann machines (RBMs) in more detail
23
24
In this section, we discuss RBMs in more detail.
25
26
27
4.3.1 Binary RBMs
28
29 The most common form of RBM has binary hidden nodes and binary visible nodes. The joint
30 distribution then has the following form:
31
32 1
33 p(x, z|θ) = exp(−E(x, z; θ)) (4.8)
Z(θ)
34
D X
X K D
X K
X
35
E(x, z; θ) , − xd zk Wdk − x d bd − zk ck (4.9)
36
d=1 k=1 d=1 k=1
37
38
= −(x Wz + xT b + z c)
T T
(4.10)
XX
39 Z(θ) = exp(−E(x, z; θ)) (4.11)
40 x z
41
42 where E is the energy function, W is a D × K weight matrix, b are the visible bias terms, c are
43 the hidden bias terms, and θ = (W, b, c) are all the parameters. For notational simplicity, we will
44 absorb the bias terms into the weight matrix by adding dummy units x0 = 1 and z0 = 1 and setting
45 w0,: = c and w:,0 = b. Note that naively computing Z(θ) takes O(2D 2K ) time but we can reduce
46 this to O(min{D2K , K2D }) time using the structure of the graph.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
4.3. RESTRICTED BOLTZMANN MACHINES (RBMS) IN MORE DETAIL

1
2 When using a binary RBM, the posterior can be computed as follows:
3
K
Y Y
4
5
p(z|x, θ) = p(zk |x, θ) = Ber(zk |σ(wT:,k x)) (4.12)
k=1 k
6
7
By symmetry, one can show that we can generate data given the hidden variables as follows:
8
9
Y Y
p(x|z, θ) = p(xd |z, θ) = Ber(xd |σ(wTd,: z)) (4.13)
10
d d
11
12 We can write this in matrix-vector notation as follows:
13
14 E [z|x, θ] = σ(WT x) (4.14)
15
E [x|z, θ] = σ(Wz) (4.15)
16
17
The weights in W are called the generative weights, since they are used to generate the observations,
18
and the weights in WT are called the recognition weights, since they are used to recognize the
19
input.
20
From Equation 4.12, we see that we activate hidden node k in proportion to how much the input
21
vector x “looks like” the weight vector w:,k (up to scaling factors). Thus each hidden node captures
22
certain features of the input, as encoded in its weight vector, similar to a feedforward neural network.
23
For example, consider an RBM for text models, where x is a bag of words (i.e., a bit vector over the
24
vocabulary). Let zk = 1 if “topic” k is present in the document. Suppose a document has the topics
25
“sports” and “drugs”. If we “multiply” the predictions of each topic together, the model may give very
26
high probability to the word “doping”, which satisfies both constraints. By contrast, adding together
27
experts can only make the distribution broader. In particular, if we mix together the predictions
28
from “sports” and “drugs”, we might generate words like “cricket” and “addiction”, which come from
29
the union of the two topics, not their intersection.
30
31
32 4.3.2 Categorical RBMs
33
We can extend the binary RBM to categorical visible variables by using a 1-of-C encoding, where
34
C is the number of states for each xd . We define a new energy function as follows [Salak07;
35
Salak10softmax]:
36
37
D X
X K X
C D X
X C K
X
38
E(x, z; θ) , − xcd zk wdk
c
− xcd bcd − zk ck (4.16)
39
d=1 k=1 c=1 d=1 c=1 k=1
40
41 The full conditionals are given by
42
X
43 p(xd = c∗ |z, θ) = softmax({bcd + c C
zk wdk }c=1 ))[c∗ ] (4.17)
44 k
45
XX
p(zk = 1|x, θ) = σ(ck + xcd wdk
c
) (4.18)
46
d c
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


48

1
2 4.3.3 Gaussian RBMs
3
We can generalize the model to handle real-valued data. In particular, a Gaussian RBM has the
4
following energy function:
5
D X
X K D K
1X X
6
7 E(x, z|θ) = − wdk zk xd − (xd − bd )2 − ak zk (4.19)
2
8 d=1 k=1 d=1 k=1
9
10
The parameters of the model are θ = (wdk , ak , bd ). (We have assumed the data is standardized, so
11
we fix the variance to σ 2 = 1.) Compare this to a Gaussian in canonical or information form (see
12
Main Section 2.3.1.4):
13
1
14 Nc (x|η, Λ) ∝ exp(ηT x − xT Λx) (4.20)
2
15
P
16 where η = Λµ. P We see that we have set Λ = I, and η = k zk w:,k . Thus the mean is given by
17 µ = Λ−1 η = k zk w:,k , which is a weighted combination of prototypes. The full conditionals, which
18 are needed for inference and learning, are given by
19 X
20 p(xd |z, θ) = N (xd |bd + wdk zk , 1) (4.21)
21 k
!
22 X
23 p(zk = 1|x, θ) = σ ck + wdk xd (4.22)
24 d

25
More powerful models, which make the (co)variance depend on the hidden states, can also be
26
developed [Ranzato10].
27
28
29
4.3.4 RBMs with Gaussian hidden units
30 If we use Gaussian latent variables and Gaussian visible variables, we get an undirected version of
31 factor analysis (Main Section 28.3.1). Interestingly, this is mathematically equivalent to the standard
32 directed version [Marks01].
33 If we use Gaussian latent variables and categorical observed variables, we get an undirected version
34 of categorical PCA (Main Section 28.3.5). In [Salak07], this was applied to the Netflix collaborative
35 filtering problem, but was found to be significantly inferior to using binary latent variables, which
36 have more expressive power.
37
38
39
40
41
42
43
44
45
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
5 Information theory

5.1 Minimizing KL between two Gaussians


5.1.1 Moment projection
In moment projection, we want to solve
q = argmin DKL (p k q) (5.1)
q

We just need to equate the moments of q to the moments of p. For example, suppose p(x) = N (x|µ, Σ)
and q(x) = N (x|m, diag(v)). We have m = µ and vi = Σii . This is the “mode covering” solution.

5.1.2 Information projection


In information projection, we want to solve
q = argmin DKL (q k p) (5.2)
q

For example, suppose


p(x) = N (x|µ, Σ) = N (x|µ, Λ−1 ) (5.3)
and
q(x) = N (x|m, V) = N (x|m, diag(v)) (5.4)
Below we show that the optimal solution is to set m = µ and vi = Λ−1
ii . This is the “mode seeking”
solution.
To derive this result, we write out the KL as follows:
 
1 |Σ|
DKL (q k p) = log T
− D + (µ − m) Σ(µ − m) + tr(ΛV) (5.5)
2 |V|
where D is the dimensionality. To find the optimal m, we set the derivative to 0 and solve. This
gives m = µ. Plugging in this solution, the objective then becomes
 
1 X
J(q) = − log |V| − const + [Λjj Vjj ] (5.6)
2 j
50

1
2 Defining V = diag(σi2 ) we get
3  
4 1 X X 
5 J(q) = − log σj2 − const + σj2 Λjj  (5.7)
2 j j
6
7
∂J(q)
8 Setting ∂σi = 0 we get
9
∂J(q) 1
10 = − + σi Λii = 0 (5.8)
11 ∂σi σi
12
which gives
13
14 σi−2 = Λii (5.9)
15
16 which says that we should match the marginal precisions of the two distributions.
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
6 Optimization

6.1 Proximal methods


In this section, we discuss a class of optimization algorithms called proximal methods that use as
their basic subroutine the proximal operator of a function, as opposed to its gradient or Hessian.
We define this operator below, but essentially it involves solving a convex subproblem.
Compared to gradient methods, proximal method are easier to apply to nonsmooth problems
(e.g., with `1 terms), as well as large scale problems that need to be decomposed and solved in
parallel. These methods are widely used in signal and image processing, and in some applications in
deep learning (e.g., [Bai2019] uses proximal methods for training quantized DNNs, [Yao2020] uses
proximal methods for efficient neural architecture search, [PPO; Wang2019PPO] uses proximal
methods for policy gradient optimization, etc.).
Our presentation is based in part on the tutorial in [proximal]. For another good review, see
[Polson2015].

6.1.1 Proximal operators


Let f : Rn → R ∪ {+∞} be a convex function, where f (x) = ∞ means the point is infeasible. Let
the effective domain of f be the set of feasible points:

dom(f ) = {x ∈ Rn : f (x) < ∞} (6.1)

The proximal operator (also called a proximal mapping) of f , denoted proxf (x) : Rn → Rn ,
is defined by
 
1
proxf (x) = argmin f (z) + ||z − x||22 (6.2)
z 2

This is a strongly convex function and hence has a unique minimizer. This operator is sketched
in Figure 6.1a. We see that points inside the domain move towards the minimum of the function,
whereas points outside the domain move to the boundary and then towards the minimum.
For example, suppose f is the indicator function for the convex set C, i.e.,
(
0 if x ∈ C
f (x) = IC (x) = (6.3)
∞ if x 6∈ C
52

1
2 3.0
3
2.5
4
2.0
5
1.5
6
1.0
7
0.5
8
0.0
9
10 -3 -2 -1 0 1 2 3
11 (a) (b)
12
13 Figure 6.1: (a) Evaluating a proximal operator at various points. The thin lines represent level sets of a
14 convex function; the minimum is at the bottom left. The black line represents the boundary of its domain.
15
Blue points get mapped to red points by the prox operator, so points outside the feasible set get mapped
to the boundary, and points inside the feasible set get mapped to closer to the minimum. From Figure 1
16
of [proximal]. Used with kind permission of Stephen Boyd. (b) Illustration of the Moreau envelope with
17
η = 1 (dotted line) of the absolute value function (solid black line). See text for details. From Figure 1 of
18
[Polson2015]. Used with kind permission of Nicholas Polson.
19
20
21
22
In this case, the proximal operator is equivalent to projection onto the set C:
23
24
projC (x) = argmin ||z − x||2 (6.4)
25 z∈C
26
27 We can therefore think of the prox operator as generalized projection.
28 We will often want to compute the prox operator for a scaled function ηf , for η > 0, which can be
29 written as
30  
1
31 proxηf (x) = argmin f (z) + ||z − x||2 2
(6.5)
32 z 2η
33
34
The solution to the problem in Equation (6.5) the same as the solution to the trust region
35
optimization problem of the form
36
37
argmin f (z) s.t. ||z − x||2 ≤ ρ (6.6)
z
38
39 for appropriate choices of η and ρ. This the proximal projection minimizes the function while staying
40 close to the current iterate. We give other interpretations of the proximal operator below.
41 We can generalize the operator by replacing the Euclidean distance with Mahalanobis distance:
42
 
43 1
proxηf,A (x) = argmin f (z) + (z − x)T A(z − x) (6.7)
44 z 2η
45
46 where A is a psd matrix.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
6.1. PROXIMAL METHODS

1
2 6.1.1.1 Moreau envelope
3
Let us define the following quadratic approximation to the function f as a function of z, requiring
4
that it touch f at x:
5
6 1
fxη (z) = f (z) + ||z − x||22 (6.8)
7 2η
8
By definition, the location of the minimum of this function is z ∗ (x) = argminz fxη (z) = proxηf (x).
9
10
For example, consider approximating the function f (x) = |x| at x0 = 1.5 using fx10 (z) = |z| + 12 (z −
11
x0 )2 . This is shown in Figure 6.1b: the solid black line is f (x), x0 = 1.5 is the black square, and
12
the light gray line is fx10 (z). The proximal projection of x0 onto f is z ∗ (x0 ) = argminz fx10 (z) = 0.5,
13
which is the minimum of the quadratic, shown by the red cross. This proximal point is closer to the
14
minimum of f (x) than the starting point, x0 .
15
Now let us evaluate the approximation at this proximal point:
16 1 1
fxη (z ∗ (x)) = f (z ∗ ) + ||z ∗ − x||22 = min f (z) + ||z − x||22 , f η (x) (6.9)
17 2η z 2η
18
19
where f η (x) is called the Moreau envelope of f .
20
For example, in Figure 6.1b, we see that fx10 (z ∗ ) = fx10 (0.5) = 1.0, so f 1 (x0 ) = 1.0. This is shown
21
by the blue circle. The dotted line is the locus of blue points as we vary x0 , i.e., the Moreau envelope
22
of f .
23
We see that the Moreau envelope is a smooth lower bound on f , and has the same minimum location
24
as f . Furthermore, it has domain Rn , even when f does not, and it is continuously differentiable,
25
even when f is not. This makes it easier to optimize. For example, the Moreau envelope of f (r) = |r|
26
is the Huber loss function, which is used in robust regression.
27
28 6.1.1.2 Prox operator on a linear approximation yields gradient update
29
Suppose we make a linear approximation of f at the current iterate xt :
30
31 fˆ(x) = f (xt ) + gTt (x − xt ) (6.10)
32
33
where gt = ∇f (xt ). To compute the prox operator, note that
 
34 1 1
35 ∇z fˆxη (z) = ∇z f (xt ) + gTt (z − x) + ||z − xt ||22 = gt + (z − xt ) (6.11)
2η η
36
37 Solving ∇z fˆxη (z) = 0 yields the standard gradient update:
38
39
proxηfˆ(x) = x − ηgt (6.12)
40
Thus a prox step is equivalent to a gradient step on a linearized objective.
41
42
6.1.1.3 Prox operator on a quadratic approximation yields regularized Newton update
43
44 Now suppose we use a second order approximation at xt :
45
1
46 fˆ(x) = f (xt ) + ∇gt (x − xt ) + (x − xt )T Ht (x − xt ) (6.13)
2
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


54

1
2 The prox operator for this is
3 1
4 proxηfˆ(x) = x − (Ht + I)−1 gt (6.14)
η
5
6 6.1.1.4 Prox operator as gradient descent on a smoothed objective
7
8 Prox operators are arguably most useful for nonsmooth functions for which we cannot make a
9 Taylor series approximation. Instead, we will optimize the Moreau envelope, which is a smooth
10 approximation.
11 In particular, from Equation (6.9), we have
12 1
13 f η (x) = f (proxηf (x)) + ||x − proxηf (x))||22 (6.15)

14
15
Hence the gradient of the Moreau envelope is given by
16 1
∇x f η (x) = (x − proxηf (x)) (6.16)
17 η
18
Thus we can rewrite the prox operator as
19
20 proxηf (x) = x − η∇f η (x) (6.17)
21
Thus a prox step is equivalent to a gradient step on the smoothed objective.
22
23
24
6.1.2 Computing proximal operators
25 In this section, we briefly discuss how to compute proximal operators for various functions that
26 are useful in ML, either as regularizers or constraints. More examples can be found in [proximal;
27 Polson2015].
28
29
6.1.2.1 Moreau decomposition
30
31 A useful technique for computing some kinds of proximal operators leverages a result known as
32 Moreau decomposition, which states that
33
x = proxf (x) + proxf ∗ (x) (6.18)
34
35 where f ∗ is the convex conjugate of f (see Section 6.3).
36 For example, suppose f = || · || is a general norm on RD . If can be shown that f ∗ = IB , where
37
B = {x : ||x||∗ ≤ 1} (6.19)
38
39 is the unit ball for the dual norm || · ||∗ , defined by
40
||z||∗ = sup{zT x : ||x|| ≤ 1} (6.20)
41
42 Hence
43
proxλf (x) = x − λproxf ∗ /λ (x/λ) = x − λprojB (x/λ) (6.21)
44
45 Thus there is a close connection between proximal operators of norms and projections onto norm
46 balls that we will leverage below.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
6.1. PROXIMAL METHODS

1
2 6.1.2.2 Projection onto box constraints
3
Let C = {x : l ≤ x ≤ u} be a box or hyper-rectangle, imposing lower and upper bounds on each
4
element. (These bounds can be infinite for certain elements if we don’t to constrain values along that
5
dimension.) The projection operator is easy to compute elementwise by simply thresholding at the
6
boundaries:
7

ld if xk ≤ lk
8 
9
projC (x)d = xd if lk ≤ xk ≤ uk (6.22)
10 

11
ud if xk ≥ uk
12
For example, if we want to ensure all elements are non-negative, we can use
13
14
projC (x) = x+ = [max(x1 , 0), . . . , max(xD , 0)] (6.23)
15
16
17
6.1.2.3 `1 norm
18 Consider the 1-norm f (x) = ||x||1 . The proximal projection can be computed componentwise. We
19 can solve each 1d problem as follows:
20
21 1
proxλf (x) = argmin λ|z| + (z − x)2 (6.24)
22 z z
23
24 One can show that the solution to this is given by
25 
26 x − λ if x ≥ λ

27 proxλf (x) = 0 if |x| ≥ λ (6.25)


28 x + λ if x ≤ λ
29
30 This is known as the soft thresholding operator, since values less than λ in absolute value are
31 set to 0 (thresholded), but in a differentiable way. This is useful for enforcing sparsity. Note that soft
32 thresholding can be written more compactly as
33
34 SoftThresholdλ (x) = sign(x) (|x| − λ)+ (6.26)
35
36 where x+ = max(x, 0) is the positive part of x. In the vector case, we define SoftThresholdλ (x) to
37 be elementwise soft thresholding.
38
39 6.1.2.4 `2 norm
40 qP
D
41 Now consider the `2 norm f (x) = ||x||2 = d=1 xd . The dual norm for this is also the `2 norm.
2

42 Projecting onto the corresponding unit ball B can be done by simply scaling vectors that lie outside
43 the unit sphere:
44 (
x
45 ||x||2 > 1
projB (x) = ||x||2 (6.27)
46 x ||x||2 ≤ 1
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


56

1
2 Hence by the Moreau decomposition we have
3 (
4 (1 − λ
||x||2 )x if ||x||2 ≥ λ
proxλf (x) = (1 − λ/||x||2 )+ x = (6.28)
5 0 otherwise
6
7 This will set the whole vector to zero if its `2 norm is less than λ. This is therefore called block soft
8 thresholding.
9
10
11 6.1.2.5 Squared `2 norm
PD
12
Now consider using the squared `2 norm (scaled by 0.5), f (x) = 12 ||x||22 = 1
2 d=1 x2d . One can show
13
that
14
15 1
proxλf (x) = x (6.29)
16 1+λ
17
18 This reduces the magnitude of the x vector, but does not enforce sparsity. It is therefore called the
19 shrinkage operator.
20 More generally, if f (x) = 12 xT Ax + bT x + c is a quadratic, with A being positive definite, then
21
22 proxλf (x) = (I + λA)−1 (x − λb) (6.30)
23
24 A special case of this is if f is affine, f (x) = bT x + c. Then we have proxλf (x) = x − λb. We saw an
25 example of this in Equation (6.12).
26
27
6.1.2.6 Nuclear norm
28
29 The nuclear norm, also called the trace norm, of an m × n matrix A is the `1 norm of of its
30 singular values: f (A) = ||A||∗ = ||σ||1 . Using this as a regularizer can result in a low rank matrix.
31 The proximal operator for this is defined by
32
X
33 proxλf (A) = (σi − λ)+ ui vTi (6.31)
34 i
35
P
36 where A = i σi ui viT is the SVD of A. This operation is called singular value thresholding.
37
38
6.1.2.7 Projection onto positive definite cone
39
40 Consider the cone of positive semidefinite matrices C, and let f (A) = IC (A) be the indicator function.
41 The proximal operator corresponds to projecting A onto the cone. This can be computed using
42
X
43 projC (A) = (λi )+ ui uTi (6.32)
44 i
45
P
46 where i λi ui uTi is the eigenvalue decomposition of A. This is useful for optimizing psd matrices.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
6.1. PROXIMAL METHODS

1
2 6.1.2.8 Projection onto probability simplex
3 PD
Let C = {x : x ≥ 0, d=1 xd = 1} = SD be the probability simplex in D dimensions. We can project
4
onto this using
5
6 projC (x) = (x − ν1)+ (6.33)
7
8 The value ν ∈ R must be found using bisection search. See [proximal] for details. This is useful for
9 optimizing over discrete probability distributions.
10
11 6.1.3 Proximal point methods (PPM)
12
13
A proximal point method (PPM), also called a proximal minimization algorithm, iteratively
14
applies the following update:
15 1
16 θt+1 = proxηt L (θt ) = argmin L(θ) + ||θ − θt ||22 (6.34)
θ 2ηt
17
18 where we assume L : Rn → R ∪ {+∞} is a closed proper convex function. The advantage of this
19 method over minimizing L directly is that sometimes adding quadratic regularization can improve
20 the conditioning of the problem, and hence speed convergence.
21
22 6.1.3.1 Gradient descent is PPM on a linearized objective
23
24
We now show that SGD is PPM on a linearized objective. To see this, let the approximation at the
25
current iterate be
26
L̃ˆt (θ) = L̃t (θt ) + gTt (θ − θt ) (6.35)
27
28 where gt = ∇θ L̃t (θt ). Now we compute a proximal update to this approximate objective:
29
1
30 θt+1 = proxη ˆ (θt ) = argmin L̃ˆt (θ) + ||θ − θt ||22 (6.36)
31 t L̃t
θ 2η t
32
We have
33  
34 1 1
∇θ L̃t (θt ) + gTt (θ − θt ) + ||θ − θt ||22 = gt + (θ − θt ) (6.37)
35 2ηt ηt
36
Setting the gradient to zero yields the SGD step θt+1 = θt − ηt gt .
37
38
39 6.1.3.2 Beyond linear approximations (truncated AdaGrad)
40 Sometimes we can do better than just using PPM with a linear approximation to the objective, at
41 essentially no extra cost, as pointed out in [Asi2019siam; Asi2019aistats; Asi2019pnas]. For
42 example, suppose we know a lower bound on the loss, L̃min t = minθ L̃t (θ). For example, when using
43 squared error, or cross-entropy loss for discrete labels, we have L̃t (θ) ≥ 0. Let us therefore define the
44 truncated model
45  
46 L̃ˆt (θ) = max L̃t (θ) + gTt (θ − θt ), L̃min
t (6.38)
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


58

1
2 50

3 45 0.8

4 40
SGM 0.6
5 35 Truncated
adam
6 trunc-adagrad
30 0.4
7
25
SGM
8 0.2 Truncated
20 adam
9 trunc-adagrad

10−3 10−2 10−1 100 101 102 103 10−3 10−2 10−1 100 101 102 103
10
11 (a) (b)
12
13 Figure 6.2: Illustration of the benefits of using a lower-bounded loss function when training a resnet-128
14 CNN on the CIFAR10 image classification dataset. The curves are as follows: SGM (stochastic gradient
15
method, i.e., SGD), Adam, truncated SGD and truncated AdaGrad. (a) Time to reach an error that
satisifes L(θt ) − L(θ ∗ ) ≤  vs initial learning rate η0 . (b) Top-1 accuracy after 50 epochs vs η0 . The lines
16
represent median performance across 50 random restarts, and shading represents 90% confidence intervals.
17
From Figure 4 of [Asi2019pnas]. Used with kind permission of Hilal Asi.
18
19
20
21
We can further improve things by replacing the Euclidean
Pt norm with a scaled Euclidean norm,
1

22
where the diagonal scaling matrix is given by At = diag( i=1 gi gTi ) 2 , as in AdaGrad [adagrad].
23 If L̃min
t = 0, the resulting proximal update becomes
24 h i 1
25 θt+1 = argmin L̃t (θt ) + gTt (θ − θt ) + (θ − θt )T At (θ − θt ) (6.39)
26 θ + 2ηt
27 L̃t (θt )
= θt − min(ηt , )gt (6.40)
28 gt A−1
T
t gt
29
30 Thus the update is like a standard SGD update, but we truncate the learning rate if it is too big.1
31 [Asi2019pnas] call this truncated AdaGrad. Furthermore, they prove optimizing this trun-
32 cated linear approximation (with or without AdaGrad weighting), instead of the standard linear
33 approximation used by gradient descent, can result in significant benefits. In particular, it is guaran-
34 teed to be stable (under certain technical conditions) for any learning rate, whereas standard GD
35 can “blow up”, even for convex problems.
36 Figure 6.2 shows the benefits of this approach when training a resnet-128 CNN (Main Section 16.2.4)
37 on the CIFAR10 image classification dataset. For SGD and the truncated proximal method, the
38 learning rate is decayed using ηt = η0 t−β with β = 0.6. For Adam and truncated AdaGrad, the
39 learning rate is set to ηt = η0 , since we use diagonal scaling. We see that both truncated methods
40 (regular and AdaGrad version) have good performance for a much broader range of initial learning
41 rate η0 compared to SGD or Adam.
42
43 1. One way to derive this update (suggested by Hilal Asi) is to do case analysison the value of L̃ˆt (θt+1 ), where
44 L̃ˆt is the truncated linear model. If L̃ˆt (θt+1 ) > 0, then setting the gradient to zero yields the usual SGD update,
45 θt+1 = θ − η g . (We assume A = I for simplicity.) Otherwise we must have L̃ˆ (θ
t t t t ) = 0. But we know that
t t+1
46 θt+1 = θt − λgt for some λ, so we solve L̃ˆt (θt − λgt ) = 0 to get λ = L̃ˆt (θt )/||gt ||22 .
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
6.1. PROXIMAL METHODS

1
2 6.1.3.3 Stochastic and incremental PPM
h i
3
PPM can be extended to the stochastic setting, where the goal is to optimize L(θ) = Eq(z) L̃(θ, z) ,
4
5
by using the following stochastic update:
6 1
7 θt+1 = proxηt L̃t (θt ) = argmin L̃t (θ) + ||θ − θt ||22 (6.41)
θ 2ηt
8
9 where L̃t (θ) = L̃(θ, zt ) and zt ∼ q. The resulting method is known as stochastic PPM (see e.g.,
10 [Patrascu2018]). If q is the empirical distribution associated with a finite-sum objective, this is
11 called the incremental proximal point method [Bertsekas2015]. It is often more stable than
12 SGD.
13 In the case where the cost function is a linear least squares problem, one can show [Akyildiz2018]
14 that the IPPM is equivalent to the Kalman filter (Main Section 8.2.2), where the posterior mean is
15 equal to the current parameter estimate, θt . The advantage of this probabilistic perspective is that it
16 also gives us the posterior covariance, which can be used to define a variable-metric distance function
17 inside the prox operator, as in Equation (6.7). We can extend this to nonlinear problems using the
18 extended KF (Main Section 8.3.2).
19
20
6.1.4 Mirror descent
21
22 In this section, we discuss mirror descent [Nemirovski1983; Beck03], which is like gradient
23 descent, but can leverage non-Euclidean geometry to potentially speed up convergence, or enforce
24 certain constaints.
25 Suppose we replace the Euclidean distance term ||θ − θt ||22 in PPM with a Bregman divergence
26 (Main Section 5.1.10), defined as
27  
28
Dh (x, y) = h(x) − h(y) + ∇h(y)T (x − y) (6.42)
29
where h(x) is a strongly convex function. Combined with our linear approximation to the objective,
30
this gives the following update:
31
32 1
θt+1 = argmin L̂(θ) + Dh (θ, θt ) (6.43)
33 θ ηt
34
= argmin ηt gTt θ + Dh (θ, θt ) (6.44)
35 θ
36
37
This is known as mirror descent [Nemirovski1983; Beck03]. This can easily be extended to
38
the stochastic setting in the obvious way.
39
One can show that natural gradient descent (Main Section 6.4) is a form of mirror descent
40
[Raskutti2015]. More precisely, mirror descent in the mean parameter space is equivalent to natural
41
gradient descent in the canonical parameter space.
42
43 6.1.5 Proximal gradient method
44
We are often interested in optimizing a composite objective of the form
45
46 L(θ) = Ls (θ) + Lr (θ) (6.45)
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


60

1
2 where Ls is convex and differentiable (smooth), and Lr is convex but not necessarily differentiable
3 (i.e., it may be non-smooth or “rough”). For example, Lr might be an `1 norm regularization term,
4 and Ls might be the NLL for linear regression (see Section 6.1.5.1).
5 The proximal gradient method is the following update:
6
7
θt+1 = proxηt Lr (θt − ηt ∇Ls (θt )) (6.46)
8
If Lr = IC , this is equivalent to projected gradient descent. If Lr = 0, this is equivalent to
9
gradient descent. If Ls = 0, this is equivalent to a proximal point method.
10
We can create a version of the proximal gradient method with Nesterov acceleration as follows:
11
12 θ̃t+1 = θt + βt (θt − θt−1 ) (6.47)
13
θt+1 = proxηt Lr (θ̃t+1 − ηt ∇Ls (θ̃t+1 )) (6.48)
14
15 See e.g., [Tseng2008].
16 Now we consider the stochastic case, where Ls (θ) = E [Ls (θ, z)]. (We assume Lr is deterministic.)
17 In this setting, we can use the following stochastic update:
18
19 θt+1 = proxηt Lr (θt − ηt ∇Ls (θt , zt )) (6.49)
20
21 where zt ∼ q. This is called the stochastic proximal gradient method. If q is the empirical
22 distribution, this is called the incremental proximal gradient method [Bertsekas2015]. Both
23 methods can also be accelerated (see e.g., [Nitanda2014]).
24 If Ls is not convex, we can compute a locally convex approximation, as in Section 6.1.3.2. (We
25 assume Lr remains convex.) The accelerated version of this is studied in [Li2015apgm]. In the
26 stochastic case, we can similarly make a locally convex approximation to Ls (θ, z). This is studied in
27 [Reddi2016; Li2018prox]. An EKF interpretation in the incremental case (where q = pD ) is given
28 in [Akyildiz2019].
29
30 6.1.5.1 Example: Iterative soft-thresholding algorithm (ISTA) for sparse linear regres-
31 sion
32
Suppose we are interested in fitting a linear regression model with a sparsity-promoting prior on
33
the weights, as in the lasso model (Main Section 15.2.6). One way to implement this is to add the
34 PD
`1 -norm of the parameters as a (non-smooth) penalty term, Lr (θ) = ||θ||1 = d=1 |θd |. Thus the
35
objective is
36
37 1
L(θ) = Ls (θ) + Lr (θ) = ||Xθ − y||22 + λ||θ||1 (6.50)
38 2
39
The proximal gradient descent update can be written as
40
41
θt+1 = SoftThresholdηt λ (θt − ηt ∇Ls (θt )) (6.51)
42
43 where the soft thresholding operator (Equation (6.26)) is applied elementwise, and ∇Ls (θ) = XT (Xθ−
44 y). This is called the iterative soft thresholding algorithm or ISTA [Daubechies2004;
45 Donoho1995]. If we combine this with Nesterov acceleration, we get the method known as “fast
46 ISTA” or FISTA [Beck2009], which is widely used to fit sparse linear models.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
6.1. PROXIMAL METHODS

1
2 6.1.6 Alternating direction method of multipliers (ADMM)
3
Consider the problem of optimizing L(x) = Ls (x) + Lr (x) where now both Ls and Lr may be non-
4
smooth (but we asssume both are convex). We may want to optimize these problems independently
5
(e.g., so we can do it in parallel), but need to ensure the solutions are consistent.
6
One way to do this is by using the variable splitting trick combined with constrained optimiza-
7
tion:
8
9 minimize Ls (x) + Lr (z) s.t. x − z = 0 (6.52)
10
11 This is called consensus form.
12 The corresponding augmented Langragian is given by
13 ρ
14 Lρ (x, z, y) = Ls (x) + Lr (z) + yT (x − z) + ||x − z||22 (6.53)
2
15
16 where ρ > 0 is the penalty strength, and y ∈ Rn are the dual variables associated with the consistency
17 constraint. We can now perform the following block coordinate descent updates:
18
xt+1 = argmin Lρ (x, zt , yt ) (6.54)
19 x
20
zt+1 = argmin Lρ (xt+1 , z, yt ) (6.55)
21 z
22 yt+1 = yt + ρ(xt+1 − zt+1 ) (6.56)
23
24 We see that the dual variable is the (scaled) running average of the consensus errors.
25 Inserting the definition of Lρ (x, z, y) gives us the following more explicit update equations:
26  ρ 
27 xt+1 = argmin Ls (x) + yTt x + ||x − zt ||22 (6.57)
x 2
28  ρ 
29 zt+1 = argmin Lr (z) − yTt z + ||xt+1 − z||22 (6.58)
z 2
30
31 If we combine the linear and quadratic terms, we get
32  ρ 
33 xt+1 = argmin Ls (x) + ||x − zt + (1/ρ)yt ||22 (6.59)
x 2
34  ρ 
35 zt+1 = argmin Lr (z) + ||xt+1 − z − (1/ρ)yt ||22 (6.60)
z 2
36
37 Finally, if we define ut = (1/ρ)yt and λ = 1/ρ, we can now write this in a more general way:
38
39 xt+1 = proxλLs (zt − ut ) (6.61)
40 zt+1 = proxλLr (xt+1 + ut ) (6.62)
41
ut+1 = ut + xt+1 − zt+1 (6.63)
42
43 This is called the alternating direction method of multipliers or ADMM algorithm. The
44 advantage of this method is that the different terms in the objective (along with any constraints they
45 may have) are handled completely independently, allowing different solvers to be used. Furthermore,
46 the method can be extended to the stochastic setting as shown in [Zhong2014].
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


62

1
2
3
4
5
6
7
8
9
10
11
12
13
14
Figure 6.3: Robust PCA applied to some frames from a surveillance video. First column is input image.
15
Second column is low-rank background model. Third model is sparse foreground model. Last column is derived
16
foreground mask. From Figure 1 of [Bouwmans2017]. Used with kind permission of Thierry Bouwmans.
17
18
19 6.1.6.1 Example: robust PCA
20
21 In this section, we give an example of ADMM from [proximal].
22 Consider the following matrix decomposition problem:
23 J J
X X
24 minimize γj φj (Xj ) s.t. Xj = A (6.64)
X1:J
25 j=1 j=1
26
27
where A ∈ Rm×n is a given data matrix, Xj ∈ Rm×n are the optimization variables, and γj > 0 are
28
trade-off parameters.
29
For example, suppose we want to find a good least squares approximation to A as a sum of a low
30
rank matrix plus a sparse matrix. This is called robust PCA [Candes2011], since the sparse matrix
31
can handle the small number of outliers that might otherwise cause the rank of the approximation to
32
be high. The method is often used to decompose surveillance videos into a low rank model for the
33
static background, and a sparse model for the dynamic foreground objects, such as moving cars or
34
people, as illustrated in Figure 6.3. (See e.g., [Bouwmans2017] for a review.) RPCA can also be
35
used to remove small “outliers”, such as specularities and shadows, from images of faces, to improve
36
face recognition.
37
We can formulate robust PCA as the following optimization problem:
38 minimize ||A − (L + S)||2F + γL ||L||∗ + γS ||S||1 (6.65)
39
40 which is a sparse plus low rank decomposition of the observed data matrix. We can reformulate this
41 to match the form of a canonical matrix decomposition problem by defining X1 = L, X2 = S and
42 X3 = A − (X1 + X2 ), and then using these loss functions:
43
φ1 (X1 ) = ||X1 ||∗ , φ2 (X2 ) = ||X2 ||1 , φ3 (X3 ) = ||X3 ||2F (6.66)
44
45 We can tackle such matrix decomposition problems using ADMM, where we use the split Ls (X) =
P PJ
j γj φj (Xj ) and Lr (X) = IC (X), where X = (X1 , . . . , XJ ) and C = {X1:J : j=1 Xj = A}. The
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
6.2. DYNAMIC PROGRAMMING

1
2 overall algorithm becomes
3
1
4 Xj,t+1 = proxηt φj (Xj,t − Xt + A − Ut ) (6.67)
5 N
1
6 Ut+1 = Ut + Xt+1 − A (6.68)
7 J
8
where X is the elementwise average of X1 , . . . , XJ . Note that the Xj can be updated in parallel.
9
Projection onto the `1 norm is discussed in Section 6.1.2.3, projection onto the nuclear norm is
10
discussed in Section 6.1.2.6. projection onto the squared Frobenius norm is the same as projection
11
onto
P the squared Euclidean norm discussed in Section 6.1.2.5, and projection onto the constraint set
12
13 j Xj = A can be done using the averaging operator:

14
1
15 projC (X1 , . . . , XJ ) = (X1 , . . . , XJ ) − X + A (6.69)
J
16
17 An alternative to using `1 minimization in the inner loop is to use hard thresholding [Cherapanamjeri2017].
18 Although not convex, this method can be shown to converge to the global optimum, and is much
19 faster.
20 It is also possible to formulate a non-negative version of robust PCA. Even though NRPCA
21 is not a convex problem, it is possible to find the globally optimal solution [Fattahi2018jmlr;
22 Anderson2019].
23
24
25
6.2 Dynamic programming
26
Dynamic programming is a way to efficiently find the globally optimal solution to certain kinds
27
of optimization problems. The key requirement is that the optimal solution be expressed in terms of
28
the optimal solution to smaller subproblems, which can be reused many times. Note that DP is more
29
of an algorithm “family” rather than a specific algorithm. We give some examples below.
30
31
32 6.2.1 Example: computing Fibonnaci numbers
33
Consider the problem of computing Fibonnaci numbers, defined via the recursive equation
34
35 Fi = Fi−1 + Fi−2 (6.70)
36
37 with base cases F0 = F1 = 1. Thus we have that F2 = 2, F3 = 3, F4 = 5, F5 = 8, etc. A
38 simple recursive algorithm to compute the first n Fibbonaci numbers is shown in Algorithm 6.1.
39 Unfortunately, this takes exponential time. For example, evaluating fib(5) proceeds as follows:
40
41
42 F5 = F4 + F3 (6.71)
43
= (F3 + F2 ) + (F2 + F1 ) (6.72)
44
= ((F2 + F1 ) + (F1 + F0 )) + ((F1 + F0 ) + F1 ) (6.73)
45
46 = (((F1 + F0 ) + F1 ) + (F1 + F0 ))((F1 + F0 ) + F1 ) (6.74)
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


64

1
2 We see that there is a lot of repeated computation. For example, fib(2) is computed 3 times. One
3 way to improve the efficiency is to use memoization, which means memorizing each function value
4 that is computed. This will result in a linear time algorithm. However, the overhead involved can be
5 high.
6 It is usually preferable to try to solve the problem bottom up, solving small subproblems first,
7 and then using their results to help solve larger problems later. A simple way to do this is shown in
8 Algorithm 6.2.
9
10 Algorithm 6.1: Fibbonaci numbers, top down
11
1 function fib(n)
12
2 if n = 0 or n = 1 then
13
3 return 1
14
4 else
15
5 return (fib(n − 1) + fib(n − 2))
16
17
18
19
20
Algorithm 6.2: Fibbonaci numbers, bottom up
21 1 function fib(n)
22 2 F0 := 1, F1 := 2
23 3 for i = 2, . . . , n do
24 4 Fi := Fi−1 + Fi−2
25 5 return Fn
26
27
28
29
6.2.2 ML examples
30
31 There are many applications of DP to ML problems, which we discuss elsewhere in this book. These
32 include the forwards-backwards algorithm for inference in HMMs (Main Section 9.2.3), the Viterbi
33 algorithm for MAP sequence estimation in HMMs (Main Section 9.2.6), inference in more general
34 graphical models (Section 9.2), reinforcement learning (Main Section 34.6), etc.
35
36
6.3 Conjugate duality
37
38 In this section, we briefly discuss conjugate duality, which is a useful way to construct linear lower
39 bounds on non-convex functions. We follow the presentation of [BishopBook].
40
41
42
6.3.1 Introduction
43 Consider an arbitrary continuous function f (x), and suppose we create a linear lower bound on it of
44 the form
45
46 L(x, λ) , λT x − f ∗ (λ) ≤ f (x) (6.75)
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
6.3. CONJUGATE DUALITY

1
2
3 f(x) f(x)
4 *
−f (λ)
5
6
y

y
7
8 λx
9
λx − f*(λ)
10
11
12 x x
13
(a) (b)
14
15 Figure 6.4: Illustration of a conjugate functon. Red line is original function f (x), and the blue line is a linear
16 lower bound λx. To make the bound tight, we find the x where ∇f (x) is parallel to λ, and slide the line up to
17 touch there; the amount we slide up is given by f ∗ (λ). Adapted from Figure 10.11 of [BishopBook].
18
19
20
21
where λ is the slope, which we choose, and f ∗ (λ) is the intercept, which we solve for below. See
22
Figure 6.4(a) for an illustration.
23
For a fixed λ, we can find the point xλ where the lower bound is tight by “sliding” the line upwards
24
until it touches the curve at xλ , as shown in Figure 6.4(b). At xλ , we minimize the distance between
25
the function and the lower bound:
26
27
28
xλ , argmin f (x) − L(x, λ) = argmin f (x) − λT x (6.76)
x x
29
30 Since the bound is tight at this point, we have
31
32
f (xλ ) = L(xλ , λ) = λT xλ − f ∗ (λ) (6.77)
33
34
and hence
35
36
37 f ∗ (λ) = λT xλ − f (xλ ) = max λT x − f (x) (6.78)
x
38
39 The function f ∗ is called the conjugate of f , also known as the Fenchel transform of f . For the
40 special case of differentiable f , f ∗ is called the Legendre transform of f .
41 One reason conjugate functions are useful is that they can be used to create convex lower bounds
42 to non-convex functions. That is, we have L(x, λ) ≤ f (x), with equality at x = xλ , for any function
43 f : RD → R. For any given x, we can optimize over λ to make the bound as tight as possible, giving
44 us a fixed function L(x); this is called a variational approximation. We can then try to maximize
45 this lower bound wrt x instead of maximizing f (x). This method is used extensively in approximate
46 Bayesian inference, as we discuss in Main Chapter 10.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


66

1
2 1
3
4
5
6
0.5
7
8
9
10
11 0
0 ξ 1.5 3
12
13 (a) (b)
14
15 Figure 6.5: (a) The red curve is f (x) = e−x and the colored lines are linear lower bounds. Each lower bound of
16 slope λ is tangent to the curve at the point xλ = − log(−λ), where f (xλ ) = elog(−λ) = −λ. For the blue curve,
17 this occurs at xλ = ξ. Adapted from Figure 10.10 of [BishopBook]. Generated by opt_lower_bound.ipynb.
18 (b) For a convex function f (x), its epipgraph can be represented as the intersection of half-spaces defined by
19
linear lower bounds of the form f † (λ). Adapted from Figure 13 of [Jaakkola99].
20
21
22 6.3.2 Example: exponential function
23
Let us consider an example. Suppose f (x) = e−x , which is convex. Consider a linear lower bound of
24
the form
25
26 L(x, λ) = λx − f † (λ) (6.79)
27
28 where the conjugate function is given by
29
30
f † (λ) = max λx − f (x) = −λ log(−λ) + λ (6.80)
x
31
32 as illustrated in Figure 6.5(a).
33 To see this, define
34
J(x, λ) = λx − f (x) (6.81)
35
36
We have
37
38 ∂J
= λx − f 0 (x) = λ + e−x (6.82)
39 ∂x
40
Setting the derivative to zero gives
41
42 xλ = arg max J(x, λ) = − log(−λ) (6.83)
43 x

44 Hence
45
46 f † (λ) = J(xλ , λ) = λ(− log(−λ)) − elog(−λ) = −λ log(−λ) + λ (6.84)
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
6.3. CONJUGATE DUALITY

1
2 6.3.3 Conjugate of a conjugate
3
It is interesting to see what happens if we take the conjugate of the conjugate:
4
5 f ∗∗ (x) = max λT x − f † (λ) (6.85)
6 λ

7
If f is convex, then f ∗∗ = f , so f and f † are called conjugate duals. To see why, note that
8
9 f ∗∗ (x) = max L(x, λ) ≤ f (x) (6.86)
λ
10
11 Since we are free to modify λ for each x, we can make the lower bound tight at each x. This perfectly
12 characterizes f , since the epigraph of a convex function is an intersection of half-planes defined by
13 linear lower bounds, as shown in Figure 6.5(b).
14 Let us demonstrate this using the example from Section 6.3.2. We have
15
16 f ∗∗ (x) = max λx − f † (λ) = max λx + λ log(−λ) − λ (6.87)
λ λ
17
18 Define
19
J ∗ (x, λ) = λx − f † (x) = λx + λ log(−λ) − λ (6.88)
20
21
We have
22  
∂ ∗ −1
23
J (x, λ) = x + log(−λ) + λ −1=0 (6.89)
24 ∂λ −λ
25 x = − log(−λ) (6.90)
26 −x
λx = −e (6.91)
27
28 Substituting back we find
29
30 f ∗∗ (x) = J ∗ (x, λx ) = (−e−x )x + (−e−x )(−x) − (−e−x ) = e−x = f (x) (6.92)
31
32 6.3.4 Bounds for the logistic (sigmoid) function
33
34 In this section, we use the results on conjugate duality to derive upper and lower bounds to the
35 logistic function, σ(x) = 1+e1−x .
36
37 6.3.4.1 Exponential upper bound
38
The sigmoid function is neither convex nor concave. However, it is easy to show that f (x) =
39
log σ(x) = − log(1 + e−x ) is concave, by showing that its second derivative is negative. Now, any
40
convex function f (x) can be represented by
41
42 f (x) = min ηx − f † (η) (6.93)
η
43
44 where
45
46 f † (η) = min ηx − f (x) (6.94)
x
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


68

1
2 1 1

3
4 0.8 0.8

5 eta = 0.2
6 0.6 0.6
eta = 0.7
7
8 0.4 0.4
9
xi = 2.5
10 0.2 0.2
11
12 0 0
-6 -4 -2 0 2 4 6 -6 -4 -2 0 2 4 6
13
14 (a) (b)
15
16 Figure 6.6: Illustration of (a) exponental upper bound and (b) quadratic lower bound to the sigmoid function.
17 Generated by sigmoid_upper_bounds.ipynb and sigmoid_lower_bounds.ipynb.
18
19
20 One can show that if f (x) = log σ(x), then
21
f † (η) = −η ln η − (1 − η) ln(1 − η) (6.95)
22
23 which is the binary entropy function. Hence
24
25
log σ(x) ≤ ηx − f † (η) (6.96)

26 σ(x) ≤ exp(ηx − f (η)) (6.97)
27
28
This exponential upper bound on σ(x) is illustrated in Figure 6.6(a).
29
30 6.3.4.2 Quadratic lower bound
31 It is also useful to compute a lower bound on σ(x). If we make this a quadratic lower bound, it will
32 “play nicely” with Gaussian priors, which simplifies the analysis of several models. This approach was
33 first suggested in [Jaakkola96b].
34 First we write
35  
36 log σ(x) = − log(1 + e−x ) = − log e−x/2 (ex/2 + e−x/2 ) (6.98)
37
= x/2 − log(ex/2 + e−x/2 ) (6.99)
38
39 The function f (x) = − log(ex/2 + e−x/2 ) is a convex function of y = x2 , as can be verified by showing
dx2 f (x) > 0. Hence we can create a linear lower bound on f , using the conjugate function
40 d
41

42 f † (η) = max
2
ηx2
− f ( x2 ) (6.100)
x
43
44 We have
45
dx d 1 x
46 0=η− f (x) = η + tanh( ) (6.101)
dx2 dx 4x 2
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
6.4. THE BAYESIAN LEARNING RULE

1
2 The lower bound is tangent at the point xη = ξ, where
3  
1 ξ 1 1
4 η = − tanh( ) = − σ(ξ) − = −λ(ξ) (6.102)
5 4ξ 2 2ξ 2
6
The conjugate function can be rewritten as
7
8
f † (λ(ξ)) = −λ(ξ)ξ 2 − f (ξ) = λ(ξ)ξ 2 + log(eξ/2 + e−ξ/2 ) (6.103)
9
10 So the lower bound on f becomes
11
12 f (x) ≥ −λ(ξ)x2 − g(λ(ξ)) = −λ(ξ)x2 + λ(ξ)ξ 2 − log(eξ/2 + e−ξ/2 ) (6.104)
13
14 and the lower bound on the sigmoid function becomes
15  
16
σ(x) ≥ σ(ξ) exp (x − ξ)/2 − λ(ξ)(x2 − ξ 2 ) (6.105)
17
This is illustrated in Figure 6.6(b).
18
Although a quadratic is not a good representation for the overall shape of a sigmoid, it turns out
19
that when we use the sigmoid as a likelihood function and combine it with a Gaussian prior, we
20
get a Gaussian-like posterior; in this context, the quadratic lower bound works quite well (since a
21
quadratic likelihood times a Gaussian prior will yield an exact Gaussian posterior). See Section 15.1.1
22
for an example, where we use this bound for Bayesian logistic regression.
23
24
25 6.4 The Bayesian learning rule
26
27
In this section, we discuss the “Bayesian learning rule” [BLR], which provides a unified framework
28
for deriving many standard (and non-standard) optimization and inference algorithms used in the
29
ML community.
30
To motivate the BLR, recall the standard empirical risk minimization, or ERM problem,
31
which has the form θ∗ = argminθ `(θ), where
32 N
X
33 `(θ) = `(yn , fθ (xn )) + R(θ) (6.106)
34 n=1
35
36 where fθ (x) is a prediction function, `(y, ŷ) is a loss function, and R(θ) is some kind of regularizer.
37 Although the regularizer can prevent overfitting, the ERM method can still result in parameter
38 estimates that are not robust. A better approach is to fit a distribution over possible parameter
39 values, q(θ). If we minimize the expected loss, we will find parameter settings that will work well
40 even if they are slightly perturbed, as illustrated in Figure 6.7, which helps with robustness and
41 generalization. Of course, if the distribution q collapses to a single delta function, we will end up
42 with the ERM solution. To prevent this, we add a penalty term, that measures the KL divergence
43 from q(θ) to some prior π0 (θ) ∝ exp(−R(θ)). This gives rise to the following BLR objective:
44 "N #
45
X
L(q) = Eq(θ) `(yn , fθ (xn )) + DKL (q(θ) k π0 (θ)) (6.107)
46
n=1
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


70

1
2
3
4
5
6
7
8
9
10
11
12
13
14
Figure 6.7: Illustration of the robustness obtained by using a Bayesian approach to parameter estimation. (a)
15
When the minimum θ∗ lies next to a “wall”, the Bayesian solution shifts away from the boundary to avoid
16
large losses due to perturbations of the parameters. (b) The Bayesian solution prefers flat minima over sharp
17
minima, to avoid large losses due to perturbations of the parameters. From Figure 1 of [BLR]. Used with
18 kind permission of Emtiyaz Khan.
19
20
21
We can rewrite the KL term as
22
23 DKL (q(θ) k π0 (θ)) = Eq(θ) [R(θ)] − H(q(θ)) (6.108)
24
25 and hence can rewrite the BLR objective as follows:
26  
L(q) = Eq(θ) `(θ) − H(q(θ)) (6.109)
27
28 Below we show that different approximations to this objective recover a variety of different methods
29 in the literature.
30
31
6.4.1 Deriving inference algorithms from BLR
32
33 In this section we show how to derive several different inference algorithms from BLR. (We discuss
34 such algorithms in more detail in ??.)
35
36 6.4.1.1 Bayesian inference as optimization
37
38
The BLR objective includes standard exact Bayesian inference as a special case, as first shown in
39
[Zellner1988]. To see this, let us assume the loss function is derived from a log-likelihood:
40
`(y, fθ (x)) = − log p(y|fθ (x)) (6.110)
41
42 Let D = {(xn , yn ) : n = 1 : N } be the data we condition on. The Bayesian posterior can be written
43 as
44
YN
45 1
p(θ|D) = π0 (θ) p(yn |fθ (xn )) (6.111)
46 Z(D) n=1
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
6.4. THE BAYESIAN LEARNING RULE

1
2 This can be derived by minimizing the BLR, since
3 "N #
4
X
L(q) = −Eq(θ) log p(yn |fθ (xn )) + DKL (q(θ) k π0 (θ)) (6.112)
5
n=1
6  
7 q(θ)
= Eq(θ) log π0 (θ) QN  − log Z(D) (6.113)
8
Z(D) n=1 p(yn |fθ (xn ))
9
10 = DKL (q(θ) k p(θ|D)) − log Z(D) (6.114)
11
12
Since Z(D) is a constant, we can minimize the loss by setting q(θ) = p(θ|D).
13
Of course, we can use other kinds of loss, not just log likelihoods. This results in a framework
14
known as generalized Bayesian inference [Bissiri2016; Knoblauch2019; Knoblauch2021].
15
See ?? for more discussion.
16
17 6.4.1.2 Optimization of BLR using natural gradient descent
18
In general, we cannot compute the exact posterior q(θ) = p(θ|D), so we seek an approximation. We
19
will assume that q(θ) is an exponential family distibution, such as a multivariate Gaussian, where
20
the mean represents the standard point estimate of θ (as in ERM), and the covariance represents our
21
uncertainty (as in Bayes). Hence q can be written as follows:
22
23 q(θ) = h(θ) exp[λT T (θ) − A(λ)] (6.115)
24
25 where λ are the natural parameters, T (θ) are the sufficient statistics, A(λ) is the log partition
26 function, and h(θ) is the base measure, which is usually a constant. The BLR loss becomes
27  
28
L(λ) = Eqλ (θ) `(θ) − H(qλ (θ)) (6.116)
29
We can optimize this using natural gradient descent (??). The update becomes
30
h   i
31 ˜ λ Eq (θ) `(θ) − H(qλ )
λt+1 = λt − ηt ∇ (6.117)
λt t
32
33
where ∇˜ λ denotes the natural gradient. We discuss how to compute these natural gradients in ??.
34
In particular, we can convert it to regular gradients wrt the moment parameters µt = µ(λt ). This
35
gives
36
 
37
λt+1 = λt − ηt ∇µ Eqµt (θ) `(θ) + ηt ∇µ H(qµt ) (6.118)
38
39 From ?? we have
40
41 ∇µ H(q) = −λ − ∇µ Eqµ (θ) [log h(θ)] (6.119)
42
43
Hence the update becomes
 
44
λt+1 = λt − ηt ∇µ Eqµt (θ) `(θ) − ηt λt − ηt ∇µ Eqµ (θ) [log h(θ)] (6.120)
45  
46 = (1 − ηt )λt − ηt ∇µ Eqµ (θ) `(θ) + log h(θ) (6.121)
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


72

1
2 For distributions q with constant base measure h(θ), this simplifies to
3  
4 λt+1 = (1 − ηt )λt − ηt ∇µ Eqµ (θ) `(θ) (6.122)
5
6 Hence at the fixed point we have
7  
8 λ∗ = (1 − η)λ∗ − η∇µ Eqµ (θ) `(θ) (6.123)
   
9
λ∗ = ∇µ Eqµ (θ) −`(θ) = ∇ ˜ λ Eq (θ) −`(θ) (6.124)
λ
10
11
12 6.4.1.3 Conjugate variational inference
13
In ?? we show how to do exact inference in conjugate models. We can derive ?? from the BLR by
14
using the fixed point condition in Equation (6.124) to write
15
16
XN
17  
λ∗ = ∇µ Eq∗ −`(θ) = λ0 + ∇µ Eq∗ [log p(yi |θ)] (6.125)
18
i=1
| {z }
19 λ̃i (yi )
20
21 where λ̃i (yi ) are the sufficient statistics for the i’th likelihood term.
22 For models where the joint distribution over the latents factorizes (using a graphical model), we
23 can further decompose this update into a series of local terms. This gives rise to the variational
24 message passing scheme discussed in ??.
25
26
6.4.1.4 Partially conjugate variational inference
27
28 In Supplementary ??, we discuss CVI, which performs variational inference for partially conjugate
29 models, using gradient updates for the non-conjugate parts, and exact Bayesian inference for the
30 conjugate parts.
31
32
33 6.4.2 Deriving optimization algorithms from BLR
34
In this section we show how to derive several different optimization algorithms from BLR. Recall
35
that in BLR, instead of directly minimizing the loss
36
37 N
X
38 `(θ) = `(yn , fθ (xn )) + R(θ) (6.126)
39 n=1
40
41 we will instead minimize
42
 
43 L(λ) = Eq(θ|λ) `(θ) − H(q(θ|λ)) (6.127)
44
45 Below we show that different approximations to this objective recover a variety of different optimization
46 methods that are used in the literature.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
6.4. THE BAYESIAN LEARNING RULE

1
2 6.4.2.1 Gradient descent
3
In this section, we show how to derive gradient descent as a special case of BLR. We use as our
4
approximate posterior q(θ) = N (θ|m, I). In this case the natural and moment parameters are equal,
5
µ = λ = m. The base measure satisfies the following (from ??):
6
7 2 log h(θ) = −D log(2π) − θT θ (6.128)
8
9 Hence
10
11
∇µ Eq [log h(θ)] = ∇µ (−D log(2π) − µT µ − D) = −µ = −λ = −m (6.129)
12
Thus from Equation (6.121) the BLR update becomes
13
 
14 mt+1 = (1 − ηt )mt + ηt mt − ηt ∇m Eqm (θ) `(θ) (6.130)
15
16 We can remove the expectation using the first order delta method [verHoef2012] to get
17  
∇m Eqm (θ) `(θ) ≈ ∇θ `(θ)|θ=m (6.131)
18
19 Putting these together gives the gradient descent update:
20
21 mt+1 = mt − ηt ∇θ `(θ)|θ=mt (6.132)
22
23 6.4.2.2 Newton’s method
24
25
In this section, we show how to derive Newton’s second order optimization method as a special case
26
of BLR, as first shown in [Khan2018icml].
27
Suppose we assume q(θ) = N (θ|m, S−1 ). The natural parameters are
28 1
29 λ(1) = Sm, λ(2) = − S (6.133)
2
30
31 The mean (moment) parameters are
32
µ(1) = m, µ(2) = S−1 + mmT (6.134)
33
34 Since the base measure is constant (see ??), from Equation (6.122) we have
35  
36 St+1 mt+1 = (1 − ηt )St mt − ηt ∇µ(1) Eqµ (θ) `(θ) (6.135)
 
37 St+1 = (1 − ηt )St + 2ηt ∇µ(2) Eqµ (θ) `(θ) (6.136)
38
39 In ?? we show that
40      
41 ∇µ(1) Eq(θ) `(θ) = Eq(θ) ∇θ `(θ) − Eq(θ) ∇2θ `(θ) m (6.137)
42   1  
∇µ(2) Eq(θ) `(θ) = Eq(θ) ∇2θ `(θ) (6.138)
43 2
44
Hence the update for the precision matrix becomes
45
 
46 St+1 = (1 − ηt )St + ηt Eqt ∇2θ `(θ) (6.139)
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


74

1
2 For the precision weighted mean, we have
3    
St+1 mt+1 = (1 − ηt )St mt − ηt Eqt ∇θ `(θ) + ηt Eqt ∇2θ `(θ) mt (6.140)
4  
5 = St+1 mt − ηt Eqt ∇θ `(θ) (6.141)
6
Hence
7
 
8 mt+1 = mt − ηt S−1
t+1 Eqt ∇θ `(θ) (6.142)
9
10
We can recover Newton’s method in three steps. First set the learning rate to ηt = 1, based on
11
an assumption that the objective is convex. Second, treat the iterate as mt = θt . Third, apply the
12
delta method to get
 
13 St+1 = Eqt ∇2θ `(θ) ≈ ∇2θ `(θ)|θ=mt (6.143)
14
15 and
 
16
Eq ∇θ `(θ) ≈ ∇θ `(θ)|θ=mt (6.144)
17
18 This gives Newton’s update:
19
mt+1 = mt − [∇2m `(mt )]−1 [∇m `(mt )] (6.145)
20
21
22
6.4.2.3 Variational online Gauss-Newton
23 In this section, we describe various second order optimization methods that can be derived from the
24 BLR using a series of simplifications.
25 First, we use a diagonal Gaussian approximation to the posterior, qt (θ) = N (θ|θt , S−1 t ), where
26 St = diag(st ) is a vector of precisions. Following Section 6.4.2.2, we get the following updates:
27
1  
28 θt+1 = θt − ηt Eqt ∇θ `(θ) (6.146)
29 st+1
 
30 st+1 = (1 − ηt )st + ηt Eqt diag(∇2θ `(θ)) (6.147)
31
32
where is elementwise multiplication, and the division by st+1 is also elementwise.
33
Second, we use the delta approximation to replace expectations by plugging in the mean. Third,
34
we use a minibatch approximation to the gradient and diagonal Hessian:
X
35
∇ˆ θ `(θ) = N ∇θ `(yi , fθ (xi )) + ∇θ R(θ) (6.148)
36 M
i∈M
37
X
38 ˆ 2θ `(θ) = N
∇ ∇2θj `(yi , fθ (xi )) + ∇2θj R(θ) (6.149)
39
j
M
i∈M
40
where M is the minibatch size.
41
For some non-convex problems, such as DNNs, the Hessian may be not be positive definite, so we
42
can get better results using a Gauss-Newton approximation, based on the squared gradients instead
43
of the Hessian:
44
X
45 ˆ 2 `(θ) ≈ N
∇ [∇θj `(yi , fθ (xi ))]2 + ∇2θj R(θ) (6.150)
θj
46 M
i∈M
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
6.4. THE BAYESIAN LEARNING RULE

1
2 This is also faster to compute.
3 Putting all this together gives rise to the online Gauss-Newton or OGN method of [Osawa2019nips].
4 If we drop the delta approximation, and work with expectations, we get the variational pnline
5 Gauss-Newton or VOGN method of [Khan2018icml]. We can approximate the expectations by
6 sampling. In particular, VOGN uses the following weight perturbation method
7 h i
8 Eqt ∇ˆ θ `(θ) ≈ ∇
ˆ θ `(θt + t ) (6.151)
9
10
where t ∼ N (0, diag(st )). It is also possible to approximate the Fisher information matrix di-
11
rectly; this results in the variational online generalized Gauss-Newton or VOGGN method
12
of [Osawa2019nips].
13
14
15 6.4.2.4 Adaptive learning rate SGD
16
In this section, we show how to derive an update rule which is very similar to the RMSprop
17
[RMSprop] method, which is widely used in deep learning. The approach we take is similar to that
18
VOGN in Section 6.4.2.3. We use the same diagonal Gaussian approximation, qt (θ) = N (θ|θt , S−1 t ),
19
where St = diag(st ) is a vector of precisions. We then use the delta method to eliminate expectations:
20
21 1
22 θt+1 = θt − ηt ∇θ `(θt ) (6.152)
st+1
23
24
st+1 = (1 − ηt )st + ηt diag(∇2θ `(θt )) (6.153)
25
26
where is elementwise multiplication. If we allow for different learning rates we get
27
1
28 θt+1 = θt − αt ∇θ `(θt ) (6.154)
st+1
29
30
ˆ 2θ `(θt ))
st+1 = (1 − βt )st + βt diag(∇ (6.155)
31
32 Now suppose we replace the diagonal Hessian approximation with the sum of the squares per-sample
33 gradients:
34
35
ˆ
diag(∇2θ `(θt )) ≈ ∇`(θ t)
ˆ
∇`(θ t) (6.156)
36
37 If we also change some scaling factors we can get the RMSprop updates:
38
1 ˆ θ `(θt )
39 θt+1 = θt − α √ ∇ (6.157)
40 vt+1 + c1
41 vt+1 ˆ
= (1 − β)vt + β[∇`(θ t)
ˆ
∇`(θ t )] (6.158)
42
43 This allows us to use standard deep learning optimizers to get a Gaussian approximation to the
44 posterior for the parameters [Osawa2019nips].
45 It is also possible to derive the Adam optimizer [adam] from BLR by adding a momentum term
46 to RMSprop. See [BLR; Aitchison2018] for details.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


76

1
2 6.4.3 Variational optimization
3
Consider an objective defined in terms of discrete variables. Such objectives are not differentiable
4
and so are hard to optimize. One advantage of BLR is that it optimizes the parameters of a
5
probability distribution, and such expected loss objectives are usually differentiable and smooth. This
6
is called “variational optimization” [BarberVarOpt], since we are optimizing over a probability
7
distribution.
8
For example, consider the case of a binary neural network where θd ∈ {0, 1} indicates if
9
weight d is used or not. We can optimize over the parameters of a Bernoulli distribution, q(θ|λ) =
10 QD
11 d=1 Ber(θd |pd ), where pd ∈ [0, 1] and λd = log(pd /(1 − pd )) is the log odds. This is the basis of the
BayesBiNN approach [Meng2020icml].
12
If we ignore the entropy and regularizer term, we get the following simplified objective:
13
14
Z
15 L(λ) = `(θ)q(θ|λ)dθ (6.159)
16
17 This method has various names: stochastic relaxation [Staines2012; Staines2013; Malago2013],
18 stochastic approximation [Hu2012ac; Hu2012survey], etc. It is closely related to evolution-
19 ary strategies, which we discuss in ??.
20 In the case of functions with continuous domains, we can use a Gaussian for q(θ|µ, Σ). The
21 resulting integral in Equation (6.159) can then sometimes be solved in closed form, as explained in
22 [Mobahi2016]. By starting with a broad variance, and gradually reducing it, we hope the method
23 can avoid poor local optima, similar to simulated annealing (??). However, we generally get better
24 results by including the entropy term, because then we can automatically learn to adapt the variance.
25 In addition, we can often work with natural gradients, which results in faster convergence.
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
Part II

Inference
7 Inference algorithms: an overview
8 Inference for state-space models

8.1 More Kalman filtering


In this section, we discuss various variants and extensions of Kalman filtering.

8.1.1 Example: tracking an object with spiral dynamics


Consider a variant of the 2d tracking problem in Main Section
 29.7.1, where the hidden state is the
position and velocity of the object, zt = ut vt u̇t v̇t . (We use u and v for the two coordinates,
to avoid confusion with the state and observation variables.) We use the following dynamics matrix:
 
0.1 1.1 ∆ 0
−1 1 0 ∆
F= 0
 (8.1)
0 0.1 0 
0 0 0 0.1

The eigenvectors of the top left block of this transition matrix are complex, resulting in cyclical
behavior, as explained in [Strogatz2015]. Furthermore, since the velocities are shrinking at each
step by a factor of 0.1, the cycling behavior becomes a spiral inwards, as illustrated by the line
in Figure 8.1(a). The crosses correspond to noisy measurements of the location, as before. In
Figure 8.1(b-c), we show the results of Kalman filtering and smoothing.

8.1.2 Derivation of RLS


In this section, we explicitly derive the recursive least squares equations.
Recall from Main Section 8.2.2 that the Kalman filter equations are as follows:

Σt|t−1 = Ft Σt−1 FTt + Qt (8.2)


St = Ht Σt|t−1 HTt + Rt (8.3)
Kt = Σt|t−1 HTt S−1
t (8.4)
µt = Ft−1 µt−1 + Kt (yt − Ht Ft−1 µt−1 ) (8.5)
Σt = (I − Kt Ht )Σt|t−1 (8.6)

In the case of RLS we have Ht = uTt , Ft = I, Qt = 0 and Rt = σ 2 . Thus Σt|t−1 = Σt−1 , and the
82

1
2
8 10.0 true state 10.0 true state
3 filtered smoothed
6 7.5 7.5
4
4 5.0 5.0
5 2 2.5 2.5
6 0 0.0 0.0
2 2.5 2.5
7
4 5.0 5.0
8
6 7.5 7.5
9
7.5 5.0 2.5 0.0 2.5 5.0 7.5 10.0 15 10 5 0 5 10 15 15 10 5 0 5 10 15
10
11 (a) (b) (c)
12
13 Figure 8.1: Illustration of Kalman filtering and smoothing for a linear dynamical system. (a) Observed data.
14 (b) Filtering. (c) Smoothing. Generated by kf_spiral.ipynb.
15
16
17
18
remaining equations simplify as follows:
19
20
st = uTt Σt−1 ut + σ 2 (8.7)
1
21
kt = Σt−1 ut (8.8)
22 st
23 1
µt = µt−1 + kt (yt − uTt µt−1 ) = µt−1 + Σt−1 ut (yt − uTt µt−1 ) (8.9)
24 st
25 1
Σt = (I − kt uTt )Σt−1 = Σt−1 − Σt−1 ut uTt Σt−1 (8.10)
26 st
27
28 Note that from Main Equation (8.32), we can also write the Kalman gain as
29
1 1
30
kt = (Σ−1 + ut uTt )−1 ut (8.11)
31 σ 2 t−1 σ 2
32
33
Also, from Main Equation (8.30), we can also write the posterior covariance as
34
35
Σt = Σt−1 − st kt kTt (8.12)
36
37
If we let Vt = Σt /σ 2 , we can further simplify the equations, as follows [Borodachev2016].
38
σ 2 Vt−1 ut (yt − uTt µt−1 ) Vt−1 ut (yt − uTt µt−1 )
39 µt = µt−1 + = µt−1 + (8.13)
40 σ 2 (uTt Vt−1 ut + 1) uTt Vt−1 ut + 1
T
41 Vt−1 ut ut Vt−1
Vt = Vt−1 − T (8.14)
42 ut Vt−1 ut + 1
43
44 We can initialize these recursions using a vague prior, µ0 = 0, Σ0 = ∞I. In this case, the posterior
45 mean will converge to the MLE, and the posterior standard deviations will converge to the standard
46 error of the mean.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
8.1. MORE KALMAN FILTERING

1
2 8.1.3 Handling unknown observation noise
3
In the case of scalar observations (as often arises in time series forecasting), we can extend the
4
Kalman filter to handle the common situation in which the observation noise variance V = σ 2 is
5
unknown, as described in [West97]. The model is defined as follows:
6
7 p(zt |zt−1 ) = N (zt |Ft zt−1 , V Q∗t ) (8.15)
8
p(yt |zt ) = N (yt |hTt zt , V ) (8.16)
9
10 where Q∗t is the unscaled system noise, and we define Ht = hTt to be the vector that maps the
11 hidden state vector to the scalar observation. Let λ = 1/V be the observation precision. To start the
12 algorithm, we use the following prior:
13 ν0 ν0 τ0
p0 (λ) = Ga( , ) (8.17)
14 2 2

15 p0 (z|λ) = N (µ0 , V Σ0 ) (8.18)
16
where τ0 is the prior mean for σ 2 , and ν0 > 0 is the strength of this prior.
17
We now discuss the belief updating step. We assume that the prior belief state at time t − 1 is
18
νt−1 νt−1 τt−1
19 N (zt−1 , λ|D1:t−1 ) = N (zt−1 |µt−1 , V Σ∗t−1 )Ga(λ| , ) (8.19)
20 2 2
21 The posterior is given by
22 νt νt τt
N (zt , λ|D1:t ) = N (zt |µt , V Σ∗t )Ga(λ| , ) (8.20)
23 2 2
24 where
25
26
µt|t−1 = Ft µt−1 (8.21)
27 Σ∗t|t−1 = Ft Σ∗t−1 Ft + Q∗t (8.22)
28
et = yt − hTt µt|t−1 (8.23)
29
30 s∗t = hTt Σ∗t|t−1 ht + 1 (8.24)
31
kt = Σ∗t|t−1 ht /s∗t (8.25)
32
33
µt = µt−1 + kt et (8.26)
34 Σ∗t = Σ∗t|t−1 − kt kTt s∗t (8.27)
35
νt = νt−1 + 1 (8.28)
36
37 νt τt = νt−1 τt−1 + e2t /s∗t (8.29)
38 If we marginalize out V , the marginal distribution for zt is a Student distribution:
39
40 p(zt |D1:t ) = Tνt (zt |µt , τt Σ∗t ) (8.30)
41 The one-step-ahead posterior predictive density for the observations is given by
42
43
p(yt |y1:t−1 ) = Tνt−1 (yt |ŷy , τt−1 s∗t ) (8.31)
44 These equations only differs from the standard KF equations by the scaling term τt (or τt−1 for
45 the predictive), and the use of a Student distribution instead of a Gaussian. However, as νt increases
46 over time, the Student distribution will rapidly converge to a Gaussian.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


84

1
2 8.1.4 Predictive coding as Kalman filtering
3
In the field of neuroscience, a popular theoretical model for how the brain works is known as predic-
4
tive coding (see e.g., [Rao1999; Friston2003; Spratling2017; Millidge2021pc; Marino2021]).
5
This posits that the core function of the brain is simply to minimize prediction error at each layer
6
of a hierarchical model, and at each moment in time. There is considerable bological evidence for
7
this (see above references). Furthermore, it turns out that the predictive coding algorithm, when
8
applied to a linear Gaussian state-space model, is equivalent to the Kalman filter, as shown in
9
[Millidge2021phd].
10
To see this, we adopt the framework of inference as optimization, QT as used in variational inference.
11
The joint distribution is given by p(y1:T , z1:T ) = p(y1 |z1 )p(z1 ) t=2 p(yt |zt )p(zt |zt−1 ). Our goal is
12
to approximate the filtering distribution, p(zt |y1:t ). We will use a fully factorized approximation of
13 QT
the form q(z1:T ) = t=1 q(zt ). Following Main Section 10.1.1.1, the variational free energy (VFE) is
14 PT
15
given by F(ψ) = t=1 Ft (ψ t ), where
16
17
Ft (ψ t ) = Eq(zt−1 ) [DKL (q(zt ) k p(zt , yt |zt−1 ))] (8.32)
18
We will use a Gaussian approximation for q at each step. Furthermore, we will use the Laplace
19
approximation, which derives the covariance from the Hessian at the mean. Thus we have q(zt ) =
20
N (zt |µt , Σ(µt )), where ψ t = µt is the variational parameter which we need to compute. (Once we
21
have computed µt , we can derive Σ.)
22
Since the posterior is fully factorized, we can focus on a single time step. The VFE is given by
23
 
24
Ft (µt ) = −Eq(zt |µt ) log p(yt , zt |µt−1 ) − H (q(zt |µt )) (8.33)
25
26 Since the entropy of a Gaussian is independent of the mean, we can drop this second term. For the
27 first term, we use the Laplace approximation, which computes a second order Taylor series around
28 the mode:
29
     
30 E log p(yt , zt |µt−1 ) ≈ E log p(yt , zt |µt−1 ) + E ∇zt p(yt , zt |µt−1 )|zt =µt (zt − µt ) (8.34)
31  2 
+ E ∇zt p(yt , zt |µt−1 )|zt =µt (zt − µt )2
(8.35)
32
33 = log p(yt , µt |µt−1 ) + ∇zt p(yt , zt |µt−1 )|zt =µt E [(zt − µt )] (8.36)
| {z }
34 0
 
35 + ∇2zt p(yt , zt |µt−1 )|zt =µt E (zt − µt )2 (8.37)
36 | {z }
Σ
37
38 We can drop the second and third terms, since they are independent of µt . Thus we just need to
39 solve
40
41 µ∗t = argmin Ft (µt ) (8.38)
µt
42
43 Ft (µt ) = log p(yt , µt |µt−1 ) (8.39)
44
= −(yt − Hµt )Σ−1
y (yy − Hµt ) (8.40)
45
46
T
+ (µt − Fµt−1 − But−1 ) (µt − Fµt−1 − But−1 ) T
Σ−1
z (8.41)
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
8.2. MORE EXTENDED KALMAN FILTERING

1
2 We will solve this problem by gradient descent. The form of the gradient turns out to be very simple,
3 and involves two prediction error terms: one from the past state estimate, z = µt − Fµt−1 − But−1 ,
4 and one from the current observation, y = yt − Hµt :
5
6 ∇Ft (µt ) = 2HT Σ−1 T −1 T −T
y yt − (H Σy H + H Σy H)µt (8.42)
7 + (Σ−1 −1 −1
z + Σz −T)µt − 2Σz Fµt−1 − 2Σz But−1 (8.43)
8
9
= 2HT Σ−1 T −1 −1 −1
y Hµt − 2H Σy Hµt + 2Σz µt − 2Σz Fµt−1 − 2Σ−1
z But−1 (8.44)
10 = −HT Σ−1 −1
y [yt − Hµt ] + Σz [µt − Fµt−1 − But−1 ] (8.45)
11
= −HT Σ−1 −1
y y + Σz z (8.46)
12
13
Thus minimizing (precision weighted) prediction errors is equivalent to minimizing the VFE.1 In
14
this case the objective is convex, so we can find the global optimum. Furthermore, the resulting
15
Gaussian posterior is exact for this model class, and thus predictive coding gives the same results
16
as Kalman filtering. However, the advantage of predictive coding is that it is easy to extend to
17
hierarchical and nonlinear models: we just have to minimize the VFE using gradient descent (see
18
e.g., [Hosseini2020pc]).
19
Furthermore, we can also optimize the VFE with respect to the model parameters, as in variational
20
EM. In the case of linear Gaussian state-space models, [Millidge2021phd] show that for the dynamics
21
matrix the gradient is ∇F Ft = −Σz y µTt−1 , for the control matrix the gradient is ∇B Ft = −Σz y uTt−1 ,
22
and for the observation matrix the gradient is ∇H Ft = −Σy y µTt . These expressions can be
23
generalized to nonlinear models. Indeed, predictive coding can in fact approximate backpropagation
24
for many kinds of model [Millidge2020predictive].
25
Gradient descent using these predicting coding takes the form of a Hebbian update rule, in
26
which we set the new parameter to the old one plus a term that is a multiplication of the two
27
quantities available at each end of the synaptic connection, namely the prediction error  as input,
28
and the value µ (or θ) of the neuron as output. However, there are still several aspects of this
29
model that are biologically implausible, such as assuming symmetric weights (since both H and HT
30
are needed, the former to compute y and the latter to compute ∇µt Ft ), the need for one-to-one
31
alignment of error signals and parameter values, and the need (in the nonlinear case) for computing
32
the derivative of the activation function. In [Millidge2021] they develop an approximate, more
33
biologically plausible version of predictive coding that relaxes these requirements, and which does
34
not seem to hurt empirical performance too much.
35
36
37 8.2 More extended Kalman filtering
38
39
8.2.1 Derivation of the EKF
40 The derivation of the EKF is similar to the derivation of the Kalman filter (Main Section 8.2.2.4),
41 except we also need to apply the linear approximation from Main Section 8.3.1.
42
43 1. Scaling the error terms by the inverse variance can p be seen as a form of normaization. To see this, consider the
44 standardization operator: standardize(x) = (x − E [x])/ V [x]. It has been argued that that the widespread presence
of neural circuity for performing normalization, together with the upwards and downwards connections between brain
45
regions, adds support for the claim that the brain implements predictive coding (see e.g., [Rao1999predictive;
46 Friston2003; Spratling2017; Millidge2021pc; Marino2021]).
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


86

1
2 First we approximate the joint of zt−1 and zt = ft (zt−1 ) + qt−1 to get
3  
z
4
p(zt−1 , zt |y1:t−1 ) ≈ N ( t−1 |m0 , Σ0 ) (8.47)
5 zt
 
6 µt−1
m0 = (8.48)
7 f (µt−1 )
8  
Σt−1 Σt−1 FTt
9 Σ0 = (8.49)
Ft Σt−1 Ft Σt−1 FTt + Qt−1
10
11
From this we can derive the marginal p(zt |y1:t−1 ), which gives us the predict step.
12
For the update step, we first consider a Gaussian approximation to p(zt , yt |y1:t−1 ), where yt =
13
ht (zt ) + rt :
14
15
 
z
16 p(zt , yt |y1:t−1 ) ≈ N ( t |m00 , Σ00 ) (8.50)
yt
17  
µt|t−1
18 m00 = (8.51)
19 h(µt|t−1 )
 
20 Σt|t−1 Σt|t−1 HTt
Σ00 = (8.52)
21 Ht Σt|t−1 Ht Σt|t−1 HTt + Rt−1
22
23 Finally, we use Main Equation (2.78) to get the posterior
24
25 p(zt |yt , y1:t−1 ) ≈ N (zt |µt , Σt ) (8.53)
−1
26
µt = µt|t−1 + Σt|t−1 HTt (Ht Σt|t−1 HT + Rt ) [yt − h(µt|t−1 )] (8.54)
27
28 Σt = Σt|t−1 − Σt|t−1 HTt (Ht Σt|t−1 HT + Rt )−1 Ht Σt|t−1 (8.55)
29
30 This gives us the update step.
31
32 8.2.2 Example: Tracking a pendulum
33
34 Consider a simple pendulum of unit mass and length swinging from a fixed attachment, as in
35 Figure 8.2a. Such an object is in principle entirely deterministic in its behavior. However, in the real
36 world, there are often unknown forces at work (e.g., air turbulence, friction). We will model these by
37 a continuous time random white noise process w(t). This gives rise to the following differential
38 equation [Sarkka13]:
39
d2 α
40 = −g sin(α) + w(t) (8.56)
41 dt2
42
We can write this as a nonlinear SSM by defining the state to be z1 (t) = α(t) and z2 (t) = dα(t)/dt.
43
Thus
44
   
45 dz z2 0
= + w(t) (8.57)
46 dt −g sin(z1 ) 1
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
8.3. EXPONENTIAL-FAMILY EKF

1 5 5
True Angle True Angle
2 Measurements Measurements
4 EKF Estimate 4 UKF Estimate

3 3 3

2 2
4

Pendulum angle

Pendulum angle
1 1
5 α 0 0
6
w(t) 1 1
7
g 2 2

8 3
0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 4.5
3
0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0 4.5
Time Time
9
10
(a) (b) (c)
11
Figure 8.2: (a) Illustration of a pendulum swinging. g is the force of gravity, w(t) is a random external force,
12
and α is the angle wrt the vertical. Adapted from Figure 3.10 in [Sarkka13]. (b) Extended Kalman filter.
13 Generated by ekf_pendulum.ipynb. (b) Unscented Kalman filter. Generated by ukf_pendulum.ipynb.
14
15
16 If we discretize this step size ∆, we get the following formulation [Sarkka13]:
17    
z1,t z1,t−1 + z2,t−1 ∆
18 = +qt (8.58)
z2,t z2,t−1 − g sin(z1,t−1 )∆
19 | {z } | {z }
20 zt f (zt−1 )
21
where qt ∼ N (0, Q) with
22
3 2
!
23 ∆ ∆

24
Q = qc 3
∆2
2 (8.59)
2 ∆
25
26 where q c is the spectral density (continuous time variance) of the continuous-time noise process.
27 If we observe the angular position, we get the linear observation model h(zt ) = αt = zt,1 . If we only
28 observe the horizontal position, we get the nonlinear observation model h(zt ) = sin(αt ) = sin(zt,1 ).
29 To apply the EKF to this problem, we need to compute the following Jacobian matrices:
 
30 1 ∆ 
31 F(z) = , H(z) = cos(z1 ) 0 (8.60)
−g cos(z1 )∆ 1
32
33
The results are shown in Figure 8.2b.
34
35 8.3 Exponential-family EKF
36
37 In this section, we present an extension of the EKF to the case where the observation model is in the
38 exponential family, as proposed in [Ollivier2018]. We call this the Exponential family EKF or
39 EEKF. This allows us to apply the EKF for online parameter estimation of classification models, as
40 we illustrate in Section 8.3.3.
41
42 8.3.1 Modeling assumptions
43
44
We assume the dynamics model is the usual nonlinear model plus Gaussian noise, with optional
45
inputs ut :
46 zt = f (zt−1 , ut ) + N (0, Qt ) (8.61)
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


88

1
2 We assume the observation model is
3
4 p(yt |zt ) = Expfam(yt |ŷt ) (8.62)
5
where the mean (moment) parameter of the exponential family is computed deterministically using a
6
nonlinear observation model:
7
8
ŷt = h(zt , ut ) (8.63)
9
10 The standard EKF corresponds to the special case of a Gaussian output with fixed observation
11 covariance Rt , with ŷt being the mean.
12
13
8.3.2 Algorithm
14
15 The EEKF algorithm is as follows. First, the prediction step:
16
17 µt|t−1 = f (µt−1 , ut ) (8.64)
18 ∂f
Ft = |(µ ,u ) (8.65)
19 ∂z t−1 t
20 Σt|t−1 = Ft Σt−1 FTt + Qt (8.66)
21
ŷt = h(µt|t−1 , ut ) (8.67)
22
23
Second, after seeing observation yt , we compute the following:
24
25 et = T (yt ) − ŷt (8.68)
26
Rt = Cov [T (y)|ŷt ] (8.69)
27
28 where T (y) is the vector of sufficient statistics, and et is the error or innovation term. (For a Gaussian
29 observation model with fixed noise, we have T (y) = y, so et = yt − ŷt , as usual.)
30 Finally we perform the update:
31
32 ∂h
Ht = |(µ ,u ) (8.70)
33 ∂z t|t−1 t
34 Kt = Σt|t−1 Ht (Ht Σt|t−1 HTt + Rt )−1
T
(8.71)
35
Σt = (I − Kt Ht )Σt|t−1 (8.72)
36
37
µt = µt|t−1 + Kt et (8.73)
38
In [Ollivier2018], they show that this is equivalent to an online version of natural gradient descent
39
(Main Section 6.4).
40
41
42
8.3.3 EEKF for training logistic regression
43 For example, consider the case where y is a a class label with C possible values. (We drop the time
44 index for brevity.) Following Main Section 2.4.2.2, Let
45
46 T (y) = [I (y = 1) , . . . , I (y = C − 1)] (8.74)
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
8.3. EXPONENTIAL-FAMILY EKF

1
MCMC Predictive distribution Laplace Predictive distribution EEKF Predictive Distribution
2
3
4
5
6
7
8
9
10
11 (a) (b) (c)
12
13 Figure 8.3: Bayesian inference applied to a 2d binary logistic regression problem, p(y = 1|x) = σ(w0 + w1 x1 +
14 w2 x2 ). We show the training data and the posterior predictive produced by different methods. (a) Offline
15 MCMC approximation. (b) Offline Laplace approximation. (c) Online EEKF approximation at the final step
16
of inference. Generated by eekf_logistic_regression.ipynb.
17
18
19
20
be the (C − 1)-dimensional vector of sufficient statistics, and let ŷ = [p1 , . . . , pC−1 ] be the corre-
21
sponding Ppredicted probabilities of each class label. The probability of the C’th class is given by
22 C−1
pC = 1 − c=1 ŷc ; we avoid including this to ensure that R is not singular. The (C − 1) × (C − 1)
23
covariance matrix R is given by
24
25
26 Rij = diag(pi ) − pi pj (8.75)
27
28 Now consider the simpler case where we have two class labels, so C = 2. In this case, T (y) =
29 I (y = 1), and ŷ = p(y = 1) = p. The covariance matrix of the observation noise becomes the scalar
30 r = p(1 − p). Of course, we can make the output probabilities depend on the input covariates, as
31 follows:
32
33 p(yt |zt , ut ) = Ber(yt |σ(zTt ut )) (8.76)
34
35
We assume the parameters zt are static, so Qt = 0. The 2d data is shown in Figure 8.3a. We
36
sequentially compute the posterior using the EEKF, and compare to the offline estimate computed
37
using a Laplace approximation (where the MAP estimate is computed using BFGS) and an MCMC
38
approximation, which we take as “ground truth”. In Figure 8.3c, we see that the resulting posterior
39
predictive distributions are similar. Finally, in Figure 8.4, we visualize how the posterior marginals
40
converge over time. (See also Main Section 8.6.3, where we solve this same problem using ADF.)
41
42
43
8.3.4 EEKF performs online natural gradient descent
44
45 In this section, we show an exact equivalence, due to [Ollivier2018], between the EKF for exponential
46 family likelihoods (Section 8.3) and an online version of natural gradient descent (Main Section 6.4).
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


90

1
w0 batch (Laplace) 1.0 w1 batch (Laplace) w2 batch (Laplace)
2 w0 online (EEKF) w1 online (EEKF) 0.50 w2 online (EEKF)
1
3 0.5 0.25

4 0.0 0.00
0
5 0.25
weights

weights

weights
0.5
6 1 0.50

7 1.0 0.75
2 1.00
8 1.5
1.25
9
10 20 30 40 50 10 20 30 40 50 10 20 30 40 50
number samples number samples number samples
10
11 (a) (b) (c)
12
13 Figure 8.4: Marginal posteriors over time for the EEKF method. The horizontal line is the offline MAP
14 estimate. Generated by eekf_logistic_regression.ipynb.
15
16
17
8.3.4.1 Statement of the equivalence
18
19 We define online natural gradient descent as follows. Given a set of labeled pairs, (ut , yt ), the goal is
20 to minimize the loss function
21 X
22 L(θ) = Lt (yt ) (8.77)
23 t
24
where
25
26
Lt (yt ) = − log p(y|ŷt ) (8.78)
27
28
and
29
30
ŷt = h(θ, ut ) (8.79)
31
32 is the prediction from some model, such as a neural network.
33 At each step of the algorithm, we perform the following updates, where Jt is the Fisher information
34 matrix (Main Section 3.3.4), and θt is the current parameter estimate:
35
 
36 Jt = (1 − γt )Jt−1 + γt Ep(y|ŷ) ∇θ Lt (y)∇θ Lt (y)T (8.80)
37
38
θt = θt−1 − ηt J−1
t ∇θ Lt (y) (8.81)
39
where ηt is the learning rate and γt is the Fisher matrix decay rate.
40
41 Theorem 1 (EKF performs natural gradient descent). The abvoe NGD algorithm is identical to
42 performing EEKF, with θt = zt , under the following conditions: static dynamics with zt+1 =
43 f (zt , ut ) = zt , exponential family observations with mean parameters ŷt = h(zt , ut ), learning rate
44 ηt = γt = 1/(t + 1), and Fisher matrix set to Jt = Σ−1 t /(t + 1).
45
46 See Section 8.3.4.3 for the proof of this claim.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
8.3. EXPONENTIAL-FAMILY EKF

1
2 8.3.4.2 Fading memory
3
For many problems, a learning rate schedule of the form ηt = 1/(t + 1) results in overly small updates
4
later in the sequence, resulting in very slow progress. It is therefore useful to be able to setting the
5
learning rates to larger values, such as a contant ηt = η0 . (We continue to assume that γt = ηt , for
6
the equivalence to hold.)
7
We can emulate this generic learning rate with an EKF by using the fading memory trick
8
[Haykin01], in which we update
9
10
Σt−1 = Σt−1 /(1 − λt ) (8.82)
11
12
before the prediction step, where
13
14 ηt−1
1 − λt = − ηt−1 (8.83)
15 ηt
16
17 If we set ηt = η0 , then we find λt = η0 ; if we set ηt = 1/(t + const), then we find λt = 0.
18 Fading memory is equivalent to adding artificial process noise Qt with a value proportional to
19 Σt−1 . This has the effect of putting less weight on older observations. This is equivalent to NGD
20 using a Fisher matrix of the form Kt = ηt Σ−1 t .
21 The learning rate also controls the weight given to the prior. If we use the fading memory trick,
22 the effect of the prior in the initial time step decreases exponentially with time, which can result
23 in overfitting (especially since we are downweighting past observations). We may therefore want to
24 artificially increase the initial uncertainty, Σ0 . This can be emulated in NGD by regularizing the
25 Fisher matrix, see Proposition 4 of [Ollivier2018] for details.
26
27 8.3.4.3 Proof of the claim
28
29 In this section, we prove the equivalence between EKF and NGD. We start by proving this lemma.
30
Lemma 1. The error term of the EEKF is given by
31
32
et , T (yt ) − ŷt = Rt ∇ŷt log p(yt |ŷt ) (8.84)
33
34
Proof. For the case of Gaussian observations, this is easy to see, since T (yt ) = yt and
35
36 1
log p(yt |ŷt ) = − (yt − ŷt )T R−1
t (yt − ŷt ) (8.85)
37 2
38
so
39
40
∇ŷt log p(yt |ŷt ) = R−1
t (T (yt ) − ŷt ) (8.86)
41
42
Now consider the general exponential family with natural parameters η and moment parameters
43
ŷ. From the chain rule and Main Equation (2.247), We have
44
45 ∂ log p(y|η) ∂ ŷ ∂ log p(y|ŷ)
46 = = T (y) − E [T (y)] (8.87)
∂η ∂η ∂ ŷ
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


92

1
2 From Main Equation (3.78) and Main Equation (3.72) we have
3
∂ ŷ
4 = Fη = Cov [T (y)] = R (8.88)
5
∂η
6 Hence
7
∂ log p(y|η) ∂ log p(y|ŷ)
8 =R = T (y) − E [T (y)] (8.89)
9 ∂η ∂ ŷ
10
from which we get Equation (8.84).
11
12 Now we prove another lemma.
13
Lemma 2. The Kalman gain matrix of the EKF satsifies
14
15 Kt Rt = Σt HTt (8.90)
16
17 Proof. Using the definition of Kt we have
18
19
Kt = Σt|t−1 HTt (Rt + Ht Σt|t−1 HTt )−1 (8.91)
20 Kt Rt = Kt (Rt + Ht Σt|t−1 HTt ) − Kt Ht Σt|t−1 HTt (8.92)
21
= Σt|t−1 HTt− Kt Ht Σt|t−1 HTt (8.93)
22
23 = (I − Kt Ht )Σt|t−1 HTt = Σt HTt (8.94)
24
where we used Equation (8.72) for Σt in the last line.
25
26 Now we prove our first theorem,
27
28 Theorem 2 (The EKF performs preconditioned gradient descent). The update step in the EKF
29 corresponds to the following gradient step
30
µt = µt|t−1 − Σt ∇µt|t−1 Lt (yt ) (8.95)
31
32 where
33
34 Lt (y) = − log p(y|ŷt ) = − log p(y|h(µt|t−1 , ut )) (8.96)
35
36
Proof. By definition of the EKF, we have µt = µt|t−1 + Kt et . By Lemma 1 and Lemma 2 we have
37
µt = µt|t−1 + Kt et (8.97)
38
39 = µt|t−1 + Kt Rt ∇ŷt Lt (yt ) (8.98)
40
= µt|t−1 + Σt HTt ∇ŷt Lt (yt ) (8.99)
41
∂ ŷt
42 But Ht = ∂µt|t−1 , so
43
44
µt = µt|t−1 + Σt ∇µt|t−1 Lt (yt ) (8.100)
45
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
8.3. EXPONENTIAL-FAMILY EKF

1
2 We now prove another lemma.
3
4
Lemma 3 (Information filter). The EKF update can be written in information form as follows:
5
Σ−1 −1 T −1
t = Σt|t−1 + Ht Rt Ht (8.101)
6
7
Furthermore, for static dynamical systems, where f (z, u) = z and Qt = 0, the EKF becomes
8
9 Σ−1 −1 T
t = Σt−1 + Ht Rt Ht (8.102)
10
µt = µt−1 − Σt ∇µt−1 Lt (yt ) (8.103)
11
12
Proof. We have
13
14 Σt = (I − Kt Ht )Σt|t−1 (8.104)
15 −1
16
= Σt|t−1 − Σt|t−1 HTt (Ht Σt|t−1 HTt + Rt ) Ht Σt|t−1 (8.105)
17 = (Σ−1 T −1
t|t−1 + Ht Rt Ht )
−1
(8.106)
18
19 where we used the matrix inversion lemma in the last line. The second claim follows easily since for
20 a static model we have µt|t−1 = µt−1 and Σt|t−1 = Σt−1 .
21
Now we prove another lemma.
22
23 Lemma 4. For exponential family models, the HTt R−1
t Ht term in the information filter is equal to
24 the Fisher information matrix:
25
 T
26 HTt R−1
t Ht = Ep(y|ŷt ) gt gt (8.107)
27
28 where gt = ∇µt|t−1 Lt (y).
29
Proof. We will omit time indices for brevity. We have
30
31 ∂L(y) ∂L(y) ∂ ŷ ∂L(y)
32 = = H (8.108)
∂µ ∂ ŷ ∂µ ∂ ŷ
33
34 Hence
35    
36 Ep(y)) ∇µ Lt (y)∇µ Lt (y)T = HT Ep(y) ∇ŷ Lt (y)∇ŷ Lt (y)T H (8.109)
37
38 But the middle term is the FIM wrt the mean parameters, which, from Main Equation (3.80), is
39 R−1 .
40
Finally we are able to prove Theorem 1.
41
42 Proof. From Lemma 3 and Lemma 4 we have
43 h i
44 Σ−1 −1
t = Σt−1 + Ep(y|ŷt ) ∇µt−1 Lt (y)∇µt−1 Lt (y)
T
(8.110)
45
46 µt = µt−1 − Σt ∇µt−1 Lt (y) (8.111)
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


94

1
2 If we define Jt = Σ−1
t /(t + 1), this becomes
3
t 1 h i
4 Jt = Jt−1 + Ep(y|ŷt ) ∇µt−1 Lt (y)∇µt−1 Lt (y)T (8.112)
5
t+1 t+1
1
6 µt = µt−1 − J−1 ∇µt−1 Lt (y) (8.113)
7 t+1 t
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
9 Inference for graphical models

9.1 Belief propagation on trees


9.1.1 BP for polytrees
In this section, we generalize the forwards-backwards algorithm (i.e., two-filter smoothing) for chain-
structured PGMs to work on a polytree, which is a directed graph whose undirected “backbone” is
a tree, i.e., a graph with no loops. (That is, a polytree is a directed tree with multiple root nodes,
in which a node may have multiple parents, whereas in a singly rooted tree, each node has a single
parent.) This algorithm is called belief propagation and is due to [Pearl88].
We consider the case of a general discrete node X with parents Ui and children Yj . We partition

the evidence in the graph, e, into the evidence upstream of node X, e+ X , and all the rest, eX . Thus
+ −
eX contains all the evidence separated from X if its incoming arcs were deleted, and eX contains the
evidence below X and the evidence in X itself, if any. The posterior on node X can be computed as
follows:

belX (x) , p(X = x|e) = c0 λX (x)πX (x) (9.1)


λX (x) , p(e−
X |X = x) (9.2)
πX (x) , p(X = x|e+
X) (9.3)

where c0 is a normalizing constant.


Consider the graph shown in Figure 9.1. We will use the notation e+ U1 →X to denote the evidence
above the edge from U1 to X (i.e., in the “triangle” above U1 ), and e−X→Y1 to denote the evidence
below the edge from X to Y1 (i.e., in the triangle below Y1 ). We use eX to denote the local evidence
attached to node X (if any).
We can compute λX as follows, using the fact that X’s children are independent given X. In
particular, the evidence in the subtrees rooted at each child, and the evidence in X itself (if any),
are conditionally independent given X.

p(e− − −
X |X = x) = p(eX |X = x)p(eX→Y1 |X)p(eX→Y2 |X) (9.4)

If we define the λ “message” that a node X sends to its parents Ui as

λX→Ui (ui ) , p(e−


Ui →X |Ui = ui ) (9.5)
96

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
Figure 9.1: Message passing on a polytree.
21
22
23
we can write in general that
24
Y
25 λX (x) = λX→X (x) × λYj →X (x) (9.6)
26 j
27
28
where λX→X (x) = p(eX |X = x). For leaves, we just write λX→Ui (ui ) = 1, since there is no evidence
29
below X.
30
We compute πX by introducing X’s parents, to break the dependence on the upstream evidence,
31
and then summing them out. We partition the evidence above X into the evidence in each subtree
32
above each parent Ui .
X
33
p(X = x|e+
X) = p(X = x, U1 = u1 , U2 = u2 |e+
X) (9.7)
34
u1 ,u2
35 X
36
= p(X = x|u1 , u2 )p(u1 , u2 |e+ +
U1 →X , eU2 →X ) (9.8)
u1 ,u2
37 X
38 = p(X = x|u1 , u2 )p(u1 |e+ +
U1 →X )p(u2 |eU2 →X ) (9.9)
39 u1 ,u2
40
If we define the π “message” that a node X sends to its children Yj as
41
42 ΠX→Yj (x) , p(X = x|e+
X→Yj ) (9.10)
43
44
we can write in general that
45
X Y
πX (x) = p(X = x|u) ΠUi →X (ui ) (9.11)
46
u i
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
9.1. BELIEF PROPAGATION ON TREES

1
2 For root nodes, we write πX (x) = p(X = x), which is just the prior (independent of the evidence).
3
4
9.1.1.1 Computing the messages
5
6 We now describe how to recursively compute the messages. We initially focus on the example in
7 Figure 9.1. First we compute the λ message.
8
9 λX→U1 (u1 ) = p(e− +
X , eU2 →X |u1 ) (9.12)
10
all the ev. except in the U1 triangle (9.13)
11 XX
12 = p(e− +
X , eU2 →X |u1 , u2 , x)p(u2 , x|u1 ) (9.14)
13 x u2
XX
14
= p(e− +
X |x)p(eU2 →X |u2 )p(u2 , x|u1 ) (9.15)
15
x u2
16
since X separates the U2 triangle from e− X , and U2 separates the U2 triangle from U1
17
(9.16)
18
XX +
19 p(u2 |eU2 →X )
=c p(e−
X |x) p(x|u2 , u1 )p(u2 |u1 ) (9.17)
20
x u
p(u2 )
2
21
22 using Bayes’ rule, where c = p(e+ U2 →X ) is a constant (9.18)
XX
23 =c p(e− +
X |x)p(u2 |eU2 →X )p(x|u2 , u1 ) (9.19)
24 x u2
25
since U1 and U2 are marginally independent (9.20)
26 XX
27 =c λX (x)ΠU2 →X (u2 )p(x|u2 , u1 ) (9.21)
28 x u2
29
30 In general, we have
31  
32 X X Y
33
λX→Ui (ui ) = c λX (x)  p(X = x|u) ΠUk →X (uk ) (9.22)
x uk :k6=i k6=i
34
35
36 If the graph is a rooted tree (as opposed to a polytree), each node has a unique parent, and this
37 simplifies to
38 X
39 λX→Ui (ui ) = c λX (x)p(X = x|ui ) (9.23)
40 x

41

42 Finally, we derive the π messages. We note that e+
X→Yj = e − eX→Yj , so ΠX→Yj (x) is equal to
43 belX (x) when the evidence e−
X→Yj is suppressed:
44
Y
45
ΠX→Yj (x) = c0 πX (x)λX→X (x) λYk →X (x) (9.24)
46
k6=j
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


98

1
2
3
4
5
6
7
8
9
10
11
12
13
14 (a) (b)
15
16 Figure 9.2: (a) An undirected graph. (b) Its corresponding junction tree.
17
18
19 9.1.1.2 Message passing protocol
20
21 We must now specify the order in which to send the messages. If the graph is a polytree, we can
22 pick an arbitrary node as root. In the first pass, we send messages to it. If we go with an arrow, the
23 messages are π messages; if we go against an arrow, the messages are λ messages. On the second
24 pass, we send messages down from the root.
25 If the graph is a regular tree (not a polytree), there already is a single root. Hence the first pass
26 will only consist of sending λ messages, and the second pass will only consist of sending π messages.
27 This is analogous to a reversed version of the forwards-backwards algorithm, where we first send
28 backwards likelihood messages to the root (node z1 ) and then send them forwards posterior messages
29 to the end of the chain (node zT ).
30
31
9.2 The junction tree algorithm (JTA)
32
33 The junction tree algorithm or JTA is a generalization of variable elimination that lets us
34 efficiently compute all the posterior marginals without repeating redundant work, thus avoiding the
35 problems mentioned in Main Section 9.5.5. The basic idea is to convert the graph into a tree, and
36 then to run belief propagation on the tree. We summarize the main ideas below. For more details,
37 see e.g., [Lauritzen96; Huang1996; Cowell99; Jensen07; KollerBook; Vandenberghe2015].
38
39
9.2.1 Tree decompositions
40
41 A junction tree, also called a join tree or clique tree, is a tree-structured graph, derived from
42 the original graph, which satisfies certain key properties that we describe below; these properties
43 ensure that local message passing results in global consistency. Note that junction trees have many
44 applications in mathematics beyond probabilistic inference (see e.g., [Vandenberghe2015]). Note
45 also that we can create a directed version of a junction tree, known as a Bayes tree, which is useful
46 for incremental inference [Dellaert2017].
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
9.2. THE JUNCTION TREE ALGORITHM (JTA)

1
2 The process of converting a graph into a junction tree is called tree decomposition [Halin1976;
3 Robertson1984; Heinz2013; Satzinger2015; Chekuri2014; Vandenberghe2015], which we
4 summarize below.
5 Intuitively, we can convert a graph into a tree by grouping together nodes in the original graph
6 to make “meganodes” until we end up with a tree, as illustrated in Figure 9.2. More formally, we
7 say that T = (VT , ET ) is a tree decomposition of an undirected graph G = (V, E) if it satisfies the
8 following properties:
9
• ∪t∈VT Xt = V. Thus each graph vertex is associated with at least one tree node.
10
11 • For each edge (u, v) ∈ E there exists a node t ∈ VT such that u ∈ Xt and v ∈ Xt . (For example, in
12 Figure 9.2, we see that the edge a − b in G is contained in the meganode abc in T .)
13
• For each v ∈ V, the set {t : v ∈ Xt } is a subtree of T . (For example, in Figure 9.2, we see that the
14
set of meganodes in the tree containing graph node c forms the subtree (abc) − (acf ) − (cde).) Put
15
another way, if Xi and Xj both contain a vertex v, then all the nodes Xk of the tree on the unique
16
path from Xi to Xj also contain v, i.e., for any node Xk on the path from Xi to Xj , we have
17
Xi ∩ Xj ⊆ Xk . This is called the running intersection property. (For example, in Figure 9.2,
18
if Xi = (abc) and Xj = (af g), then we see that Xi ∩ Xj = {a} is contained in node Xk = (acf ).)
19
20 A tree that satisfied these properties is also called a junction tree or jtree. The width of a jtree
21 is defined to be the size of the largest meganode
22
23
width(T ) = max |Xt | (9.25)
t∈T
24
25
For example, the width of the jtree in Figure 9.2(b) is 3.
26
There are many possible tree compositions of a graph, as we discuss below. We therefore define
27
the treewidth of a graph G as the minimum width of any tree decomposition for G minus 1:
 
28
29
treewidth(G) , min width(T ) − 1 (9.26)
T ∈T (G)
30
31 We see that the treewidth of a tree is 1, and the treewidth of the graph in Figure 9.2(a) is 2.
32
33 9.2.1.1 Why create a tree decomposition?
34
Before we discuss how to compute a tree decomposition, we pause and explain why we want to do
35
this. The reason is that trees have a number of properties that make them useful for computational
36
purposes. In particular, given a pair of nodes, u, v ∈ V, we can always find a single node s ∈ V on
37
the path from u to v that is a separator, i.e., that partitions the graph into two subgraphs, one
38
containing u and the other containing v. This is conducive to using algorithms based on dynamic
39
programming, where we recursively solve the subproblems defined on the two subtrees, and then
40
combine their solutions via the separator node s. This is useful for graphical model inference (see
41
Main Section 9.6), solving sparse systems of linear equations (see e.g., [Paskin03jtree]), etc.
42
43
9.2.1.2 Computing a tree decomposition
44
45 We now describe an algorithm known as triangulation or elimination for constructing a junction
46 tree from an undirected graph. We first choose an ordering of the nodes, π. (See Main Section 9.5.3
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


100

1
2
3
4
5
6
7
8
9
10
11
12
13
14 (a) (b)
15
16
17
18
19
20
21
22
23
24
25
26
27
(c) (d)
28
29 Figure 9.3: (a-b) Illustration of two steps of graph triangulation using the elimination order (a,b,c,d,e,f,g,h)
30 applied to the graph in Figure 9.2a. The node being eliminated is shown with a darker border. Cliques are
31 numbered by the vertex that created them. The dotted a-f line is a fill-in edge created when node g is eliminated.
32 (c) Corresponding set of maximal cliques of the chordal graph. (d) Resulting junction graph.
33
34
35
36 for a discussion of how to choose a good elimination ordering.) We then work backwards in this
37 ordering, eliminating the nodes one at a time. We initially let U = {1, . . . , N } be the set of all
38 uneliminated nodes, and set the counter to i = N . At each step i, we pick node vi = πi , we create the
39 set Ni = nbri ∩ U of uneliminated neighbors and the set Ci = vi ∪ Ni , we add fill-in edges between
40 all nodes in Ci to make it a clique, we eliminate vi by removing it from U, and we decrement i by 1,
41 until all nodes are eliminated.
42 We illustrate this method by applying it to the graph in Figure 9.3, using the ordering π =
43 (a, b, c, d, e, f, g, h). We initialize with i = 8, and start by eliminating vi = π(8) = h, as shown
44 in Figure 9.3(a). We create the set C8 = {g, h} from node vi and all its uneliminated neighbors.
45 Then we add fill-in edges between them, if necessary. (In this case all the nodes in C8 are already
46 connected.) In the next step, we eliminate vi = π(7) = g, and create the clique C7 = {a, f, g},
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
9.2. THE JUNCTION TREE ALGORITHM (JTA)

1
2 adding the fill-in edge a − f , as shown in Figure 9.3(b). We continue in this way until all nodes are
3 eliminated, as shown in Figure 9.3(c).
4 If we add the fill-in edges back to the original graph, the resulting graph will be chordal, which
5 means that every undirected cycle X1 − X2 · · · Xk − X1 of length k ≥ 4 has a chord. The largest loop
6 in a chordal graph is length 3. Consequently chordal graphs are sometimes called triangulated.
7 Figure 9.3(d) illustrates the maximal cliques of the resulting chordal graph. In general, computing
8 the maximal cliques of a graph is NP-hard, but in the case of a chordal graph, the process is easy:
9 at step i of the elimination algorithm, we create clique Ci by connecting vi to all its uneliminated
10 neighbors; if this clique is contained in an already created clique, we simple discard it, otherwise we
11 add it to our list of cliques. For example, when triangulating the graph in Figure 9.3, we drop clique
12 C4 = {c, d} since it is already contained in C5 = {c, d, e}. Similarly we drop cliques C2 = {a, b} and
13 C1 = {a}.
14 There are several ways to create a jtree from this set of cliques. One approach is as follows: create
15 a junction graph, in which we add an edge between i and j if Ci ∩ Cj 6= ∅. We set the weight of this
16 edge to be |Ci ∩ Cj |, i.e., the number of variables they have in common. One can show [Jensen94;
17 Aji00] that any maximal weight spanning tree (MST) of the junction graph is a junction tree. This
18 is illustrated in Figure 9.3d, which corresponds to the jtree in Figure 9.2b.
19
20 9.2.1.3 Computing a jtree from a directed graphical model
21
22
In this section, we show how to create a junction tree from a DPGM. For example, consider the
23
“student” network from Main Figure 4.38(a). We can “moralize” this (by connecting unmarried
24
parents with a common child, and then dropping all edge orientations), to get the undirected graph in
25
Main Figure 4.38(b). We can then derive a tree decomposition by applying the variable elimination
26
algorithm from Main Section 9.5. The difference is that this time, we keep track of all the fill-in edges,
27
and add them to the original graph, in order to make it chordal. We then extract the maximal cliques
28
and convert them into a tree. The corresponding tree decomposition is illustrated in Figure 9.4. We
29
see that the nodes of the jtree T are cliques of the chordal graph:
30 C(T ) = {C, D}, {G, I, D}, {G, S, I}, {G, J, S, L}, {H, G, J} (9.27)
31
32
9.2.1.4 Tree decompositions of some common graph structures
33
34 In Figure 9.5, we illustrate the tree decomposition of several common graph structures which
35 arise when using neural networks and graphical models. The resulting decomposition can be
36 used to trade off time and memory, by storing checkpoints to partition the graph into subgraphs,
37 and then recomputing intermediate quantities on demand; for details, see e.g., [Griewank2008;
38 Binder97island; Zweig00; Chen2016sublinear]. For example, for a linear chain, we can reduce
39 the memory from O(T ) to O(log T ), if we are willing to increase the runtime from O(T ) to O(T log T ).
40 Another common graph structure is a 2d grid. If the grid has size w × h, then the treewidth is
41 min(w, h). To see this, note that we can convert the grid into a chain by grouping together all the
42 nodes in each column or each row, depending on which is smaller. (See [Lipton79] for the formal
43 proof.)
44 Note that a graph may look like it is triangulated, even though it is not. For example, Figure 9.6(a)
45 is made of little triangles, but it is not triangulated, since it contains the chordless 5-cycle 1-2-3-4-5-1.
46 A triangulated version of this graph is shown in Figure 9.6(b), in which we add two fill-in edges.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


102

1
2 Coherence Coherence
3
4
Difficulty Intelligence Difficulty Intelligence
5
6
7 Grade SAT Grade SAT
8
9
Letter Letter
10
11 Job Job

12 Happy Happy
13 (a) (b)
14
15
16
C,D G,I,D G,S,I G,J,S,L H,G,J
17 D G,I G,S G,J
18
19 (c)
20
Figure 9.4: (a) A triangulated version of the (moralized) student graph from Main Figure 4.38(b). The extra
21
fill-in edges (such as G-S) are derived from the elimination ordering used in Main Figure 9.18. (b) The
22
maximal cliques. (c) The junction tree. From Figure 9.11 of [KollerBook]. Used with kind permission of
23
Daphne Koller.
24
25
26
27
28
9.2.2 Message passing on a junction tree
29
30 In this section, we discuss how to extend the belief propagation algorithm of Main Section 9.3 to
31 work with junction trees. This will let us compute the exact marginals in time linear in the size of
32 the tree. We focus on the Lauritzen-Spiegelhalter algorithm [Lauritzen88], although there are
33 many other variants (see e.g., [Huang1996; Jensen07]).
34
35
36
9.2.2.1 Potential functions
37
38 We present the algorithm abstractly in terms of potential functions φi associated with each node
39 (clique) in the junction tree. A potential function is just a non-negative function of its arguments. If
40 the arguments are discrete, we can represent potentials as multi-dimensional arrays (tensors). We
41 discuss the Gaussian case in Main Section 2.3.3, and the general case in Section 9.2.3.
42 We assume each potential has an identity element, and that there is a way to multiply, divide and
43 marginalize potentials. For the discrete case, the identity element is a vector of all 1s. To explain
44 marginalization, suppose clique i has domain Ci , let Sij = Ci ∩ Cj be the separator between node
45 i and j. Let us partition the domain of φi into Sij and Ci0 = Ci \ Sij , where Ci0 are the variables
46 that are unique to i and not shared with Sij . We denote marginalization of potential φi (Ci ) onto the
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
9.2. THE JUNCTION TREE ALGORITHM (JTA)

1
2
3
4
5
6
7
8
9
10
11
Figure 9.5: Examples of optimal tree decompositions for some common graph structures. Adapted from
12
https: // bit. ly/ 2m5vauG .
13
14
15
16
17
18
19
20
21
22
23
24 (a) (b)
25
Figure 9.6: A triangulated graph is not just composed of little triangles. Left: this graph is not triangulated,
26
despite appearances, since it contains a chordless 5-cycle 1-2-3-4-5-1. Right: one possible triangulation, by
27
adding the 1-3 and 1-4 fill-in edges. Adapted from [Armstrong05]
28
29
30
domain Sij as follows
31
X
32
φij (Sij ) = φi (Ci ) ↓ Sij = φi (Ci0 , Sij ) (9.28)
33
Ci0 ∈Ci \Sij
34
35 We define multiplication elementwise, by broadcast φij over Cj0 .
36
37 (φj ∗ φij )(Cj ) = φj (Cj0 , Sij )φij (Sij ) (9.29)
38
where Cj = (Cj0 , Sij . We define division similarly:
39
40
φj φj (Cj0 , Sij )
41 ( )(Cj ) = (9.30)
φij φij (Sij )
42
43 where 0/0 = 0. (The Shafer-Shenoy version of the algorithm, from [Shafer90], avoids division by
44 keeping track of the individual terms and multiplying all but one of them on demand.)
45 We can intepret division as computing a conditional distribution, since φ∗j = φj /φij = p(Cj0 , Sij )/p(Sij ) =
46 p(Cj0 |Sij ). Similarly we can interpret multiplication as adding updated information back in. To see
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


104

1
2
3
// Collect to root
for each node n in post-order
4
p = parent(n)
5
φnp = φn ↓ Snp
6
φp = φp ∗ φnp
7
8 // Distribute from root
9 for each node n in pre-order
10 for each child c of n
11 φc = φφnc
c

12 φnc = φn ↓ Snc
13 φc = φc ∗ φnc
14
15
Figure 9.7: Message passing on a (directed) junction tree.
16
17
18 this, let φ∗ij = p(Sij |e) be the new separator potential, where e is some evidence. Let φ∗∗ ∗ ∗
j = φj ∗ φij
19 be the resulting of dividing out the old separator and multiplying in the new separator. Then
j ∝ p(Cj , Sij |e). So we have successfully passed information from i to j. We will leverage this
φ∗∗
20 0

21 result below.
22
23 9.2.2.2 Initialization
24
25 To initialize the junction tree potentials, we first assign each factor Fk to a unique node j = Ak such
26 that the domain of node j contains all of Fk ’s variables. Let A−1 i = {k : Ak = i} be all the factors
27 assigned to node i. We set the node potentials to
28
Y
φi = Fk (9.31)
29
k∈A−1
i
30
31 where φi i = 1 if no factors are assigned to i. We set the separator potentials to φij = 1.
32
33 9.2.2.3 Calibration
34
35
We now describe a simple serial ordering for sending messages on the junction tree. We first pick
36
an arbitrary node as root. Then the algorithm has two phases, similar to forwards and backwards
37
passes over a chain (see Main Section 9.2.3).
38
In the collect evidence phase, we visit nodes in post-order (children before parents), and each
39
node n sends a message to its parents p, until we reach the root. The parent p first divides out any
40
information it received (via the separator) from its child n by computing
41 φp
φp = (9.32)
42 φnp
43
However, since the separator potentials are initialized to 1s, this operation is not strictly necessary.
44
Next we compute the message from child to parent by computing an updated separator potential:
45
46 φnp = φn ↓ Snp (9.33)
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
9.2. THE JUNCTION TREE ALGORITHM (JTA)

1
2 Finally the parent “absorbs the flow” from its child by computing
3
4 φp = φp ∗ φnp (9.34)
5
In the distribute evidence phase we visit nodes in pre-order (parents before children), and each
6
node n sends a message to each child c, starting with the root. In particular, each child divides out
7
message it previously sent to its parent n:
8
9 φc
10 φc = (9.35)
φnc
11
12 Then we compute the new message from parent to child:
13
14
φnc = φn ↓ Snc (9.36)
15
Finally the child absorbs this new information:
16
17 φc = φc ∗ φnc (9.37)
18
19 The overall process is sometimes called “calibrating” the jtree [Lauritzen88]. See Figure 9.7 for
20 the pseudocode.
21
22
9.2.3 The generalized distributive law
23
24 We have seen how we can define potentials for discrete and Gaussian distributions, and we can then
25 use message passing on a junction tree to efficiently compute posterior marginals, as well as the
26 likelihood of the data. For example, consider a graphical model with pairwise potentials unrolled for
27 4 time steps. The partition function is defined by
28 XXXX
29 Z= ψ12 (x1 , x2 )ψ23 (x2 , x3 )ψ34 (x3 , x4 ) (9.38)
30 x1 x2 x3 x4

31
We can distribute sums over products to compute this more cheaply as follows:
32
XX X X
33 Z= ψ12 (x1 , x2 ) ψ23 (x2 , x3 ) ψ34 (x3 , x4 ) (9.39)
34 x1 x2 x3 x4
35
36 By defining suitable implementations of the sum and multiplication operations, we can use this same
37 trick to solve a variety of problems. This general formulation is called the generalized distributive
38 law [Aji00].
39 The key property we require is that the local clique functions ψc are associated with a commutative
40 semi-ring. This is a set K, together with two binary operations called “+” and “×”, which satisfy
41 the following three axioms:
42
1. The operation “+” is associative and commutative, and there is an additive identity element called
43
“0” such that k + 0 = k for all k ∈ K.
44
45 2. The operation “×” is associative and commutative, and there is a multiplicative identity element
46 called “1” such that k × 1 = k for all k ∈ K.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


106

1
2 Domain + × Name
3 [0, ∞) (+, 0) (×, 1) sum-product
4 [0, ∞) (max, 0) (×, 1) max-product
5 (−∞, ∞] (min, ∞) (+, 0) min-sum
6 {T, F } (∨, F ) (∧, T ) Boolean satisfiability
7
Table 9.1: Some commutative semirings.
8
9
10
11 X1 X2 X2 X2 X3 X3 X3 X4
12
13
14 X1 X2 X3 X4
15
16
X1 Y1 X2 Y2 X3 Y3 X4 Y4
17
18
19
Figure 9.8: The junction tree derived from an HMM of length T = 4.
20
21
22
23 3. The distributive law holds, i.e.,
24
25 (a × b) + (a × c) = a × (b + c) (9.40)
26
27 for all triples (a, b, c) from K.
28
29 There are many such semi-rings; see Table 9.1 for some examples. We can therefore use the JTA
30 to solve many kinds of problems, such as: computing posterior marginals (as we have seen); com-
31 puting posterior samples [Dawid92]; computing the N most probable assignments [Nilsson98];
32 constraint satisfaction problems [Bistarelli97; Dechter03; Dechter2019]; logical reasoning prob-
33 lems [Amir05]; solving linear systems of the form Ax = b where A is a sparse matrix [Blair92;
34 Paskin03jtree; Bickson09]; etc. See [Lauritzen97] for more details.
35
36
9.2.4 JTA applied to a chain
37
38 It is interesting to see what happens if we apply the junction tree algorithm to a chain structured
39 graph such as an HMM. A detailed discussion can be found in [Smyth97], but the basic idea is
40 as follows. First note that for a pairwise graph, the cliques are the edges, and the separators are
41 the nodes, as shown in Figure 9.8. We initialize the potentials as follows: we set ψs = 1 for all the
42 separators, we set ψc (xt−1 , xt ) = p(xt |xt−1 ) for clique c = (Xt−1 , Xt ) , and we set ψc (xt , yt ) = p(yt |xt )
43 for clique c = (Xt , Yt ).
44 Next we send messages from left to right along the “backbone”, and from observed child leaves
45 up to the backbone. Consider the clique j = (Xt−1 , Xt ) and its two children, i = (Xt−2 , Xt−1 ), and
46 i0 = (Xt , Yt ). To compute the new clique potential for j, we first marginalize the clique potentials for
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
9.2. THE JUNCTION TREE ALGORITHM (JTA)

1
2 i onto Sij and for i0 onto Si0 j to get
3 X

4 ψij (Xt−1 ) = ψi (Xt−2 , Xt−1 ) = p(Xt−1 |y1:t−1 ) = αt−1 )(Xt ) (9.41)
5 Xt−2
6
X
ψi∗0 j (Xt ) = ψi (Xt , Yt ) ∝ p(Yt |Xt ) = λt (Xt ) (9.42)
7
Yt
8
9 We then absorb messages from these separator potentials to compute the new clique potential:
10
11

ψij (Xt−1 ) ψi∗0 j (Xt )
ψi∗ (Xt−1 , Xt ) ∝ ψi (Xt−1 , Xt ) (9.43)
12 ψij (Xt−1 ) ψi0 j (Xt )
13
αt−1 (Xt−1 ) λt (Xt )
14 = A(Xt−1 , Xt ) ∝ p(Xt−1 , Xt |y1:t ) (9.44)
1 1
15
16 which we recognize as the filtered two-slice marginal.
17 Now consider the backwards pass. Let k = (Xt , Xt+1 ) be the parent of j. We send a message from
18 k to j via their shared separator Sk,j (Xt ) to get the final potential:
19
∗∗
20 ψk,j (Xt )
ψj∗∗ (Xt−1 , Xt ) ∝ ψj∗ (Xt−1 , Xt ) ∗ (9.45)
21 ψk,j (Xt )
22
γt (Xt )
23 = [A(Xt−1 , Xt )αt−1 (Xt−1 )λt (Xt )] (9.46)
24
αt (Xt )
25 ∝ p(Xt−1 , Xt |y1:T ) (9.47)
26
27 where αt (Xt ) = p(Xt |y1:t ) and γt (Xt ) = p(Xt |y1:T ) are the separator potentials for Sjk on the
28 forwards and backwards passes. This matches the two slice smothed marginal in Main Equation (9.35).
29
30 9.2.5 JTA for general temporal graphical models
31
32
In this section, we discuss how to perform exact inference in temporal graphical models, which
33
includes dynamic Bayes nets (Main Section 29.5.5) and their undirected analogs.
34
The simplest approach to inference in such models is to flatten the model into a chain, by defining
35
a mega-variable zt whose state space is the cross product of all the individual hidden variables in
36
slice t, and then to compute the corresponding transition matrix. For example, suppose we have two
37
independent binary chains, with transition matrices given by
38    
a b e f
39 A1 = , A2 = (9.48)
c d g h
40
41
Then the transition matrix of the flattened model has the following Kronecker product form:
42
 
43 ae af be bf
44 ag ah bg bh 
A = A1 ⊗ A2 =    (9.49)
45 ce cf de df 
46 cg ch dg dh
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


108

1
2
3
4
5
6
7
8
9
10 Figure 9.9: The cliques in the junction tree decomposition as we advance the frontier from one time-slice to
11
the next in a 3-chain factorial HMM model.
12
13
14 For example, the probability of going from state (1,2) to (2,1) is p(1 → 2) × p(2 → 1) = b × g. We
15 note that this is not a sparse matrix, oeven though the chains are completely independent.
16 One can use this expanded matrix inside the forwards-backwards algorithm to compute p(z1,t , z2,t |y1:T , θ),
17 from which the marginals of each chain, p(zi,t |y1:T , θ), can easily be derived. If each hidden node has
18 K states, and there are M hidden nodes per time step, the transition matrix has size (K M ) × (K M ),
19 so this method takes O(T K 2M ) time, which is often unacceptably slow.
20 Of course, the above method ignores the structure within each time slice. For example, the above
21 flattened matrix does not exploit the fact that the chains are completely independent. By using the
22 JTA, we can derive a more efficient algorithm. For example, consider the 3-chain factorial HMM
23 (FHMM) in Main Figure 29.19a. All the hidden variables xmt within a time slice become correlated
24 due to the observed common child yt (explaining away), so the exact belief state p(x1:M,t |y1:t , θ) will
25 necessarily have size O(K M ). However, rather than multiplying this large vector by a (K M ) × (K M )
26 matrix, we can update the belief state one variable at a time, as illustrated in Figure 9.9. This takes
27 O(T M K M +1 ) time (see [Ghahramani97] for the details of the algorithm). This method has been
28 called the frontier algorithm [Zweig96b], since it sweeps a “frontier” across the network (forwards
29 and backwards); however, this is just a special case of the JTA. For a detailed discussion of how to
30 apply the JTA to temporal graphical models, see [Bilmes10].
31 Although the JTA for FHMMs is better than the naive approach to inference, it still takes time
32 exponential in the number of hidden nodes per chain (ignoring any transient nodes that do not
33 connect across time). For the FHMM, this is unavoidable, since all the hidden variables immediately
34 become correlated within a single time slice due to the observed common child yt . What about for
35 graphs with sparser structure? For example, consider the coupled HMM in Main Section 29.5.4. Here
36 each hidden node only depends on two nearest neighbors and some local evidence. Thus initially
37 the belief state can be factored. However, after T = M time steps, the belief state becomes fully
38 correlated, because there is now a direct path of influence between variables in non-neighboring
39 chains. This is known as the entanglement problem, and it means that, in general, exact inference
40 in temporal graphical models is exponential in the number of (persistent) hidden variables. Looking
41 carefully at Main Figure 29.19a, this is perhaps not so suprising, since the model looks like a short
42 and wide grid-structured graph, for which exact inference is known to be intractable in general.
43 Fortunately, we can still leverage sparse graph structure when performing approximate inference.
44 The intuition is that although all the variables may be correlated, the correlation between distant
45 variables is likely to be weak. In Section 10.3.2, we derive a structured mean field approximation for
46 FHMMs, which exploits the parallel chain structure. This only takes O(T M K 2 I) time, where I is
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
9.3. MAP ESTIMATION FOR DISCRETE PGMS

1
2 the number of iterations of the inference algorithm (typically I ∼ 10 suffices for good performance).
3 See also [Boyen98] for an approach based on assumed density filtering.
4 Note that the situation with linear dynamical systems is somewhat different. In that context,
5 combining multiple hidden random variables merely increases the size of the state space additively
6 rather than multiplicatively. Thus inference takes O(T (CL)3 ) time, if there are T time steps, and C
7 hidden chains each with dimensionality L. Furthermore, two independent chains combine to produce
8 a sparse block-diagonal transition weight matrix, rather than a dense Kronecker product matrix, so
9 the structural information is not “lost”.
10
11
12
9.3 MAP estimation for discrete PGMs
13
In this section, we consider the problem of finding the most probable configuration of variables in a
14
probabilistic graphical model, i.e., our goal is to find a MAP assignment x∗ = arg maxx∈X V p(x),
15
where X = {1, . . . , K} is the discrete state space of each node, V is the number of nodes, and the
16
distribution is defined according to a Markov Random field (Main Section 4.3) with pairwise cliques,
17
one per edge:
18
 
19 X 
1 X
20 p(x) = exp θs (xs ) + θst (xs , xt ) (9.50)
21 Z  
s∈V (s,t)∈E
22
23 Here V = {x1 , . . . , xV } are the nodes, E are the edges, θs and θst are the node and edge potentials,
24 and Z is the partition function:
25
 
26 X X X 
27 Z= exp θs (xs ) + θst (xs , xt ) (9.51)
x
 
28 s∈V (s,t)∈E
29
30 Since we just want the MAP configuration, we can ignore Z, and just compute
31 X X
32 x∗ = argmax θs (xs ) + θst (xs , xt ) (9.52)
x
33 s∈V (s,t)∈E
34
35 We can compute this exactly using dynamic programming as we explain in Section 9.2; However, this
36 takes time exponential in the treewidth of the graph, which is often too slow. In this section, we
37 focus on approximate methods that can scale to intractable models. We only give a brief description
38 here; more details can be found in [Monster; KollerBook].
39
40 9.3.1 Notation
41
42
To simplify the presentation, we write the distribution in the following form:
43
1
44 p(x) = exp(−E(x)) (9.53)
Z(θ)
45
46 E(x) , −θT T (x) (9.54)
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


110

1
2 where θ = ({θs;j }, {θs,t;j,k }) are all the node and edge parameters (the canonical parameters), and
3 T (x) = ({I (xs = j)}, {I (xs = j, xt = k)}) are all the node and edge indicator functions (the sufficient
4 statistics). Note: we use s, t ∈ V to index nodes and j, k ∈ X to index states.
5 The mean of the sufficient statistics are known as the mean parameters of the model, and are
6 given by
7
8
µ = E [T (x)] = ({p(xs = j)}s , {p(xs = j, xt = k)}s6=t ) = ({µs;j }s , {µst;jk }s6=t ) (9.55)
9 This is a vector of length d = KV + K 2 E, where K = |X | is the number of states, V = |V| is
10 the number of nodes, and E = |E| is the number of edges. Since µ completely characterizes the
11 distribution p(x), so we sometimes treat µ as a distribution itself.
12 Equation (9.55) is called the standard overcomplete representation. It is called “overcomplete”
13 because it ignores the sum-to-one constraints. In some cases, it is convenient to remove this redundancy.
14 For example, consider an Ising model where Xs ∈ {0, 1}. The model can be written as
15  
16
1  X X 
17 p(x) = exp θ s xs + θst xs xt (9.56)
Z(θ)  
18 s∈V (s,t)∈E
19
20
Hence we can use the following minimal parameterization
21
T (x) = (xs , s ∈ V ; xs xt , (s, t) ∈ E) ∈ Rd (9.57)
22
23 where d = V +E. The corresponding mean parameters are µs = p(xs = 1) and µst = p(xs = 1, xt = 1).
24
25 9.3.2 The marginal polytope
26
27 The space of allowable µ vectors is called the marginal polytope, and is denoted M(G), where G
28 is the structure of the graph. This is defined to be the set of all mean parameters for the given model
29 that can be generated from a valid probability distribution:
30 X X
31
M(G) , {µ ∈ Rd : ∃p s.t. µ = T (x)p(x) for some p(x) ≥ 0, p(x) = 1} (9.58)
x x
32
33 For example, consider an Ising model. If we have just two nodes connected as X1 − X2 , one can
34 show that we have the following minimal set of constraints: 0 ≤ µ12 , 0 ≤ µ12 ≤ µ1 , 0 ≤ µ12 ≤ µ2 ,
35 and 1 + µ12 − µ1 − µ2 ≥ 0. We can write these in matrix-vector form as
36    
0 0 1   0
 µ1
37
1 0 −1 0
38    µ2  ≥   (9.59)
0 1 −1 0
39 µ12
40
−1 −1 1 −1
41
These four constraints define a series of half-planes, whose intersection defines a polytope, as shown
42
in Figure 9.10(a).
43
Since M(G) is obtained by taking a convex combination of the T (x) vectors, it can also be written
44
as the convex hull of these vectors:
45
46 M(G) = conv{T1 (x), . . . , Td (x)} (9.60)
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
9.3. MAP ESTIMATION FOR DISCRETE PGMS

1
2
3
4
5
6
7
8
9
(a) (b) (c)
10
11 Figure 9.10: (a) Illustration of the marginal polytope for an Ising model with two variables. (b) Cartoon
12 illustration of the set MF (G), which is a nonconvex inner bound on the marginal polytope M(G). MF (G) is
13 used by mean field. (c) Cartoon illustration of the relationship between M(G) and L(G), which is used by loopy
14 BP. The set L(G) is always an outer bound on M(G), and the inclusion M(G) ⊂ L(G) is strict whenever G
15 has loops. Both sets are polytopes, which can be defined as an intersection of half-planes (defined by facets),
16
or as the convex hull of the vertices. L(G) actually has fewer facets than M(G), despite the picture. In fact,
L(G) has O(|X ||V | + |X |2 |E|) facets, where |X | is the number of states per variable, |V | is the number of
17
variables, and |E| is the number of edges. By contrast, M(G) has O(|X ||V | ) facets. On the other hand, L(G)
18
has more vertices than M(G), despite the picture, since L(G) contains all the binary vector extreme points
19
µ ∈ M(G), plus additional fractional extreme points. From Figures 3.6, 5.4 and 4.2 of [Monster]. Used
20 with kind permission of Martin Wainwright.
21
22
23
For example, for a 2 node MRF X1 − X2 with binary states, we have
24
25
M(G) = conv{(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 1)} (9.61)
26
27 These are the four black dots in Figure 9.10(a). We see that the convex hull defines the same volume
28 as the intersection of half-spaces.
29
30
9.3.3 Linear programming relaxation
31
32 We can write the MAP estimation problem as follows:
33
34 max θT T (x) = max θT µ (9.62)
x∈X V µ∈M(G)
35
36 To see why this equation is true, note that we can just set µ to be a degenerate distribution with
37 µ(xs ) = I (xs = x∗s ), where x∗s is the optimal assigment of node s. Thus we can “emulate” the task of
38 optimizing over discrete assignments by optimizing over probability distributions µ. Furthermore,
39 the non-degenerate (“soft”) distributions will not correspond to corners of the polytope, and hence
40 will not maximize a linear function.
41 It seems like we have an easy problem to solve, since the objective in Equation (9.62) is linear in
42 µ, and the constraint set M(G) is convex. The trouble is, M(G) in general has a number of facets
43 that is exponential in the number of nodes.
44 A standard strategy in combinatorial optimization is to relax the constraints. In this case, instead
45 of requiring probability vector µ to live in the marginal polytope M(G), we allow it to live inside a
46 simpler, convex enclosing set L(G), which we define in Section 9.3.3.1. Thus we try to maximize the
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


112

1
2 following upper bound on the original objective:
3
4
τ ∗ = argmax θT τ (9.63)
τ ∈L(G)
5
6 This is called a linear programming relaxation of the problem. If the solution τ ∗ is integral, it
7 corresponds to the exact MAP estimate; this will be the case when the graph is a tree. In general,
8 τ ∗ will be fractional; we can derive an approximate MAP estimate by rounding (see [Werner07] for
9 details).
10
11 9.3.3.1 A convex outer approximation to the marginal polytope
12
13 Consider a set of probability vectors τ that satisfy the following local consistency constraints:
14 X
15
τs (xs ) = 1 (9.64)
xs
16 X
17 τst (xs , xt ) = τs (xs ) (9.65)
18 xt
19
The first constraint is called the normalization constraint, and the second is called the marginalization
20
constraint. We then define the set
21
22
L(G) , {τ ≥ 0 : (Equation (9.64)) holds ∀s ∈ V, (Equation (9.65)) holds ∀(s, t) ∈ E} (9.66)
23
24 The set L(G) is also a polytope, but it only has O(|V | + |E|) constraints. It is a convex outer
25 approximation on M(G), as shown in Figure 9.10(c). (By contrast, the mean field approximation,
26 which we discuss in Main Section 10.3, is a non-convex inner approximation, as we discuss in
27 Main Section 10.3.)
28 We call the terms τs , τst ∈ L(G) pseudo marginals, since they may not correspond to marginals
29 of any valid probability distribution. As an example of this, consider Figure 9.11(a). The picture
30 shows a set of pseudo node and edge marginals, which satisfy the local consistency requirements.
31 However, they are not globally consistent. To see why, note that τ12 implies p(X1 = X2 ) = 0.8, τ23
32 implies p(X2 = X3 ) = 0.8, but τ13 implies p(X1 = X3 ) = 0.2, which is not possible (see [Monster]
33 for a formal proof). Indeed, Figure 9.11(b) shows that L(G) contains points that are not in M(G).
34 We claim that M(G) ⊆ L(G), with equality iff G is a tree. To see this, first consider an element
35 µ ∈ M(G). Any such vector must satisfy the normalization and marginalization constraints, hence
36 M(G) ⊆ L(G).
37 Now consider the converse. Suppose T is a tree, and let µ ∈ L(T ). By definition, this satisfies the
38 normalization and marginalization constraints. However, any tree can be represented in the form
39
Y Y µst (xs , xt )
40
pµ (x) = µs (xs ) (9.67)
41 µs (xs )µt (xt )
s∈V (s,t)∈E
42
43 Hence satsifying normalization and local consistency is enough to define a valid distribution for any
44 tree. Hence µ ∈ M(T ) as well.
45 In contrast, if the graph has loops, we have that M(G) 6= L(G). See Figure 9.11(b) for an example
46 of this fact. The importance of this observation will become clear in Section 10.4.3.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
9.3. MAP ESTIMATION FOR DISCRETE PGMS

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Figure 9.11: (a) Illustration of pairwise UGM on binary nodes, together with a set of pseudo marginals that
are not globally consistent. (b) A slice of the marginal polytope illustrating the set of feasible edge marginals,
16
assuming the node marginals are clamped at µ1 = µ2 = µ3 = 0.5. From Figure 4.1 of [Monster]. Used with
17
kind permission of Martin Wainwright.
18
19
20
21
22
23
24
25
26
27
28 True Disparities 0 1 ’
29
30 Figure 9.12: Illustration of belief propagation for stereo depth estimation applied to the Venus image from the
31 Middlebury stereo benchmark dataset [Scharstein02]. Left column: image and true disparities. Remaining
32
columns: initial estimate, estimate after 1 iteration, and estimate at convergence. Top row: Gaussian edge
potentials using a continuous state space. Bottom row: robust edge potentials using a quantized state space.
33
From Figure 4 of [Sudderth08bp]. Used with kind permission of Erik Sudderth.
34
35
36
37
38
39
9.3.3.2 Algorithms
40
41 Our task is to solve Equation (9.63), which requires maximizing a linear function over a simple
42 convex polytope. For this, we could use a generic linear programming package. However, this is often
43 very slow.
44 Fortunately, one can show that a simple algorithm, that sends messages between nodes in the graph,
45 can be used to compute τ ∗ . In particular, the tree reweighted belief propagation algorithm can
46 be used; see Section 10.4.5.3 for details.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


114

1
2 9.3.3.3 Application to stereo depth estimation
3
Belief propagation is often applied to low-level computer vision problems (see e.g., [Szeliski10;
4
Blake11; Prince12]). For example, Figure 9.12 illustrates its application to the problem of stereo
5
depth estimation given a pair of monocular images (only one is shown). The value xi is the
6
distance of pixel i from the camera (quantized to a certain number of values). The goal is to infer
7
these values from noisy measurements. We quantize the state space, rather than using a Gaussian
8
model, in order to avoid oversmoothing at discontinuities, which occur at object boundaries, as
9
illustrated in Figure 9.12. (We can also use a hybrid discrete-continuous state space, as discussed in
10
[Yamaguchi2012], but we can no longer apply BP.)
11
Not surprisingly, people have recently applied deep learning to this problem. For example,
12
[Xu2019stereo] describes a differentiable version of message passing (Main Section 9.4), which is
13
fast and can be trained end-to-end. However, it requires labeled data for training, i.e., pixel-wise
14
ground truth depth values. For this particular problem, such data can be collected from depth
15
cameras, but for other problems, BP on “unsupervised” MRFs may be needed.
16
17
18
9.3.4 Graphcuts
19 In this section, we show how to find MAP state estimates, or equivalently, minimum energy con-
20 figurations, by using the maxflow / mincut algorithm for graphs. This class of methods is
21 known as graphcuts and is very widely used, especially in computer vision applications (see e.g.,
22 [Boykov2004]).
23 We will start by considering the case of MRFs with binary nodes and a restricted class of potentials;
24 in this case, graphcuts will find the exact global optimum. We then consider the case of multiple
25 states per node; we can approximately solve this case by solving a series of binary subproblems, as
26 we will see.
27
28 9.3.4.1 Graphcuts for the Ising model
29
30 Let us start by considering a binary MRF where the edge energies have the following form:

31
0 if xu = xv
Euv (xu , xv ) = (9.68)
32
λuv if xu 6= xv
33
34 where λst ≥ 0 is the edge cost. This encourages neighboring nodes to have the same value (since we
35 are trying to minimize energy). Since we are free to add any constant we like to the overall energy
36 without affecting the MAP state estimate, let us rescale the local energy terms such that either
37 Eu (1) = 0 or Eu (0) = 0.
38 Now let us construct a graph which has the same set of nodes as the MRF, plus two distinguished
39 nodes: the source s and the sink t. If Eu (1) = 0, we add the edge xu → t with cost Eu (0). Similarly,
40 If Eu (0) = 0, we add the edge s → xu with cost Eu (1). Finally, for every pair of variables that are
41 connected in the MRF, we add edges xu → xv and xv → xu , both with cost λu,v ≥ 0. Figure 9.13
42 illustrates this construction for an MRF with 4 nodes and the following parameters:
43
E1 (0) = 7, E2 (1) = 2, E3 (1) = 1, E4 (1) = 6 λ1,2 = 6, λ2,3 = 6, λ3,4 = 2, λ1,4 = 1 (9.69)
44
45 Having constructed the graph, we compute a minimal s − t cut. This is a partition of the nodes into
46 two sets, Xs and Xt , such that s ∈ Xs and t ∈ Xt . We then find the partition which minimizes the
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
9.3. MAP ESTIMATION FOR DISCRETE PGMS

1
2
t
3
4 7

5 6
z1 z2
6
7
8 1 6

9 2
10 z4 z3
2
11
12 6 1

13
s
14
15
Figure 9.13: Illustration of graphcuts applied to an MRF with 4 nodes. Dashed lines are ones which contribute
to the cost of the cut (for bidirected edges, we only count one of the costs). Here the min cut has cost 6. From
16
Figure 13.5 from [KollerBook]. Used with kind permission of Daphne Koller.
17
18
19
sum of the cost of the edges between nodes on different sides of the partition:
20
X
21 cost(Xs , Xt ) = cost(xu , xv ) (9.70)
22 xu ∈Xs ,xv ∈Xt
23
In Figure 9.13, we see that the min-cut has cost 6. Minimizing the cost in this graph is equivalent
24
to minimizing the energy in the MRF. Hence nodes that are assigned to s have an optimal state of
25
0, and the nodes that are assigned to t have an optimal state of 1. In Figure 9.13, we see that the
26
optimal MAP estimate is (1, 1, 1, 0).
27
Thus we have converted the MAP estimation problem to a standard graph theory problem for
28
which efficient solvers exist (see e.g., [CLR90]).
29
30
31
9.3.4.2 Graphcuts for binary MRFs with submodular potentials
32 We now discuss how to extend the graphcuts construction to binary MRFs with more general kinds
33 of potential functions. In particular, suppose each pairwise energy satisfies the following condition:
34
35 Euv (1, 1) + Euv (0, 0) ≤ Euv (1, 0) + Euv (0, 1) (9.71)
36
In other words, the sum of the diagonal energies is less than the sum of the off-diagonal energies.
37
In this case, we say the energies are submodular (Main Section 6.9). An example of a submodular
38
energy is an Ising model where λuv > 0. This is also known as an attractive MRF or associative
39
MRF, since the model “wants” neighboring states to be the same.
40
It is possible to modify the graph construction process for this setting, and then apply graphcuts,
41
such that the resulting estimate is the global optimum [Greig89].
42
43
9.3.4.3 Graphcuts for nonbinary metric MRFs
44
45 We now discuss how to use graphcuts for approximate MAP estimation in MRFs where each node
46 can have multiple states [Boykov01].
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


116
IEEE Transactions on PAMI, vol. 23, no. 11, pp. 1222-1239 p.8

1
2
3
4
5
6
7
8
9
10 (a) initial labeling (b) standard move (c) α-β-swap (d) α-expansion
11
12 Figure 9.14: (a) 2:
Figure AnExamples
image with 3 labels. (b)
of standard andAlarge
standard
moveslocal
frommove (e.g.,
a given by iterative
labeling conditional
(a). The number modes)
of just
13 flips the label of one pixel. (c) An α − β swap allows all nodes that are currently labeled as α to be relabeled
labels
as β if this is |L| =
decreases the3. energy.
A standard move
(d) An (b) changesallows
α expansion a labelallofnodes
a single pixel
that are (in
notthe circled area).
currently labeled as α to
14
Strong
be relabeled as αmoves
if this(c-d) allow the
decreases large number
energy. of pixels
From to 2change
Figure their labels Used
of [Boykov01]. simultaneously.
with kind permission of
15
Ramin Zabih.
16
17 3.1 Partitions and move spaces
18
Any labeling
One approach is tof canusebe uniquely
alpha representedAt
expansion. by each
a partition
step, ofitimage
pickspixels
one of = {P
P the l | l ∈ L}labels or
available
19
20 states and calls it
where P l α; then it solves
= {p ∈ P | f p = l} a binary subproblem where each variable can choose
is a subset of pixels assigned label l. Since there is an to remain in
obvious
21 its current
onestate,
to oneorcorrespondence
to become state betweenα (see Figuref9.14(d)
labelings for an P,
and partitions illustration).
we can use these notions
22 Another approach is to use alpha-beta swap. At each step, two labels are chosen, call them α
interchangingly.
23 and β. All theGiven nodesa paircurrently
of labelslabeled
α, β, a α canfrom
move change to β (and
a partition vice versa)
P (labeling if this
f ) to a new reduces
partitionthe energy
24 (see Figure! 9.14(c) for! an illustration). !
P (labeling f ) is called an α-β swap if Pl = Pl for any label l "= α, β. This means that
25 In order to solve these binary subproblems !
optimally, we need to ensure the potentials for these
the only
subproblems difference between
are submodular. ThisPwillandbePthe is that
case some
if thepixels thatenergies
pairwise were labeled
formαa in P are now
metric. We call such
26
! !
27 a model a metric MRF. For example, suppose the states have a natural ordering, asAcommonly
labeled β in P , and some pixels that were labeled β in P are now labeled α in P . special arises
28 if they are
casea of
discretization
an α-β swap of is aan underlying
move that givescontinuous
the label αspace.
to someInset this case, we
of pixels can define
previously a metric of
labeled
29 the form β.
E(x s , xexample
One t ) = min(δ, of α-β||xsswap ||) orisa shown
− xtmove semi-metric
in Fig. of the form E(xs , xt ) = min(δ, (xs − xt )2 ), for
2(c).
30 some constant Givenδ >a 0. labelThisα, aenergy encourages
move from a partitionneighbors
P (labelingto have
f ) tosimilar labels, but
a new partition never “punishes”
P! (labeling
31 them by more
! than δ. (This δ term prevents ! over-smoothing,
! which we illustrate in Figure 9.12.)
f ) is called an α-expansion if Pα ⊂ Pα and Pl ⊂ Pl for any label l "= α. In other words, an
32
α-expansion move allows any set of image pixels to change their labels to α. An example of
33 9.3.4.4 Application to stereo depth estimation
an α-expansion move is shown in Fig. 2(d).
34
Graphcuts is oftenthat
Recall applied to low-level
ICM and annealing computer
use standardvision
movesproblems, suchone
allowing only aspixel
stereo depth estimation,
to change its
35
which weintensity.
discussed in Section 9.3.3.3. Figure 9.15 compares graphcuts (both swap
An example of a standard move is given in Fig. 2(b). Note that a move which and expansion
36
version) to two other algorithms (simulated annealated, and a patch matching method based on
37 assigns a given label α to a single pixel is both an α-β swap and an α-expansion. As a
normalization cross correlation) on the famous Tsukuba test image. The graphcuts approach
38
works theconsequence,
best on thisa standard
example,moveasiswell
a special case of[Szeliski08;
as others both a α-β swap and an α-expansion.
Tappen2003]. It also tends to
39
outperform belief propagation (results not shown) in terms of speed and accuracy on stereo problems
40
[Szeliski08; Tappen2003], as well as other problems such as CRF labeling of LIDAR point cloud
41
data [Landrieu2017].
42
43
44
45
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
9.3. MAP ESTIMATION FOR DISCRETE PGMS

1
2
3
4
5
6
7
8
IEEE Transactions on PAMI, vol. 23, no. 11, pp. 1222-1239 p.29
9
10
11
12
13
14
15
16
17
18 (a) Left image: 384x288, 15 labels (b) Ground truth

19
20
21
22
23
24 (c) Swap algorithm (d) Expansion algorithm

25
26
27
28
29
(e) Normalized correlation (f) Simulated annealing
30
31 Figure 10: Real imagery with ground truth
Figure 9.15: An example of stereo depth estimation using MAP estimation in a pairwise discrete MRF. (a)
32
Left image, of size 384 × 288 pixels, from the University of Tsukuba. (The corresponding right image is
33 similar, but not shown.) (b) Ground truth depth map, quantized to 15 levels. (c-f ): MAP estimates using
34 different methods: (c) α − β swap, (d) α expansion, (e) normalized cross correlation, (f ) simulated annealing.
35 From Figure 10 of [Boykov01]. Used with kind permission of Ramin Zabih.
36
37
38
39
40
41
42
43
44
45
46
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


10 Variational inference

10.1 More Gaussian VI


In this section, we give more examples of Gaussian variational inference.

10.1.1 Example: Full-rank vs diagonal GVI on 1d linear regression


In this section, we give a comparison of HMC (Main Section 12.5) and Gaussian VI using both a
mean field and full rank approximation. We use a simple example from [rethinking2].1 Here the
goal is to predict (log) GDP G of various countries (in the year 2000) as a function of two input
variables: the ruggedness R of the country’s terrain, and whether the country is in Africa or not (A).
Specifically, we use the following 1d regression model:

yi ∼ N (µi , σ 2 ) (10.1)
µi = α + β xi T
(10.2)
α ∼ N (0, 10) (10.3)
βj ∼ N (0, 1) (10.4)
σ ∼ Unif(0, 10) (10.5)

where xi = (Ri , Ai , Ai × Ri ) are the features, and yi = Gi is the response. Note that this is almost a
conjugate model, except for the non-conjugate prior on σ.
We first use HMC, which is often considered the “gold standard” of posterior inference. The resulting
model fit is shown in Figure 10.1. This shows that GDP increases as a function of ruggedness for
African countries, but decreases for non-African countries. (The reasons for this are unclear, but
[Nunn2012] suggest that it is because more rugged Africa countries were less exploited by the slave
trade, and hence are now wealthier.)
Now we consider a variational approximation to the posterior, of the form p(θ|D) ≈ q(θ) =
q(θ|µ, Σ). (Since the standard deviation σ must line in the interval [0, 10] due to the uniform
prior, first transform it to the unconstrained value τ = logit(σ/10) before applying the Gaussian
approximation, as explained in Main Section 10.2.2.)
Suppose we initially choose a diagonal Gaussian approximation. In Figure 10.2a, we compare the
marginals of this posterior approximation (for the bias term and the 3 regression coefficients) with
the “exact” posterior from HMC. We see that the variational marginals have roughly the same mean,

1. We choose this example since it is used as the introductory example in the Pyro tutorial.
120

1
2
3
4 Non-African nations African nations
5 11
6
7 10

8
9 9
log GDP (2000)

log GDP (2000)


10
11 8

12
13 7

14
6
15
16 0 1 2 3 4 5 0 1 2 3 4 5 6
Terrain Ruggedness Index Terrain Ruggedness Index
17
18
19 Figure 10.1: Posterior predictive distribution for the linear model applied to the Africa data. Dark shaded
20 region is the 95% credible interval for µi . The light shaded region is the 95% credible interval for yi . Adapted
21 from Figure 8.5 of [rethinking2]. Generated by linreg_bayes_svi_hmc.ipynb.
22
23
24
25 Marginal Posterior density - Regression Coefficients HMC
Marginal Posterior density - Regression Coefficients HMC
Diag Full
26
bias weight 0 bias weight 0
7 3.0 2.00
27 3.0
6 1.75
2.5
2.5
28 5 1.50
2.0 2.0
1.25
29 4
Density

Density

Density

Density
1.5 1.5 1.00
3
30 0.75
1.0 1.0
2
0.50
31 1 0.5 0.5
0.25
0 0.0 0.0 0.00
32 8.6 8.8 9.0 9.2 9.4 9.6 9.8 2.75 2.50 2.25 2.00 1.75 1.50 1.25 1.00 8.6 8.8 9.0 9.2 9.4 9.6 9.8 2.75 2.50 2.25 2.00 1.75 1.50 1.25 1.00
weight 1 weight 2 weight 1 weight 2
5
33 3.5

8 5 3.0
34 4
4 2.5
6 3
35 2.0
Density

Density

Density

Density

3
36 4 2 1.5
2
1.0
37 2 1 1 0.5

38 0
0.4 0.3 0.2 0.1 0.0 0.1
0
0.0 0.2 0.4 0.6 0.8
0
0.5 0.4 0.3 0.2 0.1 0.0 0.1
0.0
0.2 0.0 0.2 0.4 0.6 0.8

39
40
41 (a) (b)
42
Figure 10.2: Posterior marginals for the linear model applied to the Africa data. (a) Blue is HMC, orange is
43
Gaussian approximation with diagonal covariance. (b) Blue is HMC, orange is Gaussian approximation with
44
full covariance. Generated by linreg_bayes_svi_hmc.ipynb.
45
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
10.1. MORE GAUSSIAN VI

1
Cross-section of the Posterior Distribution HMC
Diag
Cross-section of the Posterior Distribution HMC
Full
2 0.1 0.8 0.1 0.8

3
0.0 0.0
0.6 0.6
4
0.1 0.1
5 0.4 0.4
0.2 0.2

bAR

bAR
bR

bR
6
0.2 0.2
0.3 0.3
7
8 0.4 0.0 0.4 0.0

9 0.5 0.5
2.4 2.2 2.0 1.8 1.6 1.4 1.2 0.4 0.3 0.2 0.1 0.0 2.4 2.2 2.0 1.8 1.6 1.4 1.2 0.4 0.3 0.2 0.1 0.0
bA bR bA bR
10
11 (a) (b)
12
13 Figure 10.3: Joint posterior of pairs of variables for the linear model applied to the Africa data. (a) Blue is
14 HMC, orange is Gaussian approximation with diagonal covariance. (b) Blue is HMC, orange is Gaussian
15 approximation with full covariance. Generated by linreg_bayes_svi_hmc.ipynb.
16
MCMC Predictive distribution FFVB Predictive Distribution NAGVAC Predictive Distribution
17
18
19
20
21
22
23
24
(a) (b) (c)
25
26
Figure 10.4: Bayesian inference applied to a 2d binary logistic regression problem, p(y = 1|x) = σ(w0 +
27 w1 x1 + w2 x2 ). We show the training data and the posterior predictive produced by different methods. (a)
28 MCMC approximation. (b) VB approximation using full covariance matrix (Cholesky decomposition). (c) VB
29 using rank 1 approximation. Generated by vb_gauss_biclusters_demo.ipynb.
30
31
32
but their variances are too small, meaning they are overconfident. Furthermore, the variational
33
approximation neglects any posterior correlations, as shown in Figure 10.3a.
34
We can improve the quality of the approximation by using a full covariance Gaussian. The
35
resulting posterior marginals are shown in Figure 10.2b, and some bivariate posteriors are shown in
36
Figure 10.3b. We see that the posterior approximation is now much more accurate.
37
Interestingly, both variational approximations give a similar predictive distribution to the HMC
38
one in Figure 10.1. However, in some statistical problems we care about interpreting the parameters
39
themselves (e.g., to assess the strength of the dependence on ruggedness), so a more accurate
40
approximation is necessary to avoid reaching invalid conclusions.
41
42
43
10.1.2 Example: Full-rank vs rank-1 GVI for logistic regression
44 In this section, we compare full-rank GVI, rank-1 GVI and HMC on a simple 2d binary logistic
45 regression problem. The results are shown in Figure 10.4. We see that the predictive distribution
46 from the VI posterior is similar to that produced by MCMC.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


122

1
2 10.1.3 Structured (sparse) Gaussian VI
3
4
In many problems, the target posterior can represented in terms of a factor graph (see Main Sec-
5
tion 4.6.1). That is, we assume the negative log unnormalized joint probability (energy) can be
6
decomposed as follows:
7
C
X
8
− log p(z, D) = φ(z) = φc (zc ) (10.6)
9
c=1
10
11
where φc is the c’th clique potential. Note that the same variables can occur in multiple potential
12
functions, but the dependency structure can be represented by a sparse graph G. Below we show
13
that the optimal Gaussian variational posterior q(z) = N (z|µ, Λ−1 ) will have a precision matrix Λ
14
with the same sparsity structure as G. Furthermore, the natural gradient of the ELBO will also enjoy
15
the same sparsity structure. This allows us to use VI with the same accuracy as using a full-rank
16
covariance matrix but with efficiency closer to mean field (the cost per gradient step is potentially
17
only linear in the number of latent variables, depending on the treewidth of G).
18
19
20
10.1.3.1 Sparsity of the ELBO
21
22 To see why the optimal q is sparse, recall that the negative ELBO consists of two terms: the expected
23 energy, −Eq(z) [log p(z, D)], minus the entropy, H(N (µ, Λ−1 )). That is,
24
25
1
26 V (ψ) = Eq(z|µ,Λ) [φ(z)] + log(|Λ|) (10.7)
2
27
28
where ψ = (µ, Λ). To compute the first term, we only need the marginals q(zc ) for each clique, as
29
we see from Equation (10.6), so any depedencies with variables outside of zc are irrelevant. Thus
30
the optimal q will have the same sparsity structure as G, since this maximizes the entropy (no
31
unnecessary constraints). In other words, the optimal Gaussian q will be the following Gaussian
32
MRF (see Main Section 4.3.5):
33
34
C
Y
35
q(z|µ, Λ) ∝ Nc (zc |Λc µc , Λc ) (10.8)
36
c=1
37
38
where Nc (x|h, K) is a Gaussian distribution in canonical (information) form with precision K and
39
precision weighted mean h. (For a more formal proof of this result, see e.g, [Barfoot2020sparse;
40
Courts2021]).
41
42
43
10.1.3.2 Sparsity of the natural gradient of the ELBO
44
45 In [Barfoot2020sparse], they show that the natural gradient of the ELBO also inherits the same
46 sparsity structure as G. In particular, at iteration i, they derive the following updates for the
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
10.1. MORE GAUSSIAN VI

1
2 variational parameters:
3   X C  
∂2 ∂2
4 Λi+1 = Eqi φ(z) = PT
E
c qci φ c (zc ) Pc (10.9)
5 ∂zT ∂z c=1
∂zTc ∂zc
6   XC  
i+1 i+1 ∂ ∂
7 Λ δ = −Eqi φ(z) = − Eqci φc (zc ) (10.10)
8 ∂zT c=1
∂zTc
9
µi+1 = µi + δ i+1 (10.11)
10
11 where Pc is the projection matrix that extracts zc from z, i.e. zc = Pc z. We can calculate the
12 gradient gc and Hessian Hc of each factor c using automatic differentiation. We can then compute
13 δ i+1 = −(Λi+1 )−1 Eqi [g] by solving a sparse linear system.
14
15 10.1.3.3 Computing posterior expectations
16
Finally, we discuss how to compute the expectations needed to evaluate the (gradient of the) ELBO.
17
(We drop the i superscript for brevity.) One approach, discussed in [Barfoot2020sparse], is to
18
use quadrature. This requires access to the marginals qc (zc ) = N (zc |µc , Σc ) for each factor. Note
19
that Σc 6= (Λc )−1 , so we cannot just invert each local block of the precision matrix. Instead we
20
must compute the covariance for the full joint, Σ = Λ−1 , and then extract the relevant blocks,
21
Σc . Fortunately, there are various methods, such as Takahashi’s algorithm [Takahashi1973;
22
Barfoot2020fundamental] for efficiently computing the blocks Σc without first needing to compute
23
all of Σ. Alternatively, we can just use message passing on a Gaussian junction tree, as explained in
24
Main Section 2.3.3.
25
26
27
10.1.3.4 Gaussian VI for nonlinear least squares problems
28 We now consider the special case where the energy function can be written as a nonlinear least
29 squares objective:
30
C
31 1 1X
φ(z) = e(z)T W−1 e(z) = ec (zc )Wc−1 ec (zc ) (10.12)
32 2 2 c=1
33
34 where e(z) = [ec (zc )]C
c=1 is a vector of error terms, zc ∈ R , ec (zc ) ∈ R , W = diag(W1 , . . . , Wc ),
dc nc

35 and Wc ∈ R nc ×nc
. In this case, [Barfoot2020sparse] propose the following alternative objective
36 that is more conservative (entropic):
37 1 1
T
38 V 0 (ψ) = Eq [e(z)] W−1 Eq [e(z)] + log(|Λ|) (10.13)
2 2
39
40
This can be optimized using a Gauss-Newton method, that avoids the need to compute the Hessian
41
of each factor. Let us define the expected error vector at iteration i as
42 ei = Eqi [e(z)] = [Eqci [ec (zc )]]c = [eic ]c (10.14)
43
44
Similarly the expected Jacobian at iteration i is
    
45 i ∂ ∂ i
46
E = Eqi e(z) = [Eqci ec (zc ) = [Jc ]c (10.15)
∂z ∂zc c
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


124

1
2 where Jic ∈ Rnc ×dc is the Jacobian matrix of ec (zc ) wrt inputs zcj for j = 1 : dc . Then the updates
3 are as follows:
4 i i i i
5 Λi+1 = (E )T W−1 E = [(Jc )T Wc−1 Jc ]c (10.16)
i+1 i+1 i T −1 i i
6
Λ δ = (E ) W e = [(Jc )T Wc−1 eic ]c (10.17)
7
i+1
8 µi+1 = µi + δ (10.18)
9
10 10.2 Online variational inference
11
12 In this section, we discuss how to perform online variational inference. In particular, we discuss
13 the streaming variational Bayes (SVB) approach of [Broderick2013] in which we, at step t,
14 we compute the new posterior using the previous posterior as the prior:
15 
16 ψ t = argmin Eq(θ|ψ) [`t (θ)] + DKL q(θ|ψ) k q(θ|ψ t−1 ) (10.19)
17
ψ | {z }
−Łt (ψ)
18  
19
= argmin Eq(θ|ψ) `t (θ) + log q(θ|ψ) − log q(θ|ψ t−1 ) (10.20)
ψ
20
21 where `t (θ) = − log p(Dt |θ) is the negative log likelihood (or, more generally, some loss function)
22 of the data batch at step t. This approach is also called variational continual learning or VCL
23 [Nguyen2018]. (We discuss continual learning in Main Section 19.7.)
24
25
10.2.1 FOO-VB
26
27 In this section, we discuss a particular implementation of sequential VI called FOO-VB, which
28 stands for “Fixed-point Operator for Online Variational Bayes” [Zeno2021]. This assumes Gaussian
29 priors and posteriors. In particular, let
30
31 q(θ|ψ t ) = N (θ|µ, Σ), q(θ|ψ t−1 ) = N (θ|m, V) (10.21)
32
33 In this case, we can write the ELBO as follows:
34  
1 det(V) −1 T −1
35 Łt (µ, Σ) = log − D + tr(V Σ) + (m − µ) V (m − µ) + Eq(θ|µ,Σ) [`t (θ)] (10.22)
36
2 det(Σ)
37
where D is the dimensionality of θ.
38
Let Σ = LLT . We can compute the new variational parameters by solving the joint first order
39
stationary conditions, ∇µ Łt (µ, L) = 0 and ∇L Łt (µ, L) = 0. For the derivatives of the KL term, we
40
use the identities
41
42 ∂tr(V−1 Σ) X
−1
43
=2 Vin Lnj (10.23)
∂Lij n
44
∂ log | det(L)|
45 = L−T
ij (10.24)
46 ∂Lij
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
10.2. ONLINE VARIATIONAL INFERENCE

1
2 For the derivatives of the expected loss, we use the the reparameterization trick, θ = µ + L, and
3 following identities:
4
5
Eq(θ|µ,L) [`t (θ)] = E [`t (θ())] (10.25)
 
6 ∂E [`t (θ)] ∂`t (θ)
= E j (10.26)
7 ∂Lij ∂θi
8
9
Note that the expectation depends on the unknown variational parameters for qt , so we get a fixed
10
point equation which we need to iterate. As a faster alternative, [Zeno2018; Zeno2021] propose to
11
evaluate the expectations using the variational parameters from the previous step, which then gives
12
the new parameters in a single step, similar to EM.
13
We now derive the update equations. From ∇µ Łt (µ, L) = 0 we get
14
0 = −V−1 (m − µ) + E [∇`t (θ)] (10.27)
15
µ = m − VE [∇`t (θ)] (10.28)
16
17 From ∇L Łt (µ, L) = 0. we get
18
X  
19 −1 ∂`t (θ)
0 = −(L−T )ij + Vi,n Ln,j + E j (10.29)
20
n
∂θi
21
22 In matrix form, we have
23  
0 = −L−T + V−1 L + E ∇`t (θ)T (10.30)
24
25 Explicitly solving for L in the case of a general (or low rank) matrix Σ is somewhat complicated; for
26 the details, see [Zeno2021]. Fortunately, in the case of a diagonal appproximation, things simplify
27 significantly, as we discuss in Section 10.2.2.
28
29
10.2.2 Bayesian gradient descent
30
31 In this section, we discuss a simplification of FOO-VB to the diagonal case. In [Zeno2018], they
32 call the resulting algorithm “Bayesian gradient descent”, and they show it works well on some
33 continual learning problems (see Main Section 19.7).
34 Let V = diag(vi2 ), Σ = diag(σi2 ), so L = diag(σi ). Also, let gi = ∂`∂θ
t (θ)
i
, which depends on i .
35 Then Equation (10.28) becomes
36
37 µi = mi − ηvi2 Ei [gi (i )] (10.31)
38
where we have included an explicit learning rate η to compensate for the fact that the fixed point
39
equation update is approximate. For the variance terms, Equation (10.30) becomes
40
41 1 diag(σi )
0=− + + Ei [gi i ] (10.32)
42 diag(σi ) diag(vi2 )
43
44
This is a quadratic equation for each σi :
45 1 2
σ + Ei [gi i ] σi − 1 = 0 (10.33)
46
vi2 i
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


126

1
2 the solution of which is given by the following (since σi > 0):
3 p q
−Ei [gi i ] + (Ei [gi i ])2 + 4/vi2 1 1
4
σi = 2 = − vi2 Ei [gi i ] + vi2 (Ei [gi i ])2 + 4/vi2 (10.34)
5 2/vi 2 2
r r
6
1 vi4 1 1
7 = − vi2 Ei [gi i ] + ((Ei [gi i ])2 + 4/vi2 ) = − vi2 Ei [gi i ] + vi 1 + ( vi Ei [gi i ])2
2 4 2 2
8
(10.35)
9
10 We can approximate the above expectations using K Monte Carlo samples. Thus the overall algorithm
11 is very similar to standard SGD, except we compute the gradient K times, and we update µ ∈ RD
12 and σ ∈ RD rather than θ ∈ RD . See Algorithm 10.1 for the pseudocode, and see [Kurle2020] for
13 a related algorithm.
14
15
Algorithm 10.1: One step of Bayesian gradient descent
16
17
1 Function (µt , σ t , Łt ) = BGD-update(µt−1 , σ t−1 , Dt ; η, K):
18
2 for k = 1 : K do
19
3 Sample k ∼ N (0, I)
20 4 θ k = µt−1 + σ t−1 k
21 5 g k = ∇θ − log p(Dt |θ)|θk
22 6 for i = 1 : D do
23 1
PK k
7 E1i = K k=1 gi
24 1
PK k k
8 E2i = K k=1 gi i
25 2
9 µt,i = µt−1,i − σt−1,i E1i
26 q
27 10 σt,i = σt−1,i 1 + ( 21 σt−1,i E2i )2 − 12 σt−1,i
2
E2i
28
11 for k = 1 : K do
29
12 θ k = µt + σ t k
30
13 `kt = − log p(Dt |θ k )
31 h P i
K
32 14 Łt = − K 1 k
k=1 `t + DKL N (µt , σ t ) k N (µt−1 , σ t−1 )
33
34
35
36
10.3 Beyond mean field
37
38 In this sections, we discuss various improvements to VI that go beyond the mean field approximation.
39
40
10.3.1 Exploiting partial conjugacy
41
42 If the full conditionals of the joint model are conjugate distributions, we can use the VMP approach
43 of Main Section 10.3.7 to approximate the posterior one term at a time, similar to Gibbs sampling
44 (Main Section 12.3). However, in many models, some parts of the joint distribution are conjugate,
45 and some are non-conjugate. In [Khan2017aistats] they proposed the conjugate-computation
46 variational inference or CVI method to tackle models of this form. They exploit the partial
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
10.3. BEYOND MEAN FIELD

1
2 conjugacy to perform some updates in closed form, and perform the remaining updates using
3 stochastic approximations.
4 To explain the method in more detail, let us assume the joint distribution has the form
5
6
p(y, z) ∝ p̃nc (y, z)p̃c (y, z) (10.36)
7
where z are all the latents (global or local), y are all the observabels (data)2 , pc is the conjugate
8
part, pnc is the non-conjugate part, and the tilde symbols indicate that these distributions may not
9
be normalized wrt z. More precisely, we assume the conjugate part is an exponential family model
10
of the following form:
11
12 p̃c (y, z) = h(z) exp[T (z)T η − Ac (η)] (10.37)
13
14
where η is a known vector of natural parameters. (Any unknown model parameters should be
15
included in the latent state z, as we illustrate below.) We also assume that the variational posterior
16
is an exponential family model with the same sufficient statistics, but different parameters:
17
q(z|λ) = h(z) exp[T (z)T η − A(λ)] (10.38)
18
19 The mean parameters are given by µ = Eq [T (z)]. We assume the sufficient statistics are minimal,
20 so that there is a unique 1:1 mapping between λ and µ: using
21
22
µ = ∇λ A(λ) (10.39)

23 λ = ∇µ A (µ) (10.40)
24
25
where A∗ is the conjugate of A (see Main Section 2.4.4). The ELBO is given by
26
Ł(λ) = Eq [log p(y, z) − log q(z|λ)] (10.41)
27
Ł(µ) = Ł(λ(µ)) (10.42)
28
29 The simplest way to fit this variational posterior is to perform SGD on the ELBO wrt the natural
30 parameters:
31
32 λt+1 = λt + ηt ∇λ Ł(λt ) (10.43)
33
34
(Note the + sign in front of the gradient, since we are maximizing the ELBO.) The above gradient
35
update is equivalent to solving the following optimization problem:
36 1
λt+1 = argmin (∇λ Ł(λt ))T λ − ||λ − λt ||22 (10.44)
37
λ∈Ω 2ηt
38
| {z }
J(λ)
39
40 where Ω is the space of valid natural parameters, and || · ||2 is the Euclidean norm. To see this, note
41 that the first order optimality conditions satisfy
42
1
43 ∇λ J(λ) = ∇λ L(λt ) − (2λ − 2λt ) = 0 (10.45)
44
2ηt
45
2. We denote observables by y since the examples we consider later on are conditional models, where x denote the
46 inputs.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


128

1
2 from which we get Equation (10.43).
3 We can replace the Euclidean distance with a more general proximity function, such as the
4 Bregman divergence between the distributions (see Main Section 5.1.10). This gives rise to the
5 mirror descent algorithm (Section 6.1.4). We can also perform updates in the mean parameter
6 space. In [Raskutti2015], they show that this is equivalent to performing natural gradient updates
7 in the natural parameter space. Thus this method is sometimes called natural gradient VI or
8 NGVI. Combining these two steps gives the following update equation:
9
1
10 µt+1 = argmin(∇µ Ł(µt ))T µ − BA∗ (µ||µt ) (10.46)
µ∈M ηt
11
12 where M is the space of valud mean parameters and ηt > 0 is a stepsize.
13 In [Khan2017aistats], they show that the above update is equivalent to performing exact Bayesian
14 inference in the following conjugate model:
15
T
16 q(z|λt+1 ) ∝ eT (z) λ̃t
p̃c (y, z) (10.47)
17
18 We can think of the first term as an exponential family approximation to the non-conjugate part of
19 the model, using local variational natural parameters λ̃t . (These are similar to the site parameters
20 used in expectation propagation Main Section 10.7.) These can be computed usng the following
21 recursive update:
22
λ̃t = (1 − ηt )λ̃t−1 + ηt ∇µ Eq(z|µt ) [log p̃nc (y, z)] (10.48)
23
24
where λ̃0 = 0 and λ̃1 = η. (Details on how to compute this derivative are given in Main Section 6.4.5.)
25
Once we can have “conjugated” the non-conjugate part, the natural parameter of the new variational
26
posterior is obtained by
27
28 λt+1 = λ̃t + η (10.49)
29
30 This corresponds to a multiplicative update of the form
31 h iηt
32
qt+1 (z) ∝ qt (z)1−ηt exp(λ̃Tt T (z)) (10.50)
33
34
We give some examples of this below.
35
36 10.3.1.1 Example: Gaussian process with non-Gaussian likelihoods
37 In Main Chapter 18, we discuss Gaussian processes, which are a popular model for non-parametric
38 regression. Given a set of N inputs xn ∈ X and outputs yn ∈ R, we define the following joint
39 Gaussian distribution:
40 "N #
41 Y
p(y1:N , z1:N |X) = 2
N (yn |zn , σ ) N (z|0, K) (10.51)
42
n=1
43
44 where K is the kernel matrix computed using Kij = K(xi , xj ), and zn = f (xn ) is the unknown
45 function value for input n. Since this model is jointly Gaussian, we can easily compute the exact
46 posterior p(z|y) in O(N 3 ) time. (Faster approximations are also possible, see Main Section 18.5.)
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
10.3. BEYOND MEAN FIELD

1
2 One challenge with GPs arises when the likelihood function p(yn |zn ) is non-Gaussian, as occurs
3 with classification problems. To tackle this, we will use CVI. Since the conjugate part of the
4 model is a Gaussian, we require that the variational approximation also be Gaussian, so we use
5 q(z|λ) = N (z|λ(1) , λ(2) ).
6 Since the likelihood term factorizes across data points n = 1 : N , we will only need to compute
7 marginals of this variational posterior. From Main Section 2.4.2.3 we know that the sufficient statistics
8 and natural parameters of a univariate Gaussian are given by
9
10 T (zn ) = [zn , zn2 ] (10.52)
11 mn 1
λn = [ ,− ] (10.53)
12 vn 2vn
13
14
The corresponding moment parameters are
15
µn = [mn , m2n + vn ] (10.54)
16
17 mn = vn λ(1)
n (10.55)
18 1
19
vn = (2)
(10.56)
2λn
20
21 We need to compute the gradient terms ∇µn EN (zn |µn ) [log p(yn |zn )]. We can do this by sampling
22 zn from the local Gaussian posterior, and then pushing gradients inside, using the results from
(1) (2)
23 Main Section 6.4.5.1. Let the resulting stochastic gradients at step t be ĝn,t and ĝn,t . We can then
24 update the likelihood approximation as follows:
25
(i) (i) (i)
26 λ̃n,t = (1 − ηt )λ̃n,t−1 + ηt ĝn,t (10.57)
27
28 We can also perform a “doubly stochastic” approximation (as in Main ??) by just updating a random
29 subset of these terms. Once we have updated the likelihood, we can update the posterior using
30 "N #
31 Y (1) 2 (2)

32
q(z|λt+1 ) ∝ e zn λ̃n,t +zn λ̃n,t
N (z|0, K) (10.58)
n=1
33 " N
#
34 Y (1) (2)
∝ Nc (zn |λ̃n,t , λ̃n,t ) N (z|0, K) (10.59)
35
n=1
36 " N
#
37
Y
= N (zn |m̃n,t , ṽn,t ) N (z|0, K) (10.60)
38
n=1
39 " N
#
Y
40
= N (m̃n,t |zn , ṽn,t ) N (z|0, K) (10.61)
41
n=1
42
43 where m̃n,t and ṽn,t are derived from λ̃n,t . We can think of this as Gaussianizing the likelihood at
44 each step, where we replace the observations yn by pseudo-observations m̃n,t and use a variational
45 variance ṽn,t . This lets us use exact GP regression updates in the inner loop. See [Shi2019gp;
46 Chang2020gp] for details.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


130

1
2 10.3.1.2 Example: Bayesian logistic regression
3
In this section, we discuss how to compute a Gaussian approximation to p(w|D) for a binary logistic
4
regression model with a Gaussian prior on the weights. We will use CVI in which we “Gaussianize”
5
the likelihoods, and then perform closed form Bayesian linear regression in the inner loop. This is
6
similar to the approach used in Main Section 15.3.8, where we derive a quadratic lower bound to the
7
log likelihood. However, such “local VI” methods are not guaranteed to converge to a local maximum
8
of the ELBO [Khan12thesis], unlike the CVI method.
9
The joint distribution has the form
10
"N #
11 Y
12 p(y1:N , w|X) = p(yn |zn ) N (w|0, δI) (10.62)
13 n=1
14
where zn = wT xn is the local latent, and δ > 0 is the prior variance (analogous to an `2 regularizer).
15
We compute the local Gaussian likelihood terms λ̃n as in in Section 10.3.1.1. We then have the
16
following variational joint:
17
"N #
18 Y
19 q(w|λt+1 ) ∝ T
N (m̃n,t |w xn , ṽn,t ) N (w|0, δI) (10.63)
20 n=1
21 This corresponds to a Bayesian linear regression problem with pseudo-observations m̃n,t and variational
22 variance ṽn,t .
23
24
10.3.1.3 Example: Kalman smoothing with GLM likelihoods
25
26 We can extend the above examples to perform posterior inference in a linear-Gaussian state-space
27 model (Main Section 29.6) with generalized linear model (GLM) likelihoods: we alternate between
28 Gaussianizing the likelihoods and running the Kalman smoother (Main Section 8.2.3).
29
30 10.3.2 Structured mean for factorial HMMs
31
32 Consider the factorial HMM model [Ghahramani97] introduced in Main Section 29.5.3. Suppose
33 there are M chains, each of length T , and suppose each hidden node has K states, as shown in
34 Figure 10.5(a). We will derive a structured mean field algorithm that takes O(T M K 2 I) time, where
35 I is the number of mean field iterations (typically I ∼ 10 suffices for good performance).
36 We can write the exact posterior in the following form:
37 1
38 p(z|x) = exp(−E(z, x)) (10.64)
Z(x)
39
T
!T !
40 1X X X
E(z, x) = xt − Wm ztm Σ−1 xt − Wm ztm
41 2 t=1 m m
42
M
X T X
X M
43
− zT1m π̃ m − zTt−1,m Ãm zt,m (10.65)
44
m=1 t=2 m=1
45
46 where Ãm , log Am and π̃ m , log π m , where the log is applied elementwise.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
10.3. BEYOND MEAN FIELD

1
2
z1,1 z1,2 z1,3
3 z1,1 z1,2 z1,3 z1,1 z1,2 z1,3
4
5 z2,1 z2,2 z2,3
6 z2,1 z2,2 z2,3 z2,1 z2,2 z2,3
7 z3,1 z3,2 z3,3
8
9 z3,1 z3,2 z3,3
x1 x2 x3 z3,1 z3,2 z3,3
10
11 (a) (b) (c)
12
13 Figure 10.5: (a) A factorial HMM with 3 chains. (b) A fully factorized approximation. (c) A product-of-chains
14 approximation. Adapted from Figure 2 of [Ghahramani97].
15
16
17 We can approximate the posterior as a product of marginals, as in Figure 10.5(b), but a better
18 approximation is to use a product of chains, as in Figure 10.5(c). Each chain can be tractably
19 updated individually, using the forwards-backwards algorithm (Main Section 9.2.3). More precisely,
20 we assume
21 YM YT
1
22 q(z; ψ) = q(z1m ; ψ 1m ) q(ztm |zt−1,m ; ψ tm ) (10.66)
23
Zq (x) m=1 t=2
24 K
Y
25 q(z1m ; ψ 1m ) = (ψ1mk πmk )z1mk (10.67)
26 k=1
 ztmk
27 K K
Y Y
28 q(ztm |zt−1,m ; ψ tm ) = ψtmk (Amjk )zt−1,m,j  (10.68)
29 k=1 j=1
30
31
Here the variational parameter ψtmk plays the role of an approximate local evidence, averaging out
32
the effects of the other chains. This is in contrast to the exact local evidence, which couples all the
33
chains together.
34
By separating out the approximate local evidence terms, we can rewrite the above as q(z) =
Zq (x) exp(−Eq (z, x)), where
1
35
36
T X
X M M
X T X
X M
37
Eq (z, x) = − zTtm ψ̃ tm − zT1m π̃ m − zTt−1,m Ãm zt,m (10.69)
38
t=1 m=1 m=1 t=2 m=1
39
40 where ψ̃ tm = log ψ tm . We see that this has the same temporal factors as the exact log joint in
41 Equation (10.65), but the local evidence terms are different: the dependence on the visible data x
42 has been replaced by dependence on “virtual data” ψ.
43 The objective function is given by
44
DKL (q k p̃) = Eq [log q − log p̃] (10.70)
45
46 = −Eq [Eq (z, x)] − log Zq (x) + Eq [E(z, x)] + log Z(x) (10.71)
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


132

1
2 where q = q(z|x) and p̃ = p(z|x). In [Ghahramani97] they show that we can optimize this using
3 coordinate descent, where each update step is given by
4
 
1
5 ψ tm = exp WTm Σ−1 x̃tm − δ m (10.72)
2
6
7 δ m , diag(WTm Σ−1 Wm ) (10.73)
8 M
X
9 x̃tm , xt − W` E [zt,` ] (10.74)
10 `6=m
11 The intuitive interpretation of x̃tm is that it is the observation xt minus the predicted effect from
12 all the other chains apart from m. This is then used to compute the approximate local evidence,
13 ψ tm . Having computed the variational local evidence terms for each chain, we can perform forwards-
14 backwards in parallel, using these approximate local evidence terms to compute q(zt,m ) for each m
15 and t.
16 The update cost is O(T M K 2 ) for a full “sweep” over all the variational parameters, since we have
17 to run forwards-backwards M times, for each chain independently. This is the same cost as a fully
18 factorized approximation, but is much more accurate.
19
20
21
10.4 VI for graphical model inference
22
In this section, we discuss exact and approximate inference for discrete PGMs from a variational
23
perspective, following [Monster].
24
Similar to Section 9.3, we will assume a pairwise MRF of the form
25  
26
1 X X 
27 pθ (z|x) = exp θs (zs ) + θst (zs , zt ) (10.75)
Z  
28 s∈V (s,t)∈E
29
We can write this as an exponential family model, p(z|x) = p̃(z)/Z, where Z = log p(x), p̃(z) =
30
T (z)T θ, θ = ({θs;j }, {θs,t;j,k }) are all the node and edge parameters (the canonical parameters), and
31
T (z) = ({I (zs = j)}, {I (zs = j, zt = k)}) are all the node and edge indicator functions (the sufficient
32
statistics). Note: we use s, t ∈ V to index nodes and j, k ∈ X to index states.
33
34
35
10.4.1 Exact inference as VI
36 We know that the ELBO is a lower bound on the log marginal likelihood:
37
38
L(q) = Eq(z) [log p̃(z)] + H (q) ≤ log Z (10.76)
39 Let µ = Eq [T (z)] be the mean parameters of the variational distribution. Then we can rewrite this
40 as
41
42
L(µ) = θT µ + H (µ) ≤ log Z (10.77)
43 The set of all valid (unrestricted) mean parameters µ is the marginal polytope corresponding to
44 the graph, M(G), as explained in Section 9.3.2. Optimizing over this set recovers q = p, and hence
45
46
max θ T µ + H (µ) = log Z (10.78)
µ∈M(G)
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
10.4. VI FOR GRAPHICAL MODEL INFERENCE

1
Method Definition Objective Opt. Domain Section
2
Exact maxµ∈M(G) θ T µ + H (µ) = log Z Concave Marginal polytope, convex Section 10.4.1
3 Mean field maxµ∈MF (G) θ T µ + HMF (µ) ≤ log Z Concave Nonconvex inner approx. Section 10.4.2
4 Loopy BP maxτ ∈L(G) θ T τ + HBethe (τ ) ≈ log Z Non-concave Convex outer approx. Section 10.4.3
TRBP maxτ ∈L(G) θ T τ + HTRBP (τ ) ≥ log Z Concave Convex outer approx. Section 10.4.5
5
6
Table 10.1: Summary of some variational inference methods for graphical models. TRBP is tree-reweighted
7 belief propagation.
8
9
10
Equation (10.78) seems easy to optimize: the objective is concave, since it is the sum of a linear
11
function and a concave function (see Main Figure 5.4 to see why entropy is concave); furthermore,
12
we are maximizing this over a convex set, M(G). Hence there is a unique global optimum. However,
13
the entropy is typically intractable to compute, since it requires summing over all states. We discuss
14
approximations below. See Table 10.1 for a high level summary of the methods we discuss.
15
16
10.4.2 Mean field VI
17
18 The mean field approximation to the entropy is simply
19 X
20 HMF (µ) = H (µs ) (10.79)
21 s
22
which follows from the factorization assumption. Thus the mean field objective is
23
24 LMF (µ) = θT µ + HMF (µ) ≤ log Z (10.80)
25
26 This is a concave lower bound on log Z. We will maximize this over a a simpler, but non-convex,
27 inner approximation to M(G), as we now show.
28 First, let F be an edge subgraph of the original graph G, and let I(F ) ⊆ I be the subset of
29 sufficient statistics associated with the cliques of F . Let Ω be the set of canonical parameters for the
30 full model, and define the canonical parameter space for the submodel as follows:
31
Ω(F ) , {θ ∈ Ω : θα = 0 ∀α ∈ I \ I(F )} (10.81)
32
33 In other words, we require that the natural parameters associated with the sufficient statistics α
34 outside of our chosen class to be zero. For example, in the case of a fully factorized approximation,
35 F0 , we remove all edges from the graph, giving
36
37 Ω(F0 ) , {θ ∈ Ω : θst = 0 ∀(s, t) ∈ E} (10.82)
38
39
In the case of structured mean field (Main Section 10.4.1), we set θst = 0 for edges which are not in
40
our tractable subgraph.
41
Next, we define the mean parameter space of the restricted model as follows:
42
MF (G) , {µ ∈ Rd : µ = Eθ [T (z)] for some θ ∈ Ω(F )} (10.83)
43
44 This is called an inner approximation to the marginal polytope, since MF (G) ⊆ M(G). See
45 Figure 9.10(b) for a sketch. Note that MF (G) is a non-convex polytope, which results in multiple
46 local optima.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


134

1
2 Thus the mean field problem becomes
3
4 max θT µ + HMF (µ) (10.84)
µ∈MF (G)
5
6 This requires maximizing a concave objective over a non-convex set. It is typically optimized using
7 coordinate ascent, since it is easy to optimize a scalar concave function over the marginal distribution
8 for each node.
9
10
11 10.4.3 Loopy belief propagation as VI
12
Recall from Section 10.4.1 that exact inference can be posed as solving the following optimization
13
problem: maxµ∈M(G) θ T µ + H (µ), where M(G) is the marginal polytope corresponding to the graph
14
(see Section 9.3.2 for details). Since this set has exponentially many facets, it is intractable to
15
optimize over.
16
In Section 10.4.2, we discussed the mean field approximation, which uses a nonconvex inner
17
approximation, MF (G), obtained by dropping some edges from the graphical model, thus enforcing
18
a factorization of the posterior. We also approximated the entropy by using the entropy of each
19
marginal.
20
In this section, we will consider a convex outer approximation, L(G), based on pseudo marginals,
21
as in Section 9.3.3.1. We also need to approximate the entropy (which was not needed when
22
performing MAP estimation, discussed in Section 9.3.3). We discuss this entropy approximation in
23
Section 10.4.3.1, and then show how we can use this to approximate log Z. Finally we show that
24
loopy belief propagation attempts to optimize this approximation.
25
26
27 10.4.3.1 Bethe free energy
28
From Equation (9.67), we know that a joint distribution over a tree-structured graphical model can
29
be represented exactly by the following:
30
31 Y Y µst (xs , xt )
32 pµ (x) = µs (xs ) (10.85)
µs (xs )µt (xt )
33 s∈V (s,t)∈E

34
35
This satisfies the normalization and pairwise marginalization constraints of the outer approximation
36
by construction.
37
From Equation 10.85, we can write the exact entropy of any tree structured distribution µ ∈ M(T )
38
as follows:
39 X X
40
H (µ) = Hs (µs ) − Ist (µst ) (10.86)
s∈V (s,t)∈E
41
X
42 Hs (µs ) = − µs (xs ) log µs (xs ) (10.87)
43 xs ∈Xs
44 X µst (xs , xt )
45 Ist (µst ) = µst (xs , xt ) log (10.88)
µs (xs )µt (xt )
46 (xs ,xt )∈Xs ×Xt

47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
10.4. VI FOR GRAPHICAL MODEL INFERENCE

1
2 Note that we can rewrite the mutual information term in the form Ist (µst ) = Hs (µs ) + Ht (µt ) −
3 Hst (µst ), and hence we get the following alternative but equivalent expression:
4 X X
H (µ) = − (ds − 1)Hs (µs ) + Hst (µst ) (10.89)
5
s∈V (s,t)∈E
6
7 where ds is the degree (number of neighbors) for node s.
8 The Bethe3 approximation to the entropy is simply the use of Equation 10.86 even when we don’t
9 have a tree:
10
X X
HBethe (τ ) = Hs (τs ) − Ist (τst ) (10.90)
11
s∈V (s,t)∈E
12
13 We define the Bethe free energy as the expected energy minus approximate entropy:
 
14
FBethe (τ ) , − θ T τ + HBethe (τ ) ≈ − log Z (10.91)
15
16 Thus our final objective becomes
17 max θ T τ + HBethe (τ ) (10.92)
18 τ ∈L(G)

19 We call this the Bethe variational problem or BVP. The space we are optimizing over is a convex
20 set, but the objective itself is not concave (since HBethe is not concave). Thus there can be multiple
21 local optima. Also, the entropy approximation is not a bound (either upper or lower) on the true
22 entropy. Thus the value obtained by the BVP is just an approximation to log Z(θ). However, in the
23 case of trees, the approximation is exact. Also, in the case of models with attractive potentials, the
24 resulting value turns out to be an upper bound [Sudderth08]. In Section 10.4.5, we discuss how to
25 modify the algorithm so it always minimizes an upper bound for any model.
26
27 10.4.3.2 LBP messages are Lagrange multipliers
28
29 In this subsection, we will show that any fixed point of the LBP algorithm defines a stationary
30 point P
of the above constrained objective. Let us define the normalization constraint
P as Css (τ ) ,
31 −1 + xs τs (xs ), and the marginalization constraint as Cts (xs ; τ ) , τs (xs ) − xt τst (xs , xt ) for each
32 edge t → s. We can now write the Lagrangian as
X
33
L(τ , λ; θ) , θ T τ + HBethe (τ ) + λss Css (τ )
34
s
35 " #
X X X
36 + λts (xs )Cts (xs ; τ ) + λst (xt )Cst (xt ; τ ) (10.93)
37 s,t xs xt
38
(The constraint that τ ≥ 0 is not explicitly enforced, but one can show that it will hold at the
39
optimum since θ > 0.) Some simple algebra then shows that ∇τ L = 0 yields
40
X
41 log τs (xs ) = λss + θs (xs ) + λts (xs ) (10.94)
42 t∈nbr(s)
43
τst (xs , xt )
44 log = θst (xs , xt ) − λts (xs ) − λst (xt ) (10.95)
τ̃s (xs )τ̃t (xt )
45
46 3. Hans Bethe was a German-American physicist, 1906–2005.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


136

1 P
2 where we have defined τ̃s (xs ) , xt τ (xs , xt ). Using the fact that the marginalization constraint
3 implies τ̃s (xs ) = τs (xs ), we get
4
5
log τst (xs , xt ) = λss + λtt + θst (xs , xt ) + θs (xs ) + θt (xt )
X X
6 + λus (xs ) + λut (xt ) (10.96)
7 u∈nbr(s)\t u∈nbr(t)\s
8
9 To make the connection to message passing, define mt→s (xs ) = exp(λts (xs )). With this notation, we
10 can rewrite the above equations (after taking exponents of both sides) as follows:
11 Y
12 τs (xs ) ∝ exp(θs (xs )) mt→s (xs ) (10.97)
13 t∈nbr(s)
14 τst (xs , xt ) ∝ exp (θst (xs , xt ) + θs (xs ) + θt (xt ))
15 Y Y
16 × mu→s (xs ) mu→t (xt ) (10.98)
17 u∈nbr(s)\t u∈nbr(t)\s

18
where the λ terms and irrelevant constants are absorbed into the constant of proportionality. We see
19
that this is equivalent to the usual expression for the node and edge marginals in LBP.
20
To derive an equation for the messagesP in terms of other messages (rather than in terms of λts ),
21
we enforce the marginalization condition xt τst (xs , xt ) = τs (xs ). Then one can show that
22
23  
24
X Y
mt→s (xs ) ∝ exp {θst (xs , xt ) + θt (xt )} mu→t (xt ) (10.99)
25
xt u∈nbr(t)\s
26
27 We see that this is equivalent to the usual expression for the messages in LBP.
28
29
10.4.3.3 Kikuchi free energy
30
31 We have shown that LBP minimizes the Bethe free energy. In this section, we show that generalized
32 BP (Main Section 9.4.6) minimizes the Kikuchi free energy; we define this below, but the key
33 idea is that it is a tighter approximation to log Z.
34 In more detail, define Lt (G) to be the set of all pseudo-marginals such that normalization and
35 marginalization constraints hold on a hyper-graph whose largest hyper-edge is of size t + 1. For
36 example, in Main Figure 9.14, we impose constraints of the form
37 X X
38 τ1245 (x1 , x2 , x4 , x5 ) = τ45 (x4 , x5 ), τ56 (x5 , x6 ) = τ5 (x5 ), . . . (10.100)
39 x1 ,x2 x6

40
Furthermore, we approximate the entropy as follows:
41
X
42
HKikuchi (τ ) , c(g)Hg (τg ) (10.101)
43
g∈E
44
45 where Hg (τg ) is the entropy of the joint (pseudo) distribution on the vertices in set g, and c(g) is called
46 the overcounting number of set g. These are related to Mobius numbers in set theory. Rather
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
10.4. VI FOR GRAPHICAL MODEL INFERENCE

1
2 than giving a precise definition, we just give a simple example. For the graph in Main Figure 9.14,
3 we have
4
5 HKikuchi (τ ) = −[H1245 + H2356 + H4578 + H5689 ] − [H25 + H45 + H56 + H58 ] + H5 (10.102)
6
Putting these two approximations together, we can define the Kikuchi free energy4 as follows:
7
 
8
FKikuchi (τ ) , − θ T τ + HKikuchi (τ ) ≈ − log Z (10.103)
9
10 Our variational problem becomes
11
12 max θ T τ + HKikuchi (τ ) (10.104)
τ ∈L(G)
13
14 Just as with the Bethe free energy, this is not a concave objective. There are several possible
15 algorithms for finding a local optimum of this objective, including generalized belief propagation.
16 For details, see e.g., [Monster] or [KollerBook].
17
18
10.4.4 Convex belief propagation
19
20 The mean field energy functional is concave, but it is maximized over a non-convex inner approximation
21 to the marginal polytope. The Bethe and Kikuchi energy functionals are not concave, but they are
22 maximized over a convex outer approximation to the marginal polytope. Consequently, for both MF
23 and LBP, the optimization problem has multiple optima, so the methods are sensitive to the initial
24 conditions. Given that the exact formulation Equation (10.78) is a concave objective maximized over
25 a convex set, it is natural to try to come up with an appproximation of a similar form, without local
26 optima.
27 Convex belief propagation involves working with a set of tractable submodels, F, such as trees
28 or planar graphs. For each model F ⊂ G, the entropy is higher, H (µ(F )) ≥ H (µ(G)), since F has
29 fewer constraints. Consequently, any convex combination of such subgraphs will have higher entropy,
30 too:
31 X
32 H (µ(G)) ≤ ρ(F ) H (µ(F )) , H(µ, ρ) (10.105)
33 F ∈F
P
34
where ρ(F ) ≥ 0 and F ρ(F ) = 1. Furthermore, H(µ, ρ) is a concave function of µ.
35
Having defined an upper bound on the entropy, we now consider a convex outerbound on the
36
marginal polytope of mean parameters. We want to ensure we can evaluate the entropy of any vector
37
τ in this set, so we restrict it so that the projection of τ onto the subgraph G lives in the projection
38
of M onto F :
39
40 L(G; F) , {τ ∈ Rd : τ (F ) ∈ M(F ) ∀F ∈ F} (10.106)
41
42 This is a convex set since each M(F ) is a projection of a convex set. Hence we define our problem as
43
44 max τ T θ + H(τ , ρ) (10.107)
τ ∈L(G;F )
45
46 4. Ryoichi Kikuchi is a Japanese physicist.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


138

1
2 f f f f
3
4
5 b b b b
6
7
8 e e e e
9
10
11
Figure 10.6: (a) A graph. (b-d) Some of its spanning trees. From Figure 7.1 of [Monster]. Used with kind
12
permission of Martin Wainwright.
13
14
15
16 This is a concave objective being maximized over a convex set, and hence has a unique optimum.
17 Furthermore, the result is always an upper bound on log Z, because the entropy is an upper bound,
18 and we are optimizing over a larger set than the marginal polytope.
19 It remains to specify the set of tractable submodels, F, and the distribution ρ. We discuss some
20 options below.
21
22 10.4.5 Tree-reweighted belief propagation
23
In this section, we discuss tree reweighted BP [Wainwright05map; Kolmogorov06], which is
24
a form of convex BP which uses spanning trees as the set of tractable models F, as we describe
25
below.
26
27
28 10.4.5.1 Spanning tree polytope
29 It remains to specify the set of tractable submodels, F, and the distribution ρ. We will consider
30 the case where F is all spanning trees of a graph. For any given tree, the entropy is given by
31 Equation
P 10.86. To compute the upper bound, obtained by averaging over all trees, note that the
32 terms
P F ρ(F )H(µ(F )s ) for single nodes will just be Hs , since node s appears in every tree, and
33
F ρ(F ) = 1. But the mutual information term Ist receives weight ρst = Eρ [I ((s, t) ∈ E(T ))],
34 known as the edge appearance probability. Hence we have the following upper bound on the
35 entropy:
36
X X
37 H (µ) ≤ Hs (µs ) − ρst Ist (µst ) , HTRBP (µ) (10.108)
38 s∈V (s,t)∈E
39
40 This is called the tree reweighted BP approximation [Wainwright05map; Kolmogorov06].
41 This is similar to the Bethe approximation to the entropy except for the crucial ρst weights. So long
42 as ρst > 0 for all edges (s, t), this gives a valid concave upper bound on the exact entropy.
43 The edge appearance probabilities live in a space called the spanning tree polytope. This is
44 because they are constrained to arise from a distribution over trees. Figure 10.6 gives an example of
45 a graph and three of its spanning trees. Suppose each tree has equal weight under ρ. The edge f
46 occurs in 1 of the 3 trees, so ρf = 1/3. The edge e occurs in 2 of the 3 trees, so ρe = 2/3. The edge b
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
10.4. VI FOR GRAPHICAL MODEL INFERENCE

1
2 appears in all of the trees, so ρb = 1. And so on. Ideally we can find a distribution ρ, or equivalently
3 edge probabilities in the spanning tree polytope, that make the above bound as tight as possible. An
4 algorithm to do this is described in [Wainwright05]. A simpler approach is to use all single edges
5 with weight ρe = 1/E.
6 What about the set we are optimizing over? We require µ(T ) ∈ M(T ) for each tree T , which
7 means enforcing normalization and local consistency. Since we have to do this for every tree, we are
8 enforcing normalization and local consistency on every edge. Thus we are effectively optimizing in
9 the pseudo-marginal polytope L(G). So our final optimization problem is as follows:
10
11 max τ T θ + HTRBP (τ ) ≥ log Z (10.109)
τ ∈L(G)
12
13
14
10.4.5.2 Message passing implementation
15 The simplest way to minimize Equation (10.109) is a modification of belief propagation known as
16 tree reweighted belief propagation. The message from t to s is now a function of all messages
17 sent from other neighbors v to t, as before, but now it is also a function of the message sent from s
18 to t. Specifically, we have the following [Monster]:
19
X  Q ρvt
v∈nbr(t)\s [mv→t (xt )]
20 1
21 mt→s (xs ) ∝ exp θst (xs , xt ) + θt (xt ) (10.110)
x
ρst [ms→t (xt )]1−ρts
22 t

23
At convergence, the node and edge pseudo marginals are given by
24
25 Y
τs (xs ) ∝ exp(θs (xs )) [mv→s (xs )]ρvs (10.111)
26
v∈nbr(s)
27 Q ρvs
Q ρvt
28 v∈nbr(s)\t [mv→s (xs )] v∈nbr(t)\s [mv→t (xt )]
τst (xs , xt ) ∝ ϕst (xs , xt ) (10.112)
29 [mt→s (xs )]1−ρst [ms→t (xt )]1−ρts
30
 
1
31 ϕst (xs , xt ) , exp θst (xs , xt ) + θs (xs ) + θt (xt ) (10.113)
ρst
32
33 If ρst = 1 for all edges (s, t) ∈ E, the algorithm reduces to the standard LBP algorithm. However,
34 the condition ρst = 1 implies every edge is present in every spanning tree with probability 1, which is
35 only possible if the original graph is a tree. Hence the method is only equivalent to standard LBP on
36 trees, when the method is of course exact.
37 In general, this message passing scheme is not guaranteed to converge to the unique global optimum.
38 One can devise double-loop methods that are guaranteed to converge [Hazan08], but in practice,
39 using damped updates as in Equation Main Equation (9.77) is often sufficient to ensure convergence.
40
41
10.4.5.3 Max-product version
42
43 We can modify TRBP to solve the MAP estimation problem (as opposed to estimating posterior
44 marginals) by replacing sums with products in Equation (10.110) (see [Monster] for details). This
45 is guaranteed to converge to the LP relaxation discussed in Section 9.3.3 under a suitable scheduling
46 known as sequential tree-reweighted message passing [Kolmogorov06].
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


140

1
2 10.4.6 Other tractable versions of convex BP
3
It is possible to upper bound the entropy using convex combinations of other kinds of tractable
4
models besides trees. One example is a planar MRF (one where the graph has no P edges that cross),
5
with binary nodes and no external field, i.e., the model has the form p(x) ∝ exp( (s,t)∈E θst xs xt ).
6
It turns out that it is possible to perform exact inference in this model. Hence one can use convex
7
combinations of such graphs which can sometimes yield more accurate results than TRBP, albeit
8
at higher computational cost. See [Globerson07] for details, and [Schraudolph10] for a related
9
exact method for planar Ising models.
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
11 Monte Carlo Inference
Markov Chain Monte Carlo (MCMC)
12 inference
13 Sequential Monte Carlo (SMC) inference

13.1 More applications of particle filtering


In this section, we give some examples of particle filtering applied to some state estimation problems
in different kinds of state-space models. We focus on using the simplest kind of SMC algorithm,
namely the bootstrap filter (Main Section 13.2.3.1).

13.1.1 1d pendulum model with outliers


In this section, we consider the pendulum example from Section 8.2.2. Rather than Gaussian
observation noise, we assume that some fraction p = 0.4 of the observations are outliers, coming from
a Unif(−2, 2) distribution. (These could represent a faulty sensor, for example.) In this case, the
bootstrap filter is more robust than deterministic filters, as shown in Figure 13.1, since it can handle
multimodal posteriors induced by uncertainty about which observations are signal and which are
noise. By contrast, EKF and UKF assume a unimodal (Gaussian) posterior, which is very sensitive
to outliers.

13.1.2 Visual object tracking


In Main Section 13.1.2, we tracked an object given noisy measurements of its location, as estimated
by some kind of distance sensor. A harder problem is to track an object just given a sequence of
frames from a video camera. This is called visual tracking. In this section we consider an example
where the object is a remote-controlled helicopter [Nummiaro03]. We will use a simple linear
motion model for the centroid of the object, and a color histogram for the likelihood model, using
Bhattacharya distance to compare histograms.
Figure 13.2 shows some example frames. The system uses S = 250 particles, with an effective
sample size of 134. (a) shows the belief state at frame 1. The system has had to resample 5 times to
keep the effective sample size above the threshold of 150; (b) shows the belief state at frame 251;
the red lines show the estimated location of the center of the object over the last 250 frames. (c)
shows that the system can handle visual clutter (the hat of the human operator), as long as it
does not have the same color as the target object; (d) shows that the system is confused between the
grey of the helicopter and the grey of the building (the posterior is bimodal, but the green ellipse,
representing the posterior mean and covariance, is in between the two modes); (e) shows that the
probability mass has shifted to the wrong mode: i.e., the system has lost track; (f) shows the particles
spread out over the gray building; recovery of the object is very unlikely from this state using this
146

1
2
3
4 2.0 2.0 2.0

5 1.5 1.5 1.5

1.0 1.0 1.0


6 0.5 0.5 0.5

0.0 0.0 0.0


7
−0.5 −0.5 −0.5

8 −1.0 −1.0 −1.0

−1.5 Extended KF (noisy) −1.5 Unscented KF (noisy) −1.5 Bootstrap PF (noisy)


9 true angle true angle true angle
−2.0 −2.0 −2.0
0.0 0.5 1.0 1.5 2.0 2.5 3.0 0.0 0.5 1.0 1.5 2.0 2.5 3.0 0.0 0.5 1.0 1.5 2.0 2.5 3.0
10
11
(a) (b) (c)
12
13 Figure 13.1: Filtering algorithms applied to the noisy pendulum model with 40% outliers (shown in red). (a)
14 EKF. (b) UKF. (c) Boostrap filter. Generated by pendulum_1d.ipynb.
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29 (a) (b) (c)
30
31
32
33
34
35
36
37
38
39
(d) (e) (f )
40
41 Figure 13.2: Example of particle filtering applied to visual object tracking, based on color histograms. Blue
42 dots are posterior samples, green ellipse is Gaussian approximation to posterior. (a-c) Succesful tracking. (d):
43 Tracker gets distracted by an outlier gray patch in the background, and moves the posterior mean away from
44 the object. (e-f ): Losing track. See text for details. Used with kind permission of Sebastien Paris.
45
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
13.1. MORE APPLICATIONS OF PARTICLE FILTERING

1
2 proposal.
3 The simplest way to improve performance of this method is to use more particles. A more efficient
4 approach is to perform tracking by detection, by running an object detector over the image
5 every few frames, and to use these as proposals (see Main Section 13.3). This provides a way to
6 combine discriminative, bottom-up object detection (which can fail in the presence of occlusion)
7 with generative, top-down tracking (which can fail if there are unpredictable motions, or new objects
8 entering the scene). See e.g., [Hess2009; VanGool2009; Gurkan2019; Okada2019] for further
9 details.
10
11
13.1.3 Online parameter estimation
12
13 It is tempting to use particle filtering to perform online Bayesian inference for the parameters of a
14 model p(yt |xt , θ), just as we did using the Kalman filter for linear regression (Main Section 29.7.2) and
15 the EKF for MLPs (Main Section 17.5.2). However, this technique will not work. The reason is that
16 static parameters correspond to a dynamical model with zero system noise, p(θt |θt−1 ) = δ(θt − θt−1 ).
17 However, such a deterministic model causes problems for particle filtering, because the particles can
18 only be reweighted by the likelihood, but cannot be moved by the deterministic transition model.
19 Thus the diversity in the trajectories rapidly goes to zero, and the posterior collapses [Kantas2015].
20 It is possible to add artificial process noise, but this causes the influence of earlier observations
21 to decay exponentially with time, and also “washes out” the initial prior. In Main Section 13.6.3, we
22 present a solution to this problem based on SMC samplers, which generalize the particle filter by
23 allowing static variables to be turned into a sequence by adding auxiliary random variables.
24
25
13.1.4 Monte Carlo robot localization
26
27 Consider a mobile robot wandering around an indoor environment. We will assume that it already
28 has a map of the world, represented in the form of an occupancy grid, which just specifies whether
29 each grid cell is empty space or occupied by an something solid like a wall. The goal is for the robot to
30 estimate its location. (See also Main Section 13.4.3, where we discuss the problem of simultaneously
31 localizing and mapping the environment.) This can be solved optimally using an HMM filter (also
32 called a histogram filter [Jonschkowski2016]), since we are assuming the state space is discrete.
33 However, since the number of states, K, is often very large, the O(K 2 ) time complexity per update is
34 prohibitive. We can use a particle filter as a sparse approximation to the belief state. This is known
35 as Monte Carlo localization [Thrun06].
36 Figure 13.3 gives an example of the method in action. The robot uses a sonar range finder, so
37 it can only sense distance to obstacles. It starts out with a uniform prior, reflecting the fact that
38 the owner of the robot may have turned it on in an arbitrary location. (Figuring out where you are,
39 starting from a uniform prior, is called global localization.) After the first scan, which indicates
40 two walls on either side, the belief state is shown in (b). The posterior is still fairly broad, since the
41 robot could be in any location where the walls are fairly close by, such as a corridor or any of the
42 narrow rooms. After moving to location 2, the robot is pretty sure it must be in a corridor and not a
43 room, as shown in (c). After moving to location 3, the sensor is able to detect the end of the corridor.
44 However, due to symmetry, it is not sure if it is in location I (the true location) or location II. (This
45 is an example of perceptual aliasing, which refers to the fact that different things may look the
46 same.) After moving to locations 4 and 5, it is finally able to figure out precisely where it is (not
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


148

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
Figure 13.3: Illustration of Monte Carlo localization for a mobile robot in an office environment using a sonar
24
sensor. From Figure 8.7 of [Thrun06]. Used with kind permission of Sebastian Thrun.
25
26
27
28 shown). The whole process is analogous to someone getting lost in an office building, and wandering
29 the corridors until they see a sign they recognize.
30
31
32
13.2 Particle MCMC methods
33
In this section, we discuss some other sampling techniques that leverage the fact that SMC can give an
34
unbiased estimate of the normalization constant Z for the target distribution. This can be useful for
35
sampling with models where the exact likelihood is intractable. These are called pseudo-marginal
36
methods [Andrieu2009].
37
To be more precise, note that the SMC algorithm can be seen as mapping a stream of random
1:Ns
38
numbers u into a set of samples, z1:T . We need random numbers u1:N z,1:T to specify the hidden
s

39
states that are sampled at each step (using the inverse CDF of the proposal), and random numbers
40
u1:N
a,1:T −1 to contol the ancestor indices that are chosen (using the resampling algorithm), where each
s

41
uiz,t , uia,t ∼ Unif(0, 1). The normalization constant is also a function of these random numbers, so we
42
43
denote it Ẑt (u), where
44
Yt Ns
45 1 X
Ẑt (u) = w̃tn (u) (13.1)
46
s=1
Ns n=1
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
13.2. PARTICLE MCMC METHODS

1
2 One can show (see e.g., [Naesseth2019]) that
3 h i
4 E Ẑt (u) = Zt (13.2)
5
6 where the expectation is wrt the distribution of u, denoted τ (u). (Note that u can be represented
7 by a random seed.) This allows us to plug SMC inside other MCMC algorithms, as we show below.
8 Such methods are often used by probabilistic programming systems (see e.g., [Zhou2020pps]),
9 since PPLs often define models with many latent variable models defined implicitly (via sampling
10 statements), as discussed in Main Section 4.6.6.
11
12
13.2.1 Particle Marginal Metropolis Hastings
13
14 Suppose we want to compute the parameter posterior R p(θ|y) = p(y|θ)p(θ)/p(y) for a latent variable
15 model with prior p(θ) and likelihood p(y|θ) = p(y, h|θ)dh, where h are latent variables (e.g., from
16 a SSM). We can use Metropolis Hastings (Main Section 12.2) to avoid having to compute the partition
17 function p(y). However, in many cases it is intractable to compute the likelihood p(y|θ) itself, due
18 to the need to integrate over h. This makes it hard to compute the MH acceptance probability
19  
20 p(y|θ 0 )p(θ 0 )q(θ j−1 |θ 0 )
A = min 1, (13.3)
21 p(y|θ j−1 )p(θ j−1 )q(θ 0 |θ j−1 )
22
23 where θ j−1 is the parameter vector at iteration j − 1, and we are proposing θ 0 from q(θ 0 |θ j−1 ).
24 However, we can use SMC to compute Ẑ(θ) as an unbiased approximation to p(y|θ), which can be
25 used to evaluate the MH acceptance ratio:
26 !
27 Ẑ(u0 , θ 0 )p(θ 0 )q(θ j−1 |θ 0 )
A = min 1, (13.4)
28 Ẑ(uj−1 , θ j−1 )p(θ j−1 )q(θ 0 |θ j−1 )
29
30 More precisely, we apply MH to an extended space, where we sample both the parameters θ and the
31 randomness u for SMC.
32 We can generalize the above to return samples of the latent states as well as the latent parameters,
33 by sampling a single trajectory from
34
Ns
X
35
36
p(h1:T |θ, y) ≈ p̂(h|θ, y, u) = WTi δ(h1:T − hi1:T ) (13.5)
i=1
37
38 by using the internal samples generated by SMC. Thus we can sample θ and h jointly. This is
39 called the particle marginal Metropolis Hastings (PMMH) algorithm [Andrieu2010]. See
40 Algorithm 13.1 for the pseudocode. See e.g. [Dahlin2015] for more practical details.
41
42
43
13.2.2 Particle Independent Metropolis Hastings
44 Now suppose we just want to sample the latent states h, with the parameters θ being fixed. In
45 this case we can simplify PMMH algorithm by not sampling θ. Since the latent states h are now
46 sampled independently of the state of the Markov chain, this is called the particle independent
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


150

1
2 Algorithm 13.1: Particle Marginal Metropolis-Hastings
3
1 for j = 1 : J do
4
2 Sample θ 0 ∼ q(θ 0 |θ j−1 ), u0 ∼ τ (u0 ), h0 ∼ p̂(h0 |θ 0 , y, u0 )
5
3 Compute Ẑ(u0 , θ 0 ) using SMC
6
4 Compute A using Equation (13.4)
7
5 Sample u ∼ Unif(0, 1)
8
6 if u < A then
9
7 Set θ j = θ 0 , uj = u0 , hj = h0
10
8 else
11
9 Set θ j = θ j−1 , uj = uj−1 , hj = hj−1
12
13
14
15 Algorithm 13.2: Particle Independent Metropolis-Hastings
16 1 for j = 1 : J do
17 2 Sample u0 ∼ τ (u0 ), h0 ∼ p̂(h0 |θ, y, u0 )
18
3 Compute Ẑ(u0 , θ) using
 SMC 
19 0
Ẑ(u ,θ)
20 4 Compute A = min 1, Ẑ(u j−1 ,θ)

21 5 Sample u ∼ Unif(0, 1)
22 6 if u < A then
23 7 Set uj = u0 , hj = h0
24 8 else
25 9 Set uj = uj−1 , hj = hj−1
26
27
28
29 MH algorithm. The acceptance ratio term also simplifies, since we can drop all terms involving θ.
30 See Algorithm 13.2 for the pseudocode.
31 One might wonder what the advantage of PIMH is over just using SMC. The answer is that PIMH
32 can return unbiased estimates of smoothing expectations, such as
33 Z
34 π(ϕ) = ϕ(h1:T )π(h1:T |θ, y)dh1:T (13.6)
35
36
whereas estimating this directly with SMC results in a consistent but biased estimate (in contrast to
37
the estimate of Z, which is unbiased). For details, see [Middleton2019].
38
39
13.2.3 Particle Gibbs
40
41 In PMMH, we define a transition kernel that, given (θ (j−1) , h(j−1) ), generates a sample (θ (j) , h(j) ),
42 while leaving the target distribution invariant. Another way to perform this task is to use particle
43 Gibbs sampling, which avoids needing to specify any proposal distributions. In this approach, we
−1
44 first sample N − 1 trajectories h1:N1:T ∼ p(h|θ (j−1) , y) using conditional SMC, keeping the N ’th
45 trajectory fixed at the retained particle hN1:T = h
(j−1)
. We then sample a new value for h(j) from the
46 empirical distribution π̂T (h1:T ). Finally we sample θ ∼ p(θ|h(j) ). For details, see [Andrieu2010].
1:N (j)

47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
13.2. PARTICLE MCMC METHODS

1
2 Another variant, known as particle Gibbs with ancestor sampling, is discussed in [Lindsten2014];
3 it is particularly well-suited to state-space models.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


Part III

Prediction
14 Predictive models: an overview
15 Generalized linear models

15.1 Variational inference for logistic regression


In this section we discuss a variational approach to Bayesian inference for logistic regression models
based on local bounds to the likelihood. We will use a Gaussian prior, p(w) = N (w|µ0 , V0 ). We
will create a “Gaussian-like” lower bound to the likelihood, which becomes conjugate to this prior.
We then iteratively improve this lower bound.

15.1.1 Binary logistic regression


In this section, we discuss VI for binary logistic regression. Our presentation follows [BishopBook].
Let us first rewrite the likelihood for a single observation as follows:

p(yn |xn , w) = σ(ηn )yn (1 − σ(ηn ))1−ηn (15.1)


 yn  1−yn
1 1
= 1− (15.2)
1 + e−ηn 1 + e−ηn
−ηn
e
= e−ηn yn = e−ηn yn σ(−ηn ) (15.3)
1 + e−ηn

where ηn = wT xn are the logits. This is not conjugate to the Gaussian prior. So we will use the
following “Gaussian-like” variational lower bound to the sigmoid function, proposed in [Jaakkola96b;
Jaakkola00]:
 
σ(ηn ) ≥ σ(ψ n ) exp (ηn − ψ n )/2 − λ(ψ n )(ηn2 − ψ 2n ) (15.4)

where ψ n is the variational parameter for datapoint n, and


 
1 1 1
λ(ψ) , tanh(ψ/2) = σ(ψ) − (15.5)
4ψ 2ψ 2

We shall refer to this as the JJ bound, after its inventors, Jaakkola and Jordan. See Figure 15.1(a)
for a plot, and see Section 6.3.4.2 for a derivation.
Using this bound, we can write
 
p(yn |xn , w) = e−ηn yn σ(−ηn ) ≥ e−ηn yn σ(ψ n ) exp (−ηn + ψ n )/2 − λ(ψ n )(ηn2 − ψ 2n ) (15.6)
158

1
JJ bound, χ=2.5 Bohning bound, χ=−2.5
2 1 1

0.9 0.9
3
0.8 0.8

4 0.7 0.7

5 0.6 0.6

0.5 0.5

6 0.4 0.4

7 0.3 0.3

0.2 0.2
8
0.1 0.1

9 0
−6 −4 −2 0 2 4 6
0
−6 −4 −2 0 2 4 6

10
11 (a) (b)
12
13 Figure 15.1: Quadratic lower bounds on the sigmoid (logistic) function. In solid red, we plot σ(x) vs x. In
14 dotted blue, we plot the lower bound L(x, ψ) vs x for ψ = 2.5. (a) JJ bound. This is tight at ψ = ±2.5. (b)
15
Bohning bound (Section 15.1.2.2). This is tight at ψ = 2.5. Generated by sigmoid_lower_bounds.ipynb.
16
17
We can now lower bound the log joint as follows:
18
1
19
log p(y|X, w) + log p(w) ≥ − (w − µ0 )T V0−1 (w − µ0 ) (15.7)
20 2
N
X
21  
22 + ηn (yn − 1/2) − λ(ψ n )wT (xn xTn )w (15.8)
23 n=1

24 Since this is a quadratic function of w, we can derive a Gaussian posterior approximation as follows:
25
26
q(w|ψ) = N (w|µN , VN ) (15.9)
N
!
27 X
28 µN = VN V0−1 µ0 + (yn − 1/2)xn (15.10)
29 n=1
N
X
30
−1
31
VN = V0−1 + 2 λ(ψ n )xn xTn (15.11)
n=1
32
33 This is more flexible than a Laplace approximation, since the variational parameters ψ can be used
34 to optimize the curvature of the posterior covariance. To find the optimal ψ, we can maximize the
35 ELBO, which is given by
36
Z Z
37 log p(y|X) = log p(y|X, w)p(w)dw ≥ log h(w, ψ)p(w)dw = Ł(ψ) (15.12)
38
39 where
N
Y
40  
41 h(w, ψ) = σ(ψ n ) exp ηn yn − (ηn + ψ n )/2 − λ(ψ n )(ηn2 − ψ 2n ) (15.13)
42 n=1
43 We can evaluate the lower bound analytically to get
N  
44
45 1 |VN | 1 T −1 1 X 1
Ł(ψ) = log + µN VN µN − µT0 V0−1 µ0 + log σ(ψ n ) − ψ n + λ(ψ n )ψ 2n (15.14)
46 2 |V0 | 2 2 n=1
2
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
15.1. VARIATIONAL INFERENCE FOR LOGISTIC REGRESSION

1
2 If we solve for ∇ψ Ł(ψ) = 0, we get the following iterative update equation for each variational
3 parameter:
4   
5
(ψ new 2
n ) = xn E ww
T
xn = xn VN + µN µTN xn (15.15)
6
One we have estimated ψ n , we can plug it into the above Gaussian approximation q(w|ψ).
7
8
9
15.1.2 Multinomial logistic regression
10 In this section we discuss how to approximate the posterior p(w|D) for multinomial logistic regression
11 using variational inference, extending the approach of Main Section 15.3.8 to the multi-class case.
12 The key idea is to create a “Gaussian-like” lower bound on the multi-class logistic regression likelihood
13 due to [Bohning92]. We can then compute the variational posterior in closed form. This will let us
14 deterministically optimize the ELBO.
15 Let yi ∈ {0, 1}C be a one-hot label vector, and define the logits for example i to be
16
17 η i = [xTi w1 , . . . , xTi wC ] (15.16)
18
19 If we define Xi = I ⊗ xi , where ⊗ is the kronecker product, and I is C × C identity matrix,
20 then we can write the logits as η i = Xi w. (For example, if C = 2 and xi = [1, 2, 3], we have
21 Xi = [1, 2, 3, 0, 0, 0; 0, 0, 0, 1, 2, 3].) Then the likelihood is given by
22
N
Y
23
p(y|X, w) = exp[yTi η i − lse(η i )] (15.17)
24
i=1
25
26 where lse() is the log-sum-exp function
27
C
!
28
X
lse(η i ) , log exp(ηic ) (15.18)
29
c=1
30
31 For identifiability, we can set wC = 0, so
32
M
!
33
X
lse(η i ) = log 1 + exp(ηim ) (15.19)
34
m=1
35
36 where M = C − 1. (We subtract 1 so that in the binary case, M = 1.)
37
38 15.1.2.1 Bohning’s quadratic bound to the log-sum-exp function
39
40
The above likelihood is not conjugate to the Gaussian prior. However, we will now can convert it to
41
a quadratic form. Consider a Taylor series expansion of the log-sum-exp function around ψ i ∈ RM :
42 1
43 lse(η i ) = lse(ψ i ) + (η i − ψ i )T g(ψ i ) + (η i − ψ i )T H(ψ i )(η i − ψ i ) (15.20)
2
44
g(ψ i ) = exp[ψ i − lse(ψ i )] = softmax(ψ i ) (15.21)
45
46 H(ψ i ) = diag(g(ψ i )) − g(ψ i )g(ψ i ) T
(15.22)
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


160

1
2 where g and H are the gradient and Hessian of lse, and ψ i ∈ RM , where M = C − 1 is the number
3 of classes minus 1. An upper bound to lse can be found by replacing the Hessian matrix H(ψ i ) with
4 a matrix Ai such that Ah i  H(ψ i ) for alli ψi . [Bohning92] showed that this can be achieved if we
5 use the matrix Ai = 21 IM − M1+1 1M 1TM . In the binary case, this becomes Ai = 12 (1 − 12 ) = 14 .
6
Note that Ai is independent of ψ i ; however, we still write it as Ai (rather than dropping the i
7
subscript), since other bounds that we consider below will have a data-dependent curvature term.
8
The upper bound on lse therefore becomes
9
1 T
10 lse(η i ) ≤η Ai η i − bTi η i + ci (15.23)
11 2 i 
1 1
12
Ai = IM − 1M 1TM (15.24)
13 2 M +1
14 bi = Ai ψ i − g(ψ i ) (15.25)
15
1
16 ci = ψTi Ai ψ i − g(ψ i )T ψ i + lse(ψ i ) (15.26)
2
17
18 where ψ i ∈ RM is a vector of variational parameters.
19 We can use the above result to get the following lower bound on the softmax likelihood:
 
20 1 T
21
T T
log p(yi = c|xi , w) ≥ yi Xi w − 4 w Xi Ai Xi w + bi Xi w − ci (15.27)
2 c
22
23 To simplify notation, define the pseudo-measurement
24 ỹi , A−1
i (bi + yi ) (15.28)
25
26 Then we can get a “Gaussianized” version of the observation model:
27
p(yi |xi , w) ≥ f (xi , ψ i ) N (ỹi |Xi w, A−1
i ) (15.29)
28
29 where f (xi , ψ i ) is some function that does not depend on w. Given this, it is easy to compute the
30 posterior q(w) = N (mN , VN ), using Bayes rule for Gaussians.
31 Given the posterior, we can write the ELBO as follows:
32
"N #
X
33 Ł(ψ) , −DKL (q(w) k p(w)) + Eq log p(yi |xi , w) (15.30)
34 i=1
"N #
35 X
36 = −DKL (q(w) k p(w)) + Eq yTi η i − lse(η i ) (15.31)
37 i=1
38 N
X N
X
39 = −DKL (q(w) k p(w)) + yTi Eq [η i ] − Eq [lse(η i )] (15.32)
40 i=1 i=1
41 where p(w) = N (w|m0 , V0 ) is the prior and q(w) = N (w|mN , VN ) is the approximate posterior.
42 The first term is just the KL divergence between two Gaussians, which is given by
43
1
44 −DKL (N (mN , VN ) k N (m0 , V0 )) = − tr(VN V0−1 ) − log |VN V0−1 |
45
2 
46 +(mN − m0 )T V0−1 (mN − m0 ) − DM (15.33)
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
15.1. VARIATIONAL INFERENCE FOR LOGISTIC REGRESSION

1
2 where DM is the dimensionality of the Gaussian, and we assume a prior of the form p(w) =
3 N (m0 , V0 ), where typically µ0 = 0DM , and V0 is block diagonal. The second term is simply
4
5 N
X N
X
6 yTi Eq [η i ] = yTi m̃i (15.34)
7 i=1 i=1
8
9 where m̃i , Xi mN . The final term can be lower bounded by taking expectations of our quadratic
10 upper bound on lse as follows:
11
12 N
X 1 1
13 − Eq [lse(η i )] ≥ − tr(Ai Ṽi ) − m̃i Ai m̃i + bTi m̃i − ci (15.35)
i=1
2 2
14
15
16 where Ṽi , Xi VN XTi . Hence we have
17
18 1 
Ł(ψ) ≥ − tr(VN V0−1 ) − log |VN V0−1 | + (mN − m0 )T V0−1 (mN − m0 )
19 2
20 XN
1 1 1
21 − DM + yTi m̃i − tr(Ai Ṽi ) − m̃i Ai m̃i + bTi m̃i − ci (15.36)
22
2 i=1
2 2
23
24 We will use coordinate ascent to optimize this lower bound. That is, we update the variational
25 posterior parameters VN and mN , and then the variational likelihood parameters ψ i . We leave the
26 detailed derivation as an exercise, and just state the results. We have
27
N
!−1
28 X
29 VN = V0 + XTi Ai Xi (15.37)
30 i=1
N
!
31 X
32 mN = VN V0−1 m0 + XTi (yi + bi ) (15.38)
33 i=1
34 ψ i = m̃i = Xi mN (15.39)
35
36 We can exploit the fact that Ai is a constant matrix, plus the fact that Xi has block structure, to
37 simplify the first two terms as follows:
38
39 N
!−1
X
40
VN = V0 + A ⊗ xi xTi (15.40)
41 i=1
42 N
!
X
43 mN = VN V0−1 m0 + (yi + bi ) ⊗ xi (15.41)
44 i=1
45
46 where ⊗ denotes the kronecker product.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


162

1
2 15.1.2.2 Bohning’s bound in the binary case
3
If we have binary data, then yi ∈ {0, 1}, M = 1 and ηi = wT xi where w ∈ RD is a weight vector
4
(not matrix). In this case, the Bohning bound becomes
5
6 1 2
7 log(1 + eη ) ≤ aη − bη + c (15.42)
2
8 1
9 a= (15.43)
4
10
b = aψ − (1 + e−ψ )−1 (15.44)
11
1
12 c = aψ 2 − (1 + e−ψ )−1 ψ + log(1 + eψ ) (15.45)
13 2
14
It is possible to derive an alternative quadratic bound for this case. as shown in Section 6.3.4.2.
15
This has the following form
16
17 1
18 log(1 + eη ) ≤ λ(ψ)(η 2 − ψ 2 ) + (η − ψ) + log(1 + eψ ) (15.46)
2  
19
1 1 1
20 λ(ψ) , tanh(ψ/2) = σ(ψ) − (15.47)
4ψ 2ψ 2
21
22 To facilitate comparison with Bohning’s bound, let us rewrite the JJ bound as a quadratic form as
23 follows
24
25 1
log(1 + eη ) ≤ a(ψ)η 2 − b(ψ)η + c(ψ) (15.48)
26 2
27 a(ψ) = 2λ(ψ) (15.49)
28 1
29 b(ψ) = − (15.50)
2
30 1
31 c(ψ) = −λ(ψ)ψ 2 − ψ + log(1 + eψ ) (15.51)
2
32
33 The JJ bound has an adaptive curvature term, since a depends on ψ. In addition, it is tight at two
34 points, as is evident from Figure 15.1(a). By contrast, the Bohning bound is a constant curvature
35 bound, and is only tight at one point, as is evident from Figure 15.1(b). Nevertheless, the Bohning
36 bound is simpler, and somewhat faster to compute, since VN is a constant, independent of the
37 variational parameters Ψ.
38
39
15.1.2.3 Other bounds
40
41 It is possible to devise bounds that are even more accurate than the JJ bound, and which work for
42 the multiclass case, by using a piecewise quadratic upper bound to lse, as described in [Marlin11].
43 By increasing the number of pieces, the bound can be made arbitrarily tight.
44 It is also possible to come up with approximations that are not bounds. For example, [Shekhovtsov2019]
45 gives a simple approximation for the output of a softmax layer when applied to a stochastic input
46 (characterized in terms of its first two moments).
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
15.2. CONVERTING MULTINOMIAL LOGISTIC REGRESSION TO POISSON REGRESSION

1
2 15.2 Converting multinomial logistic regression to Poisson regression
3
4 It is possible to represent a multinomial logistic regression model with K outputs as K separate
5 Poisson regression models. (Although the Poisson models are fit separately, they are implicitly
6 coupled, since the counts must sum to Nn across all K outcomes.) This fact can enable more efficient
7 training when the number of categories is large [Taddy2015].
8 To see why this relationship is true, we follow the presentation of [rethinking2]. We assume
9 K = 2 for notational brevity (i.e., binomial regression). Assume we have m trials, with counts y1
10 and y2 of each outcome type. The multinomial likelihood has the form
11 m! y1 y2
p(y1 , y2 |m, µ1 , µ2 ) = µ µ (15.52)
12
y1 !y2 ! 1 2
13
14 Now consider a product of two Poisson likelihoods, for each set of counts:
15
e−λ1 λy11 e−λ2 λy22
16 p(y1 , y2 |λ1 , λ2 ) = p(y1 |λ1 )p(y2 |λ2 ) = (15.53)
y1 ! y2 !
17
18 We now show that these are equivalent, under a suitable setting of the parameters.
19 Let Λ = λ1 + λ2 be the expected total number of counts of any type, µ1 = λ1 /Λ and µ2 = λ2 /Λ.
20 Substituting into the binomial likelihood gives
21  y1  y2
m! λ1 λ2 m! λy1 λy2
22 p(y1 , y2 |m, µ1 , µ2 ) = = y1 y2 1 2 (15.54)
23
y1 !y2 ! Λ Λ Λ Λ y1 ! y2 !
24 m! e−λ1 λy11 e−λ2 λy22
= (15.55)
25 Λm e−λ1 y1 ! e−λ2 y2 !
26 m! e−λ1 λy11 e−λ2 λy22
27 = −Λ m (15.56)
28
|e {zΛ } | y{z
1! y2 !
} | {z }
p(m)−1 p(y1 ) p(y2 )
29
30 The final expression says that p(y1 , y2 |m) = p(y1 )p(y2 )/p(m), which makes sense.
31
32
15.2.1 Beta-binomial logistic regression
33
34 In some cases, there is more variability in the observed counts than we might expect from just
35 a binomial model, even after taking into account the observed predictors. This is called over-
36 dispersion, and is usually due to unobserved factors that are omitted from the model. In such cases,
37 we can use a beta-binomial model instead of a binomial model:
38
yi ∼ BetaBinom(mi , αi , βi ) (15.57)
39
40 αi = πi κ (15.58)
41 βi = (1 − πi )κ (15.59)
42 T
πi = σ(w xi ) (15.60)
43
44 Note that we have parameterized the model in terms of its mean rate,
45 αi
46 πi = (15.61)
αi + βi
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


164

1
2 and shape,
3
4
κi = αi + βi (15.62)
5
We choose to make the mean depend on the inputs (covariates), but to treat the shape (which is like
6
a precision term) as a shared constant.
7
As we discuss in [book1], the beta-binomial distribution as a continuous mixture distribution of
8
the following form:
9
Z
10
BetaBinom(y|m, α, β) = Bin(y|m, µ)Beta(µ|α, β)dµ (15.63)
11
12
13 In the regression context, we can interpret this as follows: rather than just predicting the mean
14 directly, we predict the mean and variance. This allows for each individual example to have more
15 variability than we might otherwise expect.
16 If the shape parameter κ is less than 2, then the distribution is an inverted U-shape which strongly
17 favors probabilities of 0 or 1 (see Main Figure 2.3b). We generally want to avoid this, which we can
18 do by ensuring κ > 2.
19 Following [rethinking2], let us use this model to reanalyze the Berkeley admissions data from ??.
20 We saw that there was a lot of variability in the outcomes, due to the different admissions rates of each
21 department. Suppose we just regress on the gender, i.e., xi = (I (GENDERi = 1) , I (GENDERi = 2)),
22 and w = (α1 , α2 ) are the corresponding logits. If we use a binomial regression model, we can be
23 misled into thinking there is gender bias. But if we use the more robust beta-binomial model, we
24 avoid this false conclusion, as we show below.
25 We fit the following model:
26
Ai ∼ BetaBinom(Ni , πi , κ) (15.64)
27
28 logit(πi ) = αGENDER[i] (15.65)
29 αj ∼ N (0, 1.5) (15.66)
30 κ=φ+2 (15.67)
31
φ ∼ Expon(1) (15.68)
32
33 (To ensure that κ > 2, we use a trick and define it as κ = φ + 2, where we put an exponential prior
34 (which has a lower bound of 0) on φ.)
35 We fit this model (using HMC) and plot the results in Figure 15.2. In Figure 15.2a, we show the
36 posterior predictive distribution; we see that is quite broad, so the model is no longer overconfident.
37 In Figure 15.2b, we plot p(σ(αj )|D), which is the posterior over the rate of admissions for men and
38 women. We see that there is considerable uncertainty in these value, so now we avoid the false
39 conclusion that one is significantly higher than the other. However, the model is so vague in its
40 predictions as to be useless. In Section 15.2.3, we fix this problem by using a multi-level logistic
41 regression model.
42
43
15.2.2 Poisson regression
44
45 Let us revisit the Berkeley admissions example from ?? using Poisson regression. We use a simplified
46 form of the model, in which we just model the outcome counts without using any features, such as
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
15.2. CONVERTING MULTINOMIAL LOGISTIC REGRESSION TO POISSON REGRESSION

1
2 distribution of admission rates
3.0
male
3 female
0.8 2.5
4
0.6 2.0
5

Density
1.5
6 0.4
7 1.0
0.2
8 0.5
9 0.0 0.0
2 4 6 8 10 12 0.0 0.2 0.4 0.6 0.8 1.0
10 probability admit

11 (a) (b)
12
13 Figure 15.2: Results of fitting beta-binomial regression model to Berkeley admissions data. (b) Posterior
14 predictive distribution (black) superimposed on empirical data (blue). The hollow circle is the posterior
15 predicted mean acceptance rate, E [Ai |D]; the vertical lines are 1 standard deviation around this mean,
16
std [Ai |D]; the + signs indicate the 89% predictive interval. (b) Samples from the posterior distribution for
the admissions rate for men (blue) and women (red). Thick curve is posterior mean. Adapted from Figure
17
12.1 of [rethinking2]. Generated by logreg_ucb_admissions_numpyro.ipynb.
18
19
20
21
gender or department. That is, the model has the form
22
yj,n ∼ Poi(λj ) (15.69)
23
24
λj = e αj
(15.70)
25 αj ∼ N (0, 1.5) (15.71)
26
27 for j = 1 : 2 and n = 1 : 12. Let λi = E [λi |Di ], where D1 = y1,1:N is the vector of admission
28 counts, and D2 = y2,1:N is the vector of rejection counts (so mn = y1,n + y2,n is the total number of
29 applications for case n). The expected acceptance rate across the entire dataset is
30
λ1 146.2
31 = = 0.38 (15.72)
32 λ1 + λ2 146.2 + 230.9
33
Let us compare this to a binomial regression model of the form
34
35 yn ∼ Bin(mn , µ) (15.73)
36
µ = σ(α) (15.74)
37
α ∼ N (0, 1.5) (15.75)
38
39
Let α = E [α|D], where D = (y1,1:N , m1:N ). The expected acceptance rate across the entire dataset
40
is σ(α) = 0.38, which matches Equation (15.72). (See logreg_ucb_admissions_numpyro.ipynb for
41
the code.)
42
43
44
15.2.3 GLMM (hierarchical Bayes) regression
45 Let us revisit the Berkeley admissions dataset from ??, where there are 12 examples, corresponding
46 to male and female admissions to 6 departments. Thus the data is grouped both by gender and
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


166

1
2 department. Recall that Ai is the number of students admitted in example i, Ni is the number
3 of applicants, µi is the expected rate of admissions (the variable of interest), and DEPT[i] is the
4 department (6 possible values). For pedagogical reasons, we replace the categorical variable GENDER[i]
5 with the binary indicator MALE[i]. We can create a model with varying intercept and varying
6 slope as follows:
7
Ai ∼ Bin(Ni , µi ) (15.76)
8
9
logit(µi ) = αDEPT[i] + βDEPT[i] × MALE[i] (15.77)
10 This has 12 parameters, as does the original formulation in ??. However, these are not independent
11 degrees of freedom. In particular, the intercept and slope are correlated, as we see in ?? (higher
12 admissions means steeper slope). We can capture this using the following prior:
13  
α
14
(αj , βj ) ∼ N ( , Σ) (15.78)
15 β
16 α ∼ N (0, 4) (15.79)
17
β ∼ N (0, 1) (15.80)
18
19 Σ = diag(σ) R diag(σ) (15.81)
20 R ∼ LKJ(2) (15.82)
21 2
Y
22 σ∼ N+ (σd |0, 1) (15.83)
23 d=1
24
25
We can write this more compactly in the following way.1 . We define u = (α, β), and wj = (αj , βj ),
26
and then use this model:
27 log(µi ) = wDEPT[i] [0] + wDEPT[i] [1] × MALE[i] (15.84)
28
wj ∼ N (u, Σ) (15.85)
29
30 u ∼ N (0, diag(4, 1)) (15.86)
31
See Figure 15.3(a) for the graphical model.
32
Following the discussion in Main Section 12.6.5, it is advisable to rewrite the model in a non-centered
33
form. Thus we write
34
35 wj = u + σLzj (15.87)
36
where L = chol(R) is the Cholesy factor for the correlation matrix R, and zj ∼ N (0, I2 ). Thus the
37
model becomes the following:2 .
38
39 zj ∼ N (0, I2 ) (15.88)
40 vj = diag(σ)Lzj (15.89)
41
u ∼ N (0, diag(4, 1)) (15.90)
42
43 log(µi ) = u[0] + v[DEPT[i], 0] + (u[1] + v[DEPT[i], 1]) × MALE[i] (15.91)
44
1. In https://bit.ly/3mP1QWH, this is referred to as glmm4. Note that we use w instead of v, and we use u instead
45 of v .
µ
46 2. In https://bit.ly/3mP1QWH, this is referred to as glmm5.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
15.2. CONVERTING MULTINOMIAL LOGISTIC REGRESSION TO POISSON REGRESSION

1
Posterior Predictive Check with 90% CI
2 di mi actual rate
mean ± std
0.8
3
4 σ R 0.7

5 αdi βd i 0.6
6 u Σ
0.5

admit rate
7
wj µi 0.4
8
9 0.3
J
Ai
10 0.2

11 0.1
Ni
12
N 2 4 6 8 10 12
13 cases

14 (a) (b)
15
16 Figure 15.3: (a) Generalized linear mixed model for inputs di (department) and mi (male), and output Ai
17 (number of admissionss), given Ni (number of applicants). (b) Results of fitting this model to the UCB
18 dataset. Generated by logreg_ucb_admissions_numpyro.ipynb.
19
20
21 This is the version of the model that is implemented in the numypro code.
22 The results of fitting this model are shown in Figure 15.3(b). The fit is slightly better than in ??,
23 especially for the second column (females in department 2), where the observed value is now inside
24 the predictive interval.
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


16 Deep neural networks

16.1 More canonical examples of neural networks


16.1.1 Transformers
The high level structure is shown in Figure 16.1. We explain the encoder and decoder below.

16.1.1.1 Encoder
The details of the transformer encoder block are shown in Figure 16.2. The embedded input tokens
X are passed through an attention layer (typically multi-headed), and the output Z is added to the
input X using a residual connection. More precisely, if the input X = (x1 , . . . , xn ) for xi ∈ Rd , we
compute the following [Yun2020iclr]:
n
X
xi = xi + Kij WV xj (16.1)
j=1

where K = softmax(A), and Aij = (WQ xi )T (WK xj ). (In [Sander2022], they explore a variant
of this, known as sinkformer, where they use Sinkhorn’s algorithm to ensure K is stocchastically
normalized across columns as well as rows.) The output of self attention is then passed into a layer
normalization layer, which normalizes and learns an affine transformation for each dimension, to
ensure all hidden units have comparable magnitude. (This is necessary because the attention masks
might upweight just a few locations, resulting in a skewed distribution of values.) Then the output
vectors at each location are mapped through an MLP, composed of 1 linear layer, a skip connection
and a normalization layer.
The overall encoder is N copies of this encoder block. The result is an encoding Hx ∈ RTx ×D
of the input, where Tx is the number of input tokens, and D is the dimensionality of the attention
vectors.

16.1.1.2 Decoder
Once the input has been encoded, the output is generated by the decoder. The first part of the
decoder is the decoder attention block, that attends to all previously generated tokens, y1:t−1 , and
computes the encoding Hy ∈ RTy ×D . This block uses masked attention, so that output t can only
attend to locations prior to t in Y.
170

1
2
3
4
5
6
7
8
9
10
11
Figure 16.1: High level structure of the encoder-decoder transformer architecture. From https: // jalammar.
12
github. io/ illustrated-transformer/ . Used with kind permission of Jay Alammar.
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 Figure 16.2: The encoder block of a transformer for two input tokens. From https: // jalammar. github.
31 io/ illustrated-transformer/ . Used with kind permission of Jay Alammar.
32
33
34
The second part of the decoder is the encoder-decoder attention block, that attends to both
35
the encoding of the input, Hx , and the previously generated outputs, Hy . These are combined to
36
compute Z = Attn(Q = Hy , K = Hx , V = Hx ), which compares the output to the input. The joint
37
encoding of the state Z is then passed through an MLP layer. The full decoder repeats this decoder
38
block N times.
39
At the end of the decoder, the final output is mapped to a sequence of Ty output logits via a final
40
linear layer.
41
42
16.1.1.3 Putting it all together
43
44 We can combine the encoder and decoder as shown in Figure 16.3. There is one more detail we need
45 to discuss. This concerns the fact that the attention operation pools information across all locations,
46 so the transformer is invariant to the ordering of the inputs. To overcome this, it is standard to add
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
16.1. MORE CANONICAL EXAMPLES OF NEURAL NETWORKS

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
Figure 16.3: A transformer model where we use 2 encoder blocks and 2 decoder blocks. (The second decoder
block is not expanded.) We assume there are 2 input and 2 output tokens. From https: // jalammar. github.
22
io/ illustrated-transformer/ . Used with kind permission of Jay Alammar.
23
24
25
26 positional encoding vectors to the input tokens x ∈ RTx ×D . That is, we replace x with x + u,
27 where u ∈ RTx ×D is a (possibly learned) vector, where ui is some encoding of the fact that xi comes
28 from the i’th location in the N -dimensional sequence.
29
30
31
16.1.2 Graph neural networks (GNNs)
32 In this section, we discuss graph neural networks or GNNs. Our presentation is based on
33 [Sanchez-lengeling2021], which in turn is a summary of the message passing neural network
34 framework of [gilmer2017neural] and the Graph Nets framework of [battaglia2018relational].
35 We assume the graph is represent as a set of N nodes or vertices, each associated with a feature
36 vector to create the matrix V ∈ RN ×Dv ; a set of E edges, each associated with a feature vector to
37 create the matrix E ∈ RE×De ; and a global feature vector u ∈ RDu , representing overall properties
38 of the graph, such as its size. (We can think of u as the features associated with a global or master
39 node.) The topology of the graph can be represented as an N × N adjacency matrix, but since
40 this is usually very sparse (see Figure 16.4 for an example), a more compact representation is just to
41 store the list of edges in an adjaceny list (see Figure 16.5 for an example).
42
43
16.1.2.1 Basics of GNNs
44
45 A GNN adopts a “graph in, graph out” philosophy, similar to how transformers map from sequences
46 to sequences. A basic GNN layer updates the embedding vectors associated with the nodes, edges
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


172

2-1

4-3
0-0
1-0
2-0
3-0
4-0
0-1
1-1

3-1
4-1
0-2
1-2
2-2
3-2
4-2
0-3
1-3
2-3
3-3

0-4
1-4
2-4
3-4
4-4
2 0-0
1-0 2-0 3-0
1-0 0-0 4-0
3 0-0 1-0 2-0 3-0 4-0 2-0
3-0
4-0
4 0-1 1-1 2-1 3-1
1-1 0-1 4-1
5 0-1 1-1 2-1 3-1 4-1 2-1
3-1
4-1
6 0-2
1-2
7 0-2 1-2 2-2 3-2 4-2 2-2 0-2 1-2 2-2 3-2 4-2
3-2
4-2
8 0-3
1-3
9 0-3 1-3 2-3 3-3 4-3 2-3
3-3 4-3
0-3 1-3 3-3
4-3 2-3
10 0-4
1-4
11 0-4 1-4 2-4 3-4 4-4 2-4
3-4 4-4
0-4
4-4 1-4 2-2 3-4
12
13 Image Pixels Adjacency Matrix Graph
14
15 Figure 16.4: Left: Illustration of a 5 × 5 image, where each pixel is either off (light yellow) or on (dark yellow).
16
Each non-border pixel has 8 nearest neighbors. We highlight the node at location (2,1), where the top-left is
(0,0). Middle: The corresponding adjacency matrix, which is sparse and banded. Right: Visualization of the
17
graph structure. Dark nodes correspind to pixels that are on, light nodes correspond to pixels that are off.
18
Dark edges correspond to the neighbors of the node at (2,1). From [Sanchez-lengeling2021]. Used with
19
kind permission of Benjamin Sanchez-Lengeling.
20
21
5 Nodes
22 6
[0, 1, 1, 0, 0, 1, 1, 1]
23 4 Edges
24 7 [2, 1, 1, 1, 2, 1, 1]
25
Adjacency List
26
3 [[1, 0], [2, 0], [4, 3], [6, 2],
27 [7, 3], [7, 4], [7, 5]]
0
28 Global
29 2
1 0
30
31 Figure 16.5: A simple graph where each node has 2 types (0=light yellow, 1=dark yellow), each edge has 2
32 types (1=gray, 2=blue), and the global feature vector is a constant (0=red). We represent the topology using an
33 adjaceny list. From [Sanchez-lengeling2021]. Used with kind permission of Benjamin Sanchez-Lengeling.
34
Layer N Layer N+1
35 graph in graph out

36 Un Un+1
37
38 Vn Vn+1
39
40 En En+1
41
Graph Independent Layer
42
43 update function

44
45
Figure 16.6: A basic GNN layer. We update the embedding vectors U, V and V using the global, node and edge
functions f . From [Sanchez-lengeling2021]. Used with kind permission of Benjamin Sanchez-Lengeling.
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
16.1. MORE CANONICAL EXAMPLES OF NEURAL NETWORKS

1
2
3
4
5
6
+ +
7
8
9
Aggregate information Aggregate information
10 from adjacent edges from adjacent edges

11 (a) (b)
12
13 Figure 16.7: Aggregating edge information into two different nodes. From [Sanchez-lengeling2021]. Used
14 with kind permission of Benjamin Sanchez-Lengeling.
15
16 Input Graph GNN Layers Transformed Graph Classification layer Prediction
node, edge, or global
17 predictions

18
19
20
21
22
23
Figure 16.8: An end-to-end GNN classifier. From [Sanchez-lengeling2021]. Used with kind permission of
24
Benjamin Sanchez-Lengeling.
25
26
27
28
29
and whole graph, as illustrated in Figure 16.6. The update functions are typically simple MLPs, that
30
are applied independently to each embedding vector.
31
To leverage the graph structure, we can combine information using a pooling operation. That is,
32
for each node n, we extract the feature vectors associated with its edges, and combine it with its
33
local feature vector using a permutation invariant operation such as summation or averaging. See
34
Figure 16.7 for an illustration. We denote this pooling operation by ρEn →Vn . We can similarly pool
35
from nodes to edges, ρVn →En , or from nodes to globals, ρVn →Un , etc.
36
The overall GNN is composed of GNN layers and pooling layers. At the end of the network, we
37
can use the final embeddings to classify nodes, edges, or the whole graph. See Figure 16.8 for an
38
illustration.
39
40
41
16.1.2.2 Message passing
42
43 Instead of transforming each vector independently and then pooling, we can first pool the information
44 for each node (or edge) and then update its vector representation. That is, for node i, we gather
45 information from all neighboring nodes, {hj : j ∈ nbr(i)}; we aggregate these vectors with the
46 local vector using an operation such as sum; and then we compute the new state using an update
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


174

1
Layer N Layer N+1
2 graph in graph out
3
Un Un+1
4
5
6
Vn Vn+1
7
8 En En+1
9
10
Graph Nets Layer

11
update function
12
13
aggregation function

14
Figure 16.9: Message passing in one layer of a GNN. First the global node Un and the local nodes Vn send
15
messages to the edges En , which get updated to give En+1 . Then the nodes get updated to give Vn+1 . Finally
16
the global node gets updated to give Un+1 . From [Sanchez-lengeling2021]. Used with kind permission of
17
Benjamin Sanchez-Lengeling.
18
19
20
21
function, such as
22 X
23
h0i = ReLU(Uhi + Vhj ) (16.2)
j∈nbr(i)
24
25
See Figure 16.11a for a visualization.
26
The above operation can be viewed as a form of “message passing”, in which the values of
27
neighboring nodes hj are sent to node i and then combined. It is more general than belief propagation
28
(Main Section 9.3), since the messages are not restricted to represent probability distributions (see
29
Main Section 9.4.10 for more discussion).
30
After K message passing layers, each node will have received information from neighbors which
31
are K steps away in the graph. This can be “short circuited” by sending messages through the global
32
node, which acts as a kind of bottleneck. See Figure 16.9 for an illustration.
33
34
35
16.1.2.3 More complex types of graphs
36 We can easily generalize this framework to handle other graph types. For example, multigraphs
37 have multiple edge types between each pair of nodes. For example, in a knowledge graph, we
38 might have edge types “spouse-of”, “employed-by” or “born-in”. See Figure 16.10(left) for an example.
39 In hypergraphs, each edge may connect more than two nodes. For example, in a knowledge graph,
40 we might want to specify the three-way relation “parents-of(c, m, f)”, for child c, mother m and father
41 f . We can “reify” such hyperedges into hypernodes, as shown in Figure 16.10(right).
42
43
16.1.2.4 Graph attention networks
44
45 When performing message passing, we can generalize the linear combination used in Equation (16.2)
46 to use a weighted combination instead, where the weights are computed an attention mecha-
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
16.1. MORE CANONICAL EXAMPLES OF NEURAL NETWORKS

1
2 Multigraph
3 (Heterogeneous edges)
4 Global Node
Hypernode
5 to global
edges
6
7
Hypernodes
8 Node to
to hypernode
9 edges
10
11 Edge types Nodes
12
13
14
15
Figure 16.10: Left: a multigraph can have different edge types. Right: a hypergraph can have edges which
connect multiple nodes. From [Sanchez-lengeling2021]. Used with kind permission of Benjamin Sanchez-
16
Lengeling.
17
18
19
20
nism (Main Section 16.2.7). The resulting model is called a graph attention network or GAT
21
[velickovic2017graph]. This allows the effective topology of the graph to be context dependent.
22
23 16.1.2.5 Transformers are fully connected GNNs
24
Suppose we create a fully connected graph in which each node represents a word in a sentence. Let
25
us use this to construct a GNN composed of GAT layers, where we use multi-headed scaled dot
26
product attention. Suppose we combine each GAT block with layer normalization and an MLP. The
27
resulting block is shown in Figure 16.11b. We see that this is identical to the transformer encoder
28
block shown in Figure 16.2. This construction shows that transformers are just a special case of
29
GNNs [Joshi2020].
30
The advantage of this observation is that it naturally suggests ways to overcome the O(N 2 )
31
complexity of transformers. For example, in Transformer-XL [transformer-xl], we create blocks
32
of nodes, and connect these together, as shown in Figure 16.12(top right). In binary partition
33
transformer or BPT [BPT], we also create blocks of nodes, but add them as virtual “hypernodes”,
34
as shown in Figure 16.12(bottom). There are many other approaches to reducing the O(N 2 ) cost
35
(see e.g., [book1]), but the GNN perspective is a helpful one.
36
37
38
39
40
41
42
43
44
45
46
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


176

1
2
3
4
5 Token-wise
Feedforward
6 sub-layer

7
8
9
10
11
12 Multi-head
Attention
sub-layer
13
Legend
14
Non-learnt Operations
15 Learnable Weights
Softmax Operation
16 Learnable Normalization
Non-linear Activation
17
18
19
(a) (b)
20
21 Figure 16.11: (a) A graph neural network aggregation block. Here h`i is the hidden representation for node i
in layer `, and N (i) are i’s neighbors. The output is given by h`+1 = ReLU(U` hi + j∈nbr(i) V` h`j ). (b) A
P
22
i
23 transformer encoder block. Here h`i is the hidden representation for word i in layer `, and S are all the words
24 in the sentence. The output is given by h`+1i = Attn(Q` h`i , {K` hj , V` h`j }). From [Joshi2020]. Used with
25 kind permission of Chaitanya Joshi.
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41 Figure 16.12: Graph connectivity for different types of transformer. Top left: in a vanilla Transformer,
42 every node is connected to every other node. Top right: in Transformer-XL, nodes are grouped into blocks.
43 Bottom: in BPT, we use a binary partitioning of the graph to create virtual node clusters. From https: //
44 graphdeeplearning. github. io/ post/ transformers-are-gnns/ . Used with kind permission of Chaitanya
45 Joshi.
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
17 Bayesian neural networks

17.1 More details on EKF for training MLPs


The suggestion to use the EKF to train MLPs was first made in [Singhal1988]. We give a summary
below, based on [Puskorius2003].

17.1.1 Global EKF


On the left, we use notation from [Puskorius2003], and on the right we use our notation.
ŵk = µk|k−1 (17.1)
Pk = Σk|k−1 (17.2)
Hk = Jac(h)(µk|k−1 ) (17.3)
Ak = S−1
k = (Rk + Hk Σk|t−1 HTk )−1 (17.4)
Kk = Σk|k−1 HTk S−1
k (17.5)
ξ k = yk − h(µk|k−1 ) (17.6)
ŵk+1 = µk+1|k = µk|k = µk|k−1 + Kk ξ k (17.7)
Pk+1 = Σk+1|k = Σk|k + Qk = Σk|k−1 − Kk Hk Σk|k−1 + Qk (17.8)
Suppose there are N outputs and M parameters (size of wk ), so Hk is N × M . Computing the
matrix inversion S−1 T −1
k takes O(N ) time, computing the matrix multiplication Hk Sk takes O(M N )
3 2
T −1
time, computing the matrix multiplication Σk|k−1 (Hk Sk ) takes O(M N ) time, and computing
2

the matrix multiplication Hk Σk|t−1 HTk takes O(N 2 M + N M 2 ) time. Computing the Jacobian
takes O(N M ) time. which is N times slower than standard backprop (which uses a scalar output
representing the loss). The total time is therefore O(N 3 + N 2 M + N M 2 ). The memory usage is
O(M 2 ).
The learning rate of the algorithm is controlled by the artificial process noise, Qk = qI. [Puskorius2003]
recommend annealing this from a large to a small value over time, to ensure more rapid convergence.
(We should keep q > 0 to ensure the posterior is always positive definite.)

17.1.2 Decoupled EKF


We can speed up the method using the “decoupled EKF” [Puskorius1991; Murtuza1994], which
partitions the posterior covariance into groups. We give a summary below, based on [Puskorius2003].
178

1
2 Suppose there are g groups, and let µit|t represent the mean of the i’th group (of size Mi ), Σik|k its
3 covariance, Hik the Jacobian wrt the i’th groups weights (of size N × Mi ) and Kik the corresponding
4 Kalman gain. Then we have (in our notation)
5
6 µik|k−1 = µik−1|k−1 (17.9)
7
Σik|k−1 = Σik−1|k−1 + Qik−1 (17.10)
8
Xg
9
Sk = Rk + (Hjk )T Σjk|k−1 Hjk (17.11)
10
j=1
11
12 Kik = Σik|k−1 Hik S−1
k (17.12)
13 ξk = yk − h(µk|k−1 ) (17.13)
14 i
µk|k = µik|k−1 + Kik ξ k (17.14)
15
16 Σik|k = Σik|k−1 − Kik Hik Σik|k−1 (17.15)
17
(17.16)
18
Pg
19 ThePtime complexity is reduced to O(N 3 + N 2 M + N i=1 Mi2 ), and the space complexity
g
20 is O( i=1 Mi2 ). The term “fully decoupled” refers to a diagonal approximation to the posterior
21 covariance, which is similar in spirit to diagonal pre-conditioning methods such as Adam. The term
22 “node decoupled EKF” refers to a block diagonal approximation, where the blocks correspond to all
23 the weights feeding into a single neuron (since these are highly correlated).
24 In [Puskorius2003], they give a serial scheme for reducing the complexity when N is large (e.g.,
PG
25
multi-class classification). The new time complexity is O(N 2 G + i=1 Mi2 ), where G is the number
26
of nodes in the network.
27
28
29
17.1.3 Mini-batch EKF
30 A “multi-stream” (i.e., minibatch) extension was presented in [Feldkamp1994]. As explained in
31 [Puskorius2003], this amounts to stacking Ns observations into a single large observation vector,
32 denoted yk:l , where l = k + Ns − 1, and then stacking the Jacobians Hk:l . We then perform the
33 update (possibly decoupled) as above. Note that this is more expensive than just averaging gradients,
34 as is done by mini-batch SGD.
35 Minibatch EKF is potentially less accurate than updating after each example, since the linearization
36 is computed at the previous posterior, µk−1 , even for examples at the end of the minibatch, namely
37 at time l  k. Furthermore, it may be more expensive, due to the need to invert Sk , which has size
38 N Ns × N Ns . However, it may be less sensitive to the ordering of the data.
39
40
41
42
43
44
45
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
18 Gaussian processes

18.1 Deep GPs


A deep Gaussian process or DGP is a composition of GPs [Damianou2013]. (See [Jakkala2021]
for a recent survey.) More formally, a DGP of L layers is a hierachical model of the form
(1) (Hi ) (j)
DGP(x) = fL ◦ · · · ◦ f1 (x), fi (·) = [fi (·), . . . , fi (·)], fi ∼ GP(0, Ki (·, ·)) (18.1)

This is similar to a deep neural network, except the hidden nodes are now hidden functions.
A natural question is: what is gained by this approach compared to a standard GP? Although
conventional single-layer GPs are nonparametric, and can model any function (assuming the use of a
non-degenerate kernel) with enough data, in practice their performance is limited by the choice of
kernel. This can be partially overcome by using a DGP, as we show in Section 18.1.0.2. Unfortunately,
posterior inference in DGPs is challenging, as we discuss in Section 18.1.0.3.
In Section 18.1.0.4, we discuss the expressive power of infinitely wide DGPs, and in Section 18.1.0.5
we discuss connections with DNNs.

18.1.0.1 Construction of a deep GP


In this section we give an example of a 2 layer DGP, following the presentation in [Pleiss2021].
(j)
Let f1 ∼ GP(0, K1 ) for j = 1 : H1 , where H1 is the number of hidden units, and f2 ∼ GP(0, K2 ).
Assume we have labeled training data X = (x1 , . . . , xN ) and y = (y1 , . . . , yN ). Define F1 =
[f1 (x1 ), . . . , f1 (xN )] and f2 = [f2 (f1 (x1 )), . . . , f2 (f1 (xN ))]. Let x∗ be a test input and define
f1∗ = f1 (x∗ ) and f2= f2 (f1 (x∗ )). The corresponding joint distribution over all the random variables
is given by

p(f2∗ , f2 , F1 , f1 , y) = p(f2∗ |f2 , f1∗ , F1 )p(f2 |F1 , f1∗ )p(f1∗ , F1 )p(y|f2 ) (18.2)

where we drop the dependence on X and x∗ for brevity. This is illustrated by the graphical model in
Figure 18.1, where we define K2 = K2 (F1 , F1 ), k2 ∗ = K2 (F1 , f1∗ ), and k2∗∗ = K2 (f1∗ , f1∗ ).

18.1.0.2 Example: 1d step function


Suppose we have data from a piecewise constant function. (This can often happen when modeling
certain physical processes, which can exhibit saturation effects.) Figure 18.2a shows what happens if
we fit data from such a step function using a standard GP with an RBF (Gaussian) kernel. Obviously
180

1
2 X F1 K2 f2 y
3
4
5
6 k∗2

7
8
9 x∗ f∗1 k2∗∗ f2∗
10
11
Figure 18.1: Graphical model corresponding to a deep GP with 2 layers. The dashed nodes are deterministic
12
funcitons of their parents, and represent kernel matrices. The shaded nodes are observed, the unshaded nodes
13
are hidden. Adapted from Figure 5 of [Pleiss2021].
14
15 Mean
Data
16 Confidence
1 Samples
17
1
18
y

19 0

y
Mean
20 Data
Confidence 0
21
−1
22
23 −2 −1 0 1 2 −2 −1 0 1 2
24 x x
25 (a) (b)
26
27 Figure 18.2: Some data (red points) sampled from a step function fit with (a) a standard GP with RBF kernel
28 and (b) a deep GP with 4 layers of RBF kernels. The solid blue line is the posterior mean. The pink shaded
29 area represents the posterior variance (µ(x) ± 2σ(x)). The thin blue dots in (b) represent posterior samples.
30
Generated by deepgp_stepdata.
31
32
33 this method oversmooths the function and does not “pick up on” the underlying discontinuinity. It is
34 possible to learn kernels that can capture such discontinuous (non-stationary) behavior by learning
35 to warp the input with a neural net before passing into the RBF kernel (see Main Figure 18.26).
36 Another approach is to learn a sequence of smooth mappings which together capture the overall
37 complex behavior, analogous to the approach in deep learning. Suppose we fit a 4 layer DGP with a
38 single hidden unit at each layer; we will use an RBF kernel. Thus the kernel at level 1 is K1 (x, x0 ) =
39 exp(−||x − x0 ||2 /(2D)), the kernel at level 2 is K2 (f1 (x), f1 (x0 )) = exp(−||f1 (x) − f1 (x0 )||2 /(2H1 )),
40 etc.
41 We can perform posterior inference in this model to compute p(f∗ |x∗ , X, y) for a set of test
42 points x∗ (see Section 18.1.0.3 for the details). Figure 18.2b shows the resulting posterior predictive
43 distribution. We see that the predictions away from the data capture two plausible modes: either
44 the signal continues at the level y = 0 or at y = 1. (The posterior mean, shown by the solid blue line,
45 is a poor summary of the predictive distribution in this case, since it lies between these two modes.)
46 This is an example of non-trivial extrapolation behavior outside of the support of the data.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
18.1. DEEP GPS

1
2
3 1 1
4

layer2
5 0
layer1

6 0
7 −1
8 −1
9
−2 −1 0 1 2 −2 −1 0 1 2
10
input layer1
11
(a) (b)
12
13
14 1 1
15

output
layer3

16 0
0 Mean
17 Confidence
Data
18 −1
19
−1
20
−2 −1 0 1 2 −2 −1 0 1 2
21
layer2 layer3
22
(c) (d)
23
24 Figure 18.3: Illustration of the functions learned at each layer of the DGP. (a) Input to layer 1. (b) Layer 2
25 to layer 2. (c) Layer 2 to layer 3. (d) Layer 3 to output. Generated by deepgp_stepdata.ipynb
26
27
28 Figure 18.3 shows the individual functions learned at each layer (these are all maps from 1d to 1d).
29 We see that the functons are individually smooth (since they are derived from an RBF kernel), but
30 collectively they define non-smooth behavior.
31
32
18.1.0.3 Posterior inference
33
34 In Equation (18.2), we defined the joint distribution defined by a (2 layer) DGP. We can condition
35 on y to convert this into a joint posterior, as follows:
36
37 p(f2∗ , f2 , F1 , f1 |y) = p(f2∗ |f2 , f1∗ , F1 , y)p(f2 |F1 , f1∗ , y)p(f1∗ , F1 |y) (18.3)
38 = p(f2∗ |f2 , f1∗ , F1 )p(f2 |F1 , y)p(f1∗ , F1 |y) (18.4)
39
40 where the simplifications in the second line follow from the conditional independencies encoded in
41 Figure 18.1. Note that f2 and f2∗ depend on F1 and f1∗ only through K2 , k2∗ and k2∗∗ , where
42
p(f2 |K2 ) ∼ N (0, K2 ), p(f2∗ |k2∗∗ , k2∗ , K2 , f2 ) ∼ N ((k2∗ )T K−1 ∗∗ ∗ T −1 ∗
2 f2 , k2 − (k2 ) K2 k2 ) (18.5)
43
44
Hence
45
46 p(f2∗ , f2 , F1 , f1 |y) = p(f2∗ |f2 , K2 , k2∗ , k2∗∗ )p(f2 |K2 , y)p(K2 , k2∗ , k2∗∗ |y) (18.6)
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


182

1
2 For prediction we only care about f2∗ , so we marginalize out the other variables. The posterior
3 mean is given by
4   
5
Ef2∗ |y [f2∗ ] = EK2 ,k2∗ ,k2∗∗ |y Ef2 |K2 ,y Ef2∗ |f2 ,K2 ,k2∗ ,k2∗∗ [f2∗ ] (18.7)
  ∗ T −1 
6 = EK2 ,k2∗ ,k2∗∗ |y Ef2 |K2 ,y (k2 ) K2 f2 (18.8)
7  
8
= EK2 ,k2∗ |y k2∗ K−1
2 Ef2 |K2 ,y [f2 ]
 (18.9)
9 | {z }
10 α

11
Since K2 and k2∗ are deterministic transformations of f1 (x∗ ), f1 (x1 ), . . . , f1 (xN ), we can rewrite this
12
as
13
"N #
14 X
15 Ef2∗ |y [f2∗ ] = Ef1 (x∗ ),f1 (x1 ),...,f1 (xN )|y αi K2 (f1 (xi ), f1 (x∗ )) (18.10)
16 i=1

17 We see from the above that inference in a DGP is, in general, very expensive, due to the need
18 to marginalize over a lot of variables, corresponding to all the hidden function values at each layer
19 at each data point. In [Salimbeni2017], they propose an approach to approximate inference in
20 DGPs based on the sparse variational method of Main Section 10.1.1.1. The key assumption is that
21 each layer has a set of inducing points, along with corresponding inducing values, that simplifies the
22 dependence between unknown function values within each layer. However, the dependence between
23 layers is modeled exactly. In [Dutordoir2021] they show that the posterior mean of such a sparse
24 variational approximation can be computed by performing a forwards pass through a ReLU DNN.
25
26
18.1.0.4 Behavior in the limit of infinite width
27
28 Consider the case of a DGP where the depth is 2. The posterior mean of the predicted output at
29 a test point is given by Equation (18.10). We see that this is a mixture of data-dependent kernel
30 functions, since both K2 and k2 depend on the data y. This is what makes deep GPs more expressive
31 than single layer GPs, where the kernel is fixed. However, [Pleiss2021] show that, in the limit
32 H1 → ∞, the posterior over the kernels for the layer 2 features  becomes
 independent of the data,
33 i.e., p(K2 , k2∗ |y) = δ(K2 − Klim )δ(k2∗ − klim

), where Klim = E f2 f2T and klim

= E [f2 f2∗ ], where the
34 expectations depend on X but not y. Consequently the posterior predictive mean reduces to
35
N
X
36
lim Ef2∗ |y [f2∗ ] = αi Klim (xi , x∗ ) (18.11)
37 H1 →∞
i=1
38
39 which is the same form as a single layer GP.
40 As a concrete example, consider a 2 layer DGP with an RBF kernel at each layer. Thus the kernel
41 at level 1 is K1 (x, x0 ) = exp(−||x − x0 ||2 /(2D)), and the kernel at level 2 is K2 (f1 (x), f1 (x0 )) =
42 exp(−||f1 (x) − f1 (x0 )||2 /(2H1 )). Let us fit this model to a noisy step function. In Figure 18.4 we
43 show the results as we increase the width of the hidden layer. When the width is 1, we see that
44 the covariance of the resulting DGP, K2 (f1 (x), f1 (x0 )), is nonstationary. In particular, there are
45 long-range correlations near x = ±1 (since the function is constant in this region), but short range
46 correlations near x = 0 (since the functon is changing rapidly in this region). However, as the width
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
18.1. DEEP GPS

1
Posterior F’n (Width = 1) Posterior F’n (Width = 16) Posterior F’n (Width = 256) Posterior F’n (Width = ∞)
2

f2 (f1 (x))
1 Obs.
0 Posterior
3
−1
4 −1.0 −0.5 0.0 0.5 1.0 −1.0 −0.5 0.0 0.5 1.0 −1.0 −0.5 0.0 0.5 1.0 −1.0 −0.5 0.0 0.5 1.0
x x x x
5
6 (a)
Posterior Kern. (Width = 1) Posterior Kern. (Width = 16) Posterior Kern. (Width = 256) Posterior Kern. (Width = ∞)
−1.0
7
8 −0.5

9 0.0
x

10 0.5

11 1.0
−1.0 −0.5 0.0 0.5 1.0 −1.0 −0.5 0.0 0.5 1.0 −1.0 −0.5 0.0 0.5 1.0 −1.0 −0.5 0.0 0.5 1.0
12 x x x x
13 (b)
14
15 Figure 18.4: (a) Posterior of 2-layer RBF deep GP fit to a noisy step function. Columns represent width of 1,
16 16, 256 and infinity. (b) Average posterior covariance of the DGP, given by Ef1 (x),f1 (x0 )|y [K2 (f1 (x), f1 (x0 ))].
17 As the width increases, the covariance becomes stationary, as shown by the kernel’s constant diagonals. From
18
Figure 1 of [Pleiss2021]. Used with kind permission of Geoff Pleiss.
19
20
increases, we lose this nonstationarity, as shown by the constant diagonals of the kernel matrix. Indeed,
21
in [Pleiss2021] they prove that the limiting kernel is Klim (x, x0 ) = exp(exp(−||x − x0 ||2 /(2D)) − 1),
22
which is stationary.
23
In [Pleiss2021], they also show that increasing the width makes the marginals more Gaussian,
24
due to central-limit like behavior. However, increasing the depth makes the marginals less Gaussian,
25
and causes them to have sharper peaks and heavier tails. Thus one often gets best results with a
26
deep GP if it is deep but narrow.
27
28
29
18.1.0.5 Connection with Bayesian neural networks
30 A Bayesian neural network (BNN) is a DNN in which we place priors over the parameters (see
31 Main Section 17.1). One can show (see e.g., [Ober2021]) that BNNs are a degenerate form of deep
32 GPs. For example, consider a 2 layer MLP, f2 (f1 (x)), with f1 : RD → RH1 and f2 : RH1 → R,
33 defined by
34
(i) (i) 1
35 f1 (x) = (w1 )T x + βb1 , f2 (z) = √ wT2 ϕ(z) + βb2 (18.12)
36 H1
37 where β > 0 is a scaling constant, and W1 , b1 , w2 , b2 are Gaussian. The first layer is a linear
38 regression model, and hence (from the results in Main Section 18.3.3) corresponds to a GP with a
39 linear kernel of the form K1 (x, x0 ) = xT x0 . The second layer is also a linear regression model but
40 applied to features ϕ(z). Hence (from the results in Main Section 18.3.3) this corresponds to a GP
41 with a linear kernel of the form K2 (z, z 0 ) = ϕ(z)T ϕ(z 0 ). Thus each layer of the model corresponds to
42 a (degenerate) GP, and hence the overall nodel is a (degenerate) DGP. (The term “degenerate” refers
43 to the fact that the covariance matrices only have a finite number of non-zero eigenvalues, due to the
44 use of a finite set of basis functions.) Consequently we can use the results from Section 18.1.0.4 to
45 conclude that infinitely wide DNNs also reduce to a single layer GP, as we already established in
46 Main Section 18.7.1.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


184

1
2 In practice we use finite-width DNNs. The width should be wide enough to approximate a standard
3 GP at each layer, but should not be too wide, otherwise the corresponding kernels of the resulting
4 deep GP will no longer be adapted to the data, i.e., there will not be any “feature learning”. See e.g.,
5 [Aitchison2020; Pleiss2021; Zavatone-Veth2021] for details.
6
7
18.2 GPs and SSMs
8
9 Consider a Matern kernel of order ν = 32 with length scale ` and variance σ 2 :
10
√ ! √ !
11 3 3τ 3τ
12
2
K(τ ; , `) = σ 1 + exp − (18.13)
2 ` `
13
14
For this kernel, we define
15
     
16 0 1 0 1
F= , L = , H = (18.14)
17 −λ2 −2λ 1 0
18
19 Consider a Matern kernel of order ν = 23 with length scale ` and variance σ 2 :
20 !
√ √ !
21 5 5r 5τ 2 5τ
22
2
K(τ ; , `) = σ 1 + + 2 exp − (18.15)
2 ` 3` `
23

24
We define λ = ` ,

p = ν − 21 , and
25
26    
  0 1
27 0 1 0
F= , L = 0 , H = 0 (18.16)
28 −λ3 −3λ2 −3λ
1 0
29
1
30 2σ 2 π 2 λ2p+1 Γ(p + 1)
q= (18.17)
31 Γ(p + 12 )
32
33 If ∆k = tk − tk−1 , the LG-SSM becomes
34
35
zk = Ak−1 zk−1 + N (0, Qk−1 ) (18.18)
36 yk = Hzk + N (0, σn2 ) (18.19)
37
38 where
39
40
Φ(τ ) = expm(Fτ ) (18.20)
41 Ak−1 = Φ(∆k ) (18.21)
42 Z ∆k
43 Qk−1 = Φ(∆k − τ ) L q LT Φ(∆k − τ )T dτ (18.22)
0
44
45
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
19 Beyond the iid assumption
Part IV

Generation
20 Generative models: an overview
21 Variational autoencoders

21.0.1 VAEs with missing data


Sometimes we may have missing data, in which parts of the data vector x ∈ RD may be unknown.
In ?? we saw a special case of this when we discussed multimodal VAEs. In this section we allow for
arbitrary patterns of missingness.
To model the missing data, let m ∈ {0, 1}D be a binary vector where mj = 1 if xj is missing, and
mj = 0 otherwise. Let X = {x(n) } and M = {m(n) } be N × D matrices. Furthermore, let Xo be
the observed parts of X and Xh be the hidden parts. If we assume p(M|Xo , Xh ) = p(M), we say the
data is missing completely at random or MCAR, since the missingness does not depend on the
hidden or observed features. If we assume p(M|Xo , Xh ) = p(M|Xo ), we say the data is missing at
random or MAR, since the missingness does not depend on the hidden features, but may depend
on the visible features. If neither of these assumptions hold, we say the data is not missing at
random or NMAR.
In the MCAR and MAR cases, we can ignore the missingness mechanism, since it tells us nothing
about the hidden features. However, in the NMAR case, we need to model the missing data
mechanism, since the lack of information may be informative. For example, the fact that someone
did not fill out an answer to a sensitive question on a survey (e.g., “Do you have COVID?”) could
be informative about the underlying value. See e.g., [Little87; Marlin08] for more information on
missing data models.
In the context of VAEs, we can model the MCAR scenario by treating the missing values as latent
variables. This is illustrated in Figure 21.1(a). Since missing leaf nodes in a directed graphical model
(i)
do not affect their parents, we can simply ignore them when computing the posterior p(z (i) |xo ),
(i)
where xo are the observed parts of example i. However, when using an amortized inference network,
(i)
it can be difficult to handle missing inputs, since the model is usually trained to compute p(z (i) |x1:d ).
One solution to this is to use the product of experts approach discussed in the context of multi-modal
VAEs in ??. However, this is designed for the case where whole blocks (corresponding to different
modalities) are missing, and will not work well if there are arbitrary missing patterns (e.g., pixels
that get dropped out due to occlusion or scratches on the lens). In addition, this method will not
work for the NMAR case.
An alternative approach, proposed in [Collier2020], is to explicitly include the missingness
indicators into the model, as shown in Figure 21.1(b). We assume the model always generates each
xj for j = 1 : d, but we only get to see the “corrupted” versions x̃j . If mj = 0 then x̃j = xj , but if
mj = 1, then x̃j is a special value, such as 0, unrelated to xj . We can model any correlation between
the missingness elements (components of m) by using another latent variable zm . This model can
192

1
2
(i) (i)
3 zm m(i) z(i) zm m(i) z(i)
4
5
z(i)
6
7 (i)
x1 ... (i)
xd
(i)
x1 ... (i)
xd
8 x1
(i) ... xd
(i)

9
N (i) ... (i) (i) ... (i)
10 x̃1 x̃d x̃1 x̃d
11 (a) Missingness as latents
N N
12
13 (b) MCAR corruption process (c) MNAR corruption process
14
Figure 1. Latent
Figure 21.1: Illustration variableVAE
of different models for missing
variants for data. Greyed
handling nodes are
missing observed,
data. Fromwhite nodes
Figure 1 ofare[Collier2020].
latent.
15
Used with kind permission of Mark Collier.
16
17
18
19
easily be extended to the NMAR case by letting m depend on the latent factors for the observed
20
data, z, as well as the usual missingess latent factors zm , as shown in Figure 21.1(c).
21
We modify the VAE to be conditional on the missingness pattern, so the VAE decoder has the
22
form p(xo |z, m), and the encoder has the form q(z|xo , m). However, we assume the prior is p(z)
23
as usual, independent of m. We can compute a lower bound on the log marginal likelihood of the
24
observed data, given the missingness, as follows:
25 Z Z
26 log p(xo |m) = log p(x
Original |z, m)p(z)dx
o , xm(x) Mask (m) m dz Corrupted (x̃) VAE (21.1)
27 Z
28 = log p(xo |z, m)p(z)dz (21.2)
29 Z
q(z|x̃, m)
30
= log p(xo |z, m)p(z) dz (21.3)
31 q(z|x̃, m)
 
32 p(z)
= log Eq(z|x̃,m) p(xo |z, m) (21.4)
33 q(z|x̃, m)
34
≥ Eq(z|x̃,m) [log p(xo |z, m)] − DKL (q(z|x̃, m) k p(z)) (21.5)
35
36
We can fit this model in the usual way.
37
38
39
40 Figure 2. MNIST MCAR example images. Shown are the original image, missingness mask, the corrupted image and the mean
41 reconstructions provided by the No Ind. EO Ind. and ED Ind. methods.
42
43
44
45
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
(i) (i)
zm m(i) z(i) zm m(i) z(i)

z(i)

(i)
x1 ... (i)
xd
(i)
x1 ... (i)
xd
1 x1
(i) ... xd
(i)

2 N (i)
x̃1 ... (i)
x̃d
(i)
x̃1 ... (i)
x̃d
3
(a) Missingness as latents
4 N N
5 (b) MCAR corruption process (c) MNAR corruption process
6
Figure 1. Latent variable models for missing data. Greyed nodes are observed, white nodes are latent.
7
8
9
10
11
12
13
14
15
16
17 Original (x) Mask (m) Corrupted (x̃) VAE
18
19
20
21
22
23
24
25
26
27
28
29 Figure 2. MNIST MCAR example images. Shown are the original image, missingness mask, the corrupted image and the mean
30 Figure 21.2: Imputing
reconstructions missing
provided by pixels
the No Ind. given
EO Ind. a masked
and ED out image using a VAE using a MCAR assumption.
Ind. methods.
31
From Figure 2 of [Collier2020]. Used with kind permission of Mark Collier.
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


22 Auto-regressive models
23 Normalizing flows
24 Energy-based models
25 Denoising diffusion models
26 Generative adversarial networks
Part V

Discovery
27 Discovery methods: an overview
28 Latent factor models

28.1 Inference in topic models


In this section, we discuss some methods for performing inference in LDA (Latent Dirichlet Allocation)
models, defined in Main Section 28.5.1.

28.1.1 Collapsed Gibbs sampling for LDA


In this section, we discuss how to perform inference using MCMC.
The simplest approach is to use Gibbs sampling. The full conditionals are as follows:

p(mil = k|·) ∝ exp[log πik + log wk,xil ] (28.1)


X
p(π i |·) = Dir({αk + I (mil = k)}) (28.2)
l
XX
p(wk |·) = Dir({γv + I (xil = v, mil = k)}) (28.3)
i l

However, one can get better performance by analytically integrating out the π i ’s and the wk ’s,
both of which have a Dirichlet distribution, and just sampling the discrete mil ’s. This approach
was first suggested in [Griffiths04], and is an example of collapsed Gibbs sampling. Figure 28.1(b)
shows that now all the mil variables are fully correlated. However, we can sample them one at a
time, as we explain below. PLi
First, we need some notation. Let Nivk = l=1 (mil = k, xil = v) be the number of times word
IP
v is assigned to topic k in document i. Let Nik = v P Nivk be the number of times any word from
document i has been assigned to topic k. Let Nvk = Pi Nivk be the number of times word v has
been assigned to topic k inP any document. Let Nk = v Nvk be the number of words assigned to
topic k. Finally, let Li = k Nik be the number of words in document i; this is observed.
We can now derive the marginal prior. By applying Main Equation (3.94), one can show that
"L #
YZ Yi

p(m|α) = Cat(mil |π i ) Dir(π i |α1K )dπ i (28.4)


i l=1
 N YN QK
Γ(Kα) k=1 Γ(Nik + α)
= (28.5)
Γ(α)K i=1
Γ(Li + Kα)
210

1
α
2
3 z1 zN α
... ...
4
5 c1,1 c cN,1 c
. . . 1,L1 ... . .N,L
. N c1,1 c
. . . 1,L1 ...
cN,1 c
. .N,L
. N
6
7 x1,1 x xN,1 x
. . . 1,L1 ... . .N,L
. N
8 x1,1 x
. . . 1,L1 ...
xN,1 x
. .N,L
. N
9
W
10
11 β
β
12
(a) (b)
13
14
Figure 28.1: (a) LDA unrolled for N documents. (b) Collapsed LDA, where we integrate out the continuous
15 latents zn and the continuous topic parameters W.
16
17
18
By similar reasoning, one can show
19
" #
20 YZ Y
21 p(x|m, β) = Cat(xil |wk ) Dir(wk |β1V )dwk (28.6)
22 k il:mil =k

23
 K QV
K Y
Γ(V β) Γ(Nvk + β)
24 = v=1
(28.7)
Γ(β)V Γ(Nk + V β)
25 k=1
26
From the above equations, and using the fact that Γ(x + 1)/Γ(x) = x, we can derive the full

27
conditional for p(mil |m−i,l ). Define Nivk to be the same as Nivk except it is computed by summing
28
over all locations in document i except for mil . Also, let xil = v. Then
29
− −
30 Nv,k +β Ni,k +α
p(mi,l = k|m−i,l , y, α, β) ∝ (28.8)
31
Nk− + V β Li + Kα
32
33 We see that a word in a document is assigned to a topic based both on how often that word is
34 generated by the topic (first term), and also on how often that topic is used in that document (second
35 term).
36 Given Equation (28.8), we can implement the collapsed Gibbs sampler as follows. We randomly
37 assign a topic to each word, mil ∈ {1, . . . , K}. We can then sample a new topic as follows: for a
38 given word in the corpus, decrement the relevant counts, based on the topic assigned to the current
39 word; draw a new topic from Equation (28.8), update the count matrices; and repeat. This algorithm
40 can be made efficient since the count matrices are very sparse [Li2014lda].
41 This process is illustrated in Figure 28.2 on a small example with two topics, and five words.
42 The left part of the figure illustrates 16 documents that were sampled from the LDA model using
43 p(money|k = 1) = p(loan|k = 1) = p(bank|k = 1) = 1/3 and p(river|k = 2) = p(stream|k = 2) =
44 p(bank|k = 2) = 1/3. For example, we see that the first document contains the word “bank” 4 times
45 (indicated by the four dots in row 1 of the “bank” column), as well as various other financial terms.
46 The right part of the figure shows the state of the Gibbs sampler after 64 iterations. The “correct”
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
6
7
8
9
10
11
12
13
14 28.1. INFERENCE IN TOPIC MODELS
15
16

1
River Stream Bank Money Loan
2 1 River Stream Bank Money Loan
2 1
3 2
3 4 3
5 4
6 5
4 7 6
8 7
5 9 8
10 9
11 10
6 12 11
13 12
13
7 14 14
15 15
16 16
8
9
(a) (b)
10 River Stream Bank Money Loan
1
2
11 Figure 28.2: Illustration of (collapsed) Gibbs sampling applied to a small LDA example. There are N = 16
3
4
12 documents, each containing a variable number of words drawn from a vocabulary of V = 5 words, There are
5
6
7
13 two topics. A white dot means word the word is assigned to topic 1, a black dot means the word is assigned to
8
9
14 topic 2. (a) The initial random assignment of states. (b) A sample from the posterior after 64 steps of Gibbs
10
11
12
15 sampling. From Figure 7 of [Steyvers07]. Used with kind permission of Tom Griffiths.
13
14
15
16 16

17
18 topic has been assigned to each token in most cases. For example, in document 1, we see that the word
19 “bank” has been correctly assigned to the financial topic, based on the presence of the words “money”
20 and “loan”. The posterior mean estimate of the parameters is given by p̂(money|k = 1) = 0.32,
21 p̂(loan|k = 1) = 0.29, p̂(bank|k = 1) = 0.39, p̂(river|k = 2) = 0.25, p̂(stream|k = 2) = 0.4, and
22 p̂(bank|k = 2) = 0.35, which is impressively accurate, given that there are only 16 training examples.
23
24 28.1.2 Variational inference for LDA
25
26
A faster alternative to MCMC is to use variational EM, which we discuss in general terms in
27
Main Section 6.5.6.1. There are several ways to apply this to LDA, which we discuss in the following
28
sections.
29
30 28.1.2.1 Sequence version
31
In this section, we focus on a version in which we unroll the model, and work with a latent variable
32
for each word. Following [Blei03], we will use a fully factorized (mean field) approximation of the
33
form
34
Y
35 q(zn , sn ) = Dir(zn |z̃n ) Cat(snl |Ñnl ) (28.9)
36 l
37
38 where z̃n are the variational parameters for the approximate posterior over zn , and Ñnl are the
39 variational parameters for the approximate posterior over snl . We will follow the usual mean field
40 recipe. For q(snl ), we use Bayes’ rule, but where we need to take expectations over the prior:
41
42 Ñnlk ∝ wd,k exp(E[log znk ]) (28.10)
43
where d = xnl , and
44
45
X
E [log znk ] = ψk (z̃n ) , Ψ(z̃nk ) − ψ( z̃nk0 ) (28.11)
46
k0
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


212

1
2 where ψ is the digamma function. The update for q(zn ) is obtained by adding up the expected
3 counts:
4 X
5 z̃nk = αk + Ñnlk (28.12)
6 l

7
The M step is obtained by adding up the expected counts and normalizing:
8
9 Ln
N X
X
10 ŵdk ∝ βd + Ñnlk I (xnl = d) (28.13)
11 n=1 l=1
12
13 28.1.2.2 Count version
14 P
15 Note that the E step takes O(( n Ln )Nw Nz ) space to store the Ñnlk . It is much more space efficient
16 to perform inference in the mPCA version of the model, which works with counts; these only take
17 O(N Nw Nz ) space, which is a big savings if documents are long. (By contrast, the collapsed Gibbs
18 sampler must work explicitly with the snl variables.)
19 Following the discussion in Main Section 28.4.1, we will work with the variables zn and Nn , where
20 Nn = [Nndk ] is the matrix of counts, which can be derived from sn,1:Ln . We will again use a fully
21 factorized (mean field) approximation of the form
22 Y
23 q(zn , Nn ) = Dir(zn |z̃n ) M(Nnd |xnd , Ñnd ) (28.14)
24 d

25 PLn
26
where xnd = l=1 I (xnl = d) is the total number of times token d occurs in document n.
27
The E step becomes
28 X
29
z̃nk = αk + xnd Ñndk (28.15)
d
30
31 Ñndk ∝ wdk exp(E [log znk ]) (28.16)
32
33 The M step becomes
34 X
35 ŵdk ∝ βd + xnd Ñndk (28.17)
36 n

37
38 28.1.2.3 Bayesian version
39
We now modify the algorithm to use variational Bayes (VB) instead of EM, i.e., we infer the
40
parameters as well as the latent variables. There are two advantages to this. First, by setting β  1,
41
VB will encourage W to be sparse (as in Main Section 10.3.6.6). Second, we will be able to generalize
42
this to the online learning setting, as we discuss below.
43
Our new posterior approximation becomes
44
Y Y
45
q(zn , Nn , W) = Dir(zn |z̃n ) M(Nnd |xnd , Ñnd ) Dir(wk |w̃k ) (28.18)
46
d k
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
28.1. INFERENCE IN TOPIC MODELS

1
2 Algorithm 28.1: Batch VB for LDA
3
1 Input: {xnd }, Nz , α, β
4
2 Estimate w̃dk using EM for multinomial mixtures
5
3 while not converged do
6
4 // E step
7
5 adk = 0 // expected sufficient statistics
8
6 for each document n = 1 : N do
9
7 (z̃n , Ñn ) = VB-Estep(xn , W̃, α)
10
11
8 adk + = xnd Ñndk
12 9 // M step
13 10 for each topic k = 1 : Nz do
14 11 w̃dk = βd + adk
15
16
12 function (z̃n , Ñn ) = VB-Estep(xn , W̃, α)
17
13 Initialize z̃nk = αk
18
14 repeat
19
15 z̃nold = z̃n , z̃nk = αk
20
16 for each word d = 1 : Nw do
21
17 for each topic k = 1 : Nz do 
22 18 Ñndk = exp ψk (w̃d ) + ψk (z̃nold )
23 19 Ñnd = normalize(Ñnd )
24
20 z̃n + = xnd Ñnd
25
26
21 until Converged
27
28
29
The update for Ñndk changes, to the following:
30
31 Ñndk ∝ exp (E[log wdk ] + E[log znk ]) (28.19)
32
33 The M step is the same as before:
34 X
35 ŵdk ∝ βd + xnd Ñndk (28.20)
36 n
37
No normalization is required, since we are just updating the pseudcounts. The overall algorithm is
38
summarized in Algorithm 28.1.
39
40
28.1.2.4 Online (SVI) version
41
42 In the batch version, the E step takes O(N Nz Nw ) per mean field update. This can be slow if we
43 have many documents. This can be reduced by using stochastic variatonal inference, as discussed
44 in Main Section 10.1.4. We perform an E step in the usual way. We then compute the variational
45 parameters for W treating the expected sufficient statistics from the single data case as if the whole
46 data set had those statistics. Finally, we make a partial update for the variational parameters for
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


214

1
2 Algorithm 28.2: Online VB for LDA
3
1 Input: {xnd }, Nz , α, β, LR schedule
4
2 Initialize w̃dk randomly
5
3 for t = 1 : ∞ do
6
4 Set step size ηt
7
5 Pick document n;
8
6 (z̃n , Ñn ) = VB-Estep(xn , W̃, α)
9
new
10
7 w̃dk = βd + N xnd Ñndk
new
11
8 w̃dk = (1 − ηt )w̃dk + ηt w̃dk
12
13
14 Online 98K
900
15
16
850
17 800 Batch 98K
Perplexity

Online 3.3M
18 750
19 700
20
650
21
600
22
23
103.5 104 104.5 105 105.5 106 106.5
24
Documents seen (log scale)
25
26 Documents 2048 4096 8192 12288 16384 32768 49152 65536
Figure 28.3: Test perplexity vs number of training documents for batch and online VB-LDA. From Figure 1
analyzed
27 of [Hoffman10]. Used with kind
systems permission
systems of Davidservice
service Blei. service business business business
28 road health systems systems companies service service industry
made communication health companies systems companies companies service
29 Top eight service service companies business business industry industry companies
30 words announced billion market company company company services services
W, putting weight national
ηt on thelanguage
new estimate and weight
communication − ηt onmanagement
billion 1 industry the old estimate.
company The
companystep size
ηt
31
decays over time, according
west tocaresome schedule,
company ashealthin SGD. The overall
market systems algorithm
managementismanagement
summarized
in
32 language road billion industry billion services public public
Algorithm 28.2. In practice, we should use mini-batches. In [Hoffman10], they used a batch of size
33
256–4096.
Figure 1: Top: Perplexity on held-out Wikipedia documents as a function of number of documents
34
Figure 28.3i.e.,
analyzed, plots
the the perplexity
number on a Online
of E steps. test setVB
of size 1000
run on 3.3vsmillion
number of analyzed
unique documents
Wikipedia articles is(E
35
steps), where the data is drawn from (English) Wikipedia. The figure shows that
compared with online VB run on 98,000 Wikipedia articles and with the batch algorithm run online variational
on the
36
inference is much
same 98,000 faster than
articles. offlinealgorithms
The online inference, converge
yet produces
muchsimilar
faster results.
than the batch algorithm does.
37
Bottom: Evolution of a topic about business as online LDA sees more and more documents.
38
39
40
to summarize the latent structure of massive document collections that cannot be annotated by hand.
41
A central research problem for topic modeling is to efficiently fit models to larger corpora [4, 5].
42
43 To this end, we develop an online variational Bayes algorithm for latent Dirichlet allocation (LDA),
44 one of the simplest topic models and one on which many others are based. Our algorithm is based on
45 online stochastic optimization, which has been shown to produce good parameter estimates dramat-
46
ically faster than batch algorithms on large datasets [6]. Online LDA handily analyzes massive col-
47
lections of documents and, moreover, online LDA need not locally store or collect the documents—
each can arrive in a stream and be discarded after one look.
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
In the subsequent sections, we derive online LDA and show that it converges to a stationary point
29 State-space models

29.1 Continuous time SSMs


In this section, we briefly discuss continuous time dynamical systems.

29.1.1 Ordinary differential equations


We first consider a 1d system, whose state at time t is z(t). We assume this evolves according to the
following nonlinear differential equation:

dz
= f (t, z) (29.1)
dt

We assume the observations occur at discrete time steps ti ; we can estimate the hidden state at these
time steps, and then evolve the system dynamics in continuous time until the next measurement
arrives. To compute zi from zi−1 , we use
Z ti
zi+1 = zi + f (t, zi )dt (29.2)
ti−1

To compute the integral, we will use the second-order Runge-Kutta method, with a step size of
∆ = ti − ti−1 be the sampling frequency. This gives rise to the following update:

k1 = f (ti , zi ) (29.3)
k2 = f (ti + ∆, zi + k1 ∆) (29.4)

zi+1 = zi + (k1 + k2 ) (29.5)
2

The term k1 is the slope at zi , and the term k2 is the slope at zi+1 , so 12 (k1 + k2 ) is the average
slope. Thus zi+1 is the initial value zi plus the step size ∆ times the average slope, as illustrated in
Figure 29.1.
When we have a vector-valued state-space, we need to solve a multidimensional integral. However,
if the components are conditionally independent given the previous state, we can reduce this to a set
of independent 1d integrals.
216

1
2
z
3 Slope = f(ti + h, zi + k1h)
4
5 zi+1 predicted
6 Slope = f(ti, zi)
7
8 Average Slope = ½ [f(ti + h, zi + k1h) + f(ti , zi)]
9
zi
10
11
12
ti ti+1 t
13
14
Figure 29.1: Illustration of one step of second-order Runge-Kutta method with step size h.
15
16
17
18 29.1.2 Example: Noiseless 1d spring-mass system
19
20 In this section, we consider an example from Wikipedia1 of a spring mass system operating in
21 1d. Like many physical systems, this is best modeled in continuous time, although we will later
22 discretize it.
23 Let x(t) be the position of an object which is attached by a spring to a wall, and let ẋ(t) and ẍ(t)
24 be its velocity and acceleration. By Newton’s laws of motion, we have the following ordinary
25 differential equation:
26
27 mẍ(t) = u(t) − bẋ(t) − cx(t) (29.6)
28
29
where u(t) is an externally applied force (e.g., someone tugging on the object), b is the viscous friction
30
coefficient, c is the spring constant, and m is the mass of the object. See Figure 29.3 for the setup.
31
We assume that we only observe the position, and not the velocity.
32
We now proceed to represent this as a first order Markov system. For simplicity, we ignore the noise
33
(Qt = Rt = 0). We define the state space to contain the position and velocity, z(t) = [x(t), ẋ(t)].
34
Thus the model becomes
35
36
ż(t) = Fz(t) + Bu(t) (29.7)
37 y(t) = Hz(t) + Du(t) (29.8)
38
39 where
40       
ẋ(t) 0 1 x(t) 0
41 = c b + 1 u(t) (29.9)
ẍ(t) −m −m ẋ(t) m
42  
43  x(t)
y(t) = 1 0 (29.10)
44 ẋ(t)
45
46 1. https://en.wikipedia.org/wiki/State-space_representation#Moving_object_example
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
29.1. CONTINUOUS TIME SSMS

1
2
3
4
5
6
7
8
9
10
11
Figure 29.2: Illustration of the spring mass system.
12
13
14
15 To simulate from this system, we need to evaluate the state at a set of discrete time intervals,
16 tk = k∆, where ∆ is the sampling rate or step size. There are many ways to discretize an ODE.2
17 Here we discuss the generalized bilinear transform [Zhang2007]. In this approach, we specify
18 a step size ∆ and compute
19
20 zk = (I − α∆F)−1 (I + (1 − α)∆F) zk−1 + ∆(I − α∆F)−1 B uk (29.11)
| {z } | {z }
21
A B
22
23 If we set α = 0, we recover Euler’s method, which simplifies to
24
25 zk = (I + ∆F) zk−1 + |{z}
∆B uk (29.12)
| {z }
26 B
A
27
28 If we set α = 1, we recover the backward Euler method. If we set α = 12 we get the bilinear
29 method, which preserves the stability of the system [Zhang2007]; we will use this in Section 29.2.
30 Regardless of how we do the discretization, the resulting discrete time SSM becomes
31
32 zk = Fzk−1 + Buk (29.13)
33
yk = Hzk + Duk (29.14)
34
35 Now consider simulating a system where we periodically “tug” on the object, so the force increases
36 and then decreases for a short period, as shown in the top row of Figure 29.3. We can discretize the
37 dynamics and compute the corresponding state and observation at integer time points. The result is
38 shown in the bottom row of Figure 29.3. We see that the object’s location changes smoothly, since it
39 integrates the force over time.
40
41
42
29.1.3 Example: tracking a moving object in continuous time
43 In this section, we consider a variant of the example in Section 8.1.1. We modify the dynamics so
44 that energy is conserved, by ensuring the velocity is constant, and by working in continuous time.
45
46 2. See discussion at https://en.wikipedia.org/wiki/Discretization.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


218

1
2
3
4
5
6
7
8
9
10
11
Figure 29.3: Signals genreated by the spring mass system. Top row shows the input force. Bottom row
12
shows the observed location of the end-effector. Adapted from A figure by Sasha Rush. Generated by
13
ssm_spring_demo.ipynb.
14
15
16
Thus the particle moves in a circle, rather than spiralling in towards the origin. (See [Strogatz2015]
17
for details.) The dynamics model becomes
18
19 zt = Fzt−1 + t (29.15)
20  
0 1
21 F= (29.16)
−1 0
22
23 where t ∼ N (0, Q) is the system noise. We set Q = 0.001I, so the noise is negligible.
24 To see why this results in circular dynamics, we follow a derivation from Gerardo Durán-Martín.
25 We can represent a point in the plane using polar coordinates, (r, θ), or Euclidean coordinates, (x, y).
26 We can switch between these using standard trigonometric identities:
27
28 xt = rt cos θt , yt = rt sin θt (29.17)
29
30
The (noiseless) dynamical system has the following form:
31    
ẋt yt
32 = (29.18)
ẏt −xt
33
34 We now show that this implies that
35
36 ṙ = 0, θ̇ = −1 (29.19)
37
which means the radius is constant, and the angle changes at a constant rate. To see why, first note
38
that
39
40 rt2 = x2t + yt2 (29.20)
41
d 2 d
42 rt = (x2t + yt2 ) (29.21)
dt dt
43
44
2rt r˙t = 2xt ẋt + 2yt ẏt (29.22)
xt ẋt + yt ẏt
45
ṙt = (29.23)
46 rt
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
29.1. CONTINUOUS TIME SSMS

1
2
State Space Approximate Space
3
4 1.0 1.0
5
6 0.5 0.5
7
8 0.0 0.0

9
0.5 0.5
10
11 state space Filtered
1.0 observations 1.0 observations
12
1.5 1.0 0.5 0.0 0.5 1.0 1.5 1.5 1.0 0.5 0.0 0.5 1.0 1.5
13
14 (a) (b)
15
16 Figure 29.4: Illustration of Kalman filtering applied to a 2d linear dynamical system in continuous time. (a)
17 True underlying state and observed data. (b) Estimated state. Generated by kf_continuous_circle.ipynb.
18
19
20 Also,
21
yt
22 tan θt = (29.24)
xt
23
d d yt
24 tan θt = (29.25)
25
dt dt xt
1 yt xt ẏt − ẋt yt
26 θ̇t sec2 (θt ) = ẏt − 2 ẋt = (29.26)
27 xt xt x2t
28 xt ẏt − ẋt yt
θ̇t = (29.27)
29 rt2
30
31 Plugging into Equation (29.18), and using the fact that cos2 + sin2 = 1, we have
32
(r cos θ)(r sin θ) − (r sin θ)(r cos θ)
33 ṙ = =0 (29.28)
34
r
−(r cos θ)(r cos θ) − (r sin θ)(r sin θ) −r2 (cos2 θ + sin2 θ)
35 θ̇ = = = −1 (29.29)
36 r2 r2
37 In most applications, we cannot directly see the underlying states. Instead, we just get noisy
38 obervations. Thus we will assume the following observation model:
39
40 yt = Hzt + δ t (29.30)
41
 
1 0
42 H= (29.31)
0 1
43
44 where δ t ∼ N (0, R) is the measurement noise. We set R = 0.01I.
45 We sample from this model and apply the Kalman filter to the resulting synthetic data, to estimate
46 the underlying hidden states. To ensure energy conservation, we integrate the dynamics between
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


220

1
2 State Space Approximate Space
3
4 1.0 1.0

5
0.5 0.5
6
7 0.0 0.0
8
9 0.5 0.5
10
1.0 state space 1.0 EKF estimation
11
observations observations
12 1.5 1.0 0.5 0.0 0.5 1.0 1.5 1.5 1.0 0.5 0.0 0.5 1.0 1.5
13
14 (a) (b)
15
16 Figure 29.5: Illustration of extended Kalman filtering applied to a 2d nonlinear dynamical system. (a) True
17 underlying state and observed data. (b) Estimated state. Generated by ekf_continuous.ipynb.
18
19
20 each observation in continuous time, using the RK2 method in Section 29.1.1. The result is shown in
21 Figure 29.4. We see that the method is able to filter out the noise, and track the underlying hidden
22 state.
23
24 29.1.4 Example: tracking a particle in 2d
25
Consider a particle moving in continuous time in 2d space with the following nonlinear dynamics:
26
 
27 z2
z 0 = fz (z) = (29.32)
28 z1 − z13
29
30 This is a mildly nonlinear version of the model in Section 29.1.3. We can use the RK2 method of
31 Section 29.1.1 applied to each component separately to compute the latent trajectory of the particle.
32 We use a step size of h = dt = 0.01 and simulate from t = 0 to T = 7.5. Thus the number of
33 integration steps is N = T /h = 750.
34 We sample the system at K = 70 evenly spaced time steps, and generate noisy observations using
35 yt = h(zt ) + N (0, R), where h(z) = z is the identity function. See Figure 29.5(a) for the hidden
36 trajectory and corresponding noisy observations.
37 In Figure 29.5(b), we condition on the noisy observations, and compute p(zt |y1:t ) using the EKF.
38 We see that the method can succesfully filter out the noise.
39
40 29.2 Structured State Space Sequence model (S4)
41
42 In this section, we briefly discuss a new sequence-to-sequence model known as the Structured State
43 Space Sequence model or S4 [S4; hippo; Goel2022]. Our presentation of S4 is based in part
44 on the excellent tutorial [RushS4].
45 An S4 model is basically a deep stack of (noiseless) linear SSMs (Main Section 29.6). In between
46 each layer we add pointwise nonlinearities and a linear mapping. Becauses SSMs are recurrent first-
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
29.2. STRUCTURED STATE SPACE SEQUENCE MODEL (S4)

1
2
3
4
5
6
7
8
9
10
+
11
12
13
14
15
16
17
18
19
20
21
Figure 29.6: Illustration of one S4 block. The input X has H sequences, each of length L. These get mapped
(in parallel for each of the sequences) to the output Y by a (noiseless) linear SSM with state size N and
22
parameters (A, B, C, D). The output Y is mapped pointwise through a nonlinearity φ to create Y0 . The
23
channels are then linearly combined by weight matrix W and added to the input (via a skip connection) to get
24
the final output Y00 , which is another set of H sequences of length L.
25
26
27
28
29
order models, we can easily generate (sample) from them in O(L) time, where L is the length of the
30
sequence. This is much faster than the O(L2 ) required by standard transformers (Main Section 16.3.5).
31
However, because these SSMs are linear, it turns out that we compute all the hidden representations
32
given known inputs in parallel using convolution; this makes the models fast to train. Finally, since
33
S4 models are derived from an underlying continuous time process, they can easily be applied to
34
observations at different temporal frequencies. Empirically S4 has been found to be much better at
35
modeling long range dependencies compared to transformers, which (at the time of writing, namely
36
January 2022) are considered state of the art.
37
The basic building block, known as a Linear State Space Layer (LSSL), is the following
38
continuous time linear dynamical system that maps an input sequence u(t) ∈ R to an output
39
sequence y(t) ∈ R1 via a sequence of hidden states z(t) ∈ RN :
40
41 ż(t) = Az(t) + Bu(t) (29.33)
42 y(t) = Cz(t) + Du(t) (29.34)
43
44 Henceforth we will omit the skip connection corresponding to the D term for brevity. We can convert
45 this to a discrete time system using the generalized bilinear transform discussed in Section 29.1.2. If
46 we set α = 12 in Equation (29.11) we get the bilinear method, which preserves the stability of the
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


222

1
2 system [Zhang2007]. The result is
3
4 zt = Azt−1 + But (29.35)
5 yt = Czt (29.36)
6 −1 N ×N
A = (I − ∆/2 · A) (I + ∆/ · A) ∈ R (29.37)
7
−1
8 B = (I − ∆/2 · A) ∆B ∈ R N
(29.38)
N ×1
9 C=C∈R (29.39)
10
11 We now discuss what is happening inside the LSSL layer. Let us assume the initial state is z−1 = 0.
12 We can unroll the recursion to get
13
2
14 z0 = Bu0 , z1 = ABu0 + Bu1 , z2 = A Bu0 + ABu1 + Bu2 , · · · (29.40)
15 2
y0 = CBu0 , y1 = CABu0 + CBu1 , y2 = CA Bu0 + CABu1 + CBu2 , · · · (29.41)
16
17
We see that zt is computing a weighted sum of all the past inputs, where the weights are controlled
18
by powers of the A matrix. (See also the discussion of subspace identification techniques in
19
Main Section 29.8.2.) It turns out that we can define A to have a special structure, known as
20
HiPPO (High-order Polynomial Projection Operator) [hippo], such that (1) it ensures zt embeds
21
“relevant” parts of the past history in a compact manner, (2) enables recursive computation of z1:L
22
in O(N ) time per step, instead of the naive O(N 2 ) time for the matrix vector multiply; and (3) only
23
has O(3N ) (complex-valued) parameters to learn.
24
However, recursive computation still takes time linear in L. At training time, when all inputs to
25
each location are already available, we can further speed thing up by recognizing that the output
26
sequence can be computed in parallel using a convolution:
27
28 k k−1
yk = CA Bu0 + CA Bu1 + · · · + CBuk (29.42)
29
30 y=K u (29.43)
L−1
31
K = (CB, CAB, · · · , CA B) (29.44)
32
33
We can compute the convolution kernel K matrix in O(N + L) time and space, using the S4
34
representation, and then use FFT to efficiently compute the output. Unfortunately the details of
35
how to do this are rather complicated, so we refer the reader to [S4].
36
Once we have constructed an LSSL layer, we can process a stack of H sequences independently
37
in parallel by replicating the above process with different parameters, for h = 1 : H. If we let each
38
of the H C matrices be of size N × M, so they return a vector of channels instead of a scalar at
39
each location, the overall mapping is from u1:L ∈ RH×L to y1:L ∈ RHM ×L . We can add a pointwise
40
nonlinearity to the output and then apply a projection matrix W ∈ RM H×H to linearly combine the
41
channels and map the result back to size y1:L ∈ RH×L , as shown in Figure 29.6. This overall block
42
can then be repeated a desired number of times. The input to the whole model is an encoder matrix
43
which embeds each token, and the output is a decoder that creates the softmax layer at each location,
44
as in the transformer. We can now learn the A, B, C, D, and W matrices (as well as the step size
45
∆) for each layer using backpropagation, using whatever loss function we want on the output layer.
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
30 Graph learning

30.1 Latent variable models for graphs


Graphs arise in many application areas, such as modeling social networks, protein-protein interaction
networks, or patterns of disease transmission between people or animals. There are usually two
primary goals when analyzing such data: first, try to discover some “interesting structure” in the
graph, such as clusters or communities; second, try to predict which links might occur in the future
(e.g., who will make friends with whom). In this section, we focus on the former. More precisely, we
will consider a variety of latent variable models for observed graphs.

30.1.1 Stochastic block model


In Figure 30.1(a) we show a directed graph on 9 nodes. There is no apparent structure. However,
if we look more deeply, we see it is possible to partition the nodes into three groups or blocks,
B1 = {1, 4, 6}, B2 = {2, 3, 5, 8}, and B3 = {7, 9}, such that most of the connections go from nodes in
B1 to B2 , or from B2 to B3 , or from B3 to B1 . This is illustrated in Figure 30.1(b).
The problem is easier to understand if we plot the adjacency matrices. Figure 30.2(a) shows the
matrix for the graph with the nodes in their original ordering. Figure 30.2(b) shows the matrix for
the graph with the nodes in their permuted ordering. It is clear that there is block structure.
We can make a generative model of block structured graphs as follows. First, for every node,
sample a latent block qi ∼ Cat(π), where πk is the probability of choosing block k, for k = 1 : K.
Second, choose the probability of connecting group a to group b, for all pairs of groups; let us denote
this probability by ηa,b . This can come from a beta prior. Finally, generate each edge Rij using the
following model:

p(Rij = r|qi = a, qj = b, η) = Ber(r|ηa,b ) (30.1)

This is called the stochastic block model [Nowicki01]. Figure 30.4(a) illustrates the model as a
DGM, and Figure 30.2 illustrates how this model can be used to cluster the nodes in our example.
Note that this is quite different from a conventional clustering problem. For example, we see
that all the nodes in block 3 are grouped together, even though there are no connections between
them. What they share is the property that they “like to” connect to nodes in block 1, and to receive
connections from nodes in block 2. Figure 30.3 illustrates the power of the model for generating many
different kinds of graph structure. For example, some social networks have hierarchical structure,
which can be modeled by clustering people into different social strata, whereas others consist of a set
of cliques.
224

1
2
3
4 1 2 3 1 6 4
5
6
7
4 5 6
8 8 2 3 5
9
10
11 7 8 9
12 9 7
13
14 a) (a) 0.1 b)(b) 1 6 82 0.9 0.
4 35
15
Figure 30.1: (a) A directed Input
graph. (b) The same graph, Output
0.1
with the nodes partitioned into 3 groups, making the
0.1 0.
16
17 block structure more apparent.
1 97 0.9 0.1 0.
18 T
19 123456789 16482359 7 82 z
a) 16 b) 0.1 0.9 0.1 c) d) T1
20 1 1 4 35 0.1 0.1 0.9 η
2 Input 6 Output
97 16482359
0.9 0.1 0.1
21 1
3 T
123456789
4
164823597
1 T1
22 1 4 8 z 6
23 T 5 1
2
3
1
6
4 2 1
164823597
T1 4
24 6 T 1 45 8
2 3 6
4 8
6 3 8 R
25
7 7
8
5
9
5 2
3
2 T2 T3
8 9 7 9 5 T31
9
26 9 7 7 5
27 9
28 (a) Figure 1: (a) Input and output when the IRM is applied
(b)to a binary relation R : T71 × T 1 → {0, 1}. The
29 partition of the entities, and the input matrix takes on a relatively clean block structure when sorted according
30
(b) The
Figure 30.2: (a) Adjacency matrix for IRM assumes
the graph in that relation
Figure R is generated
30.1(a). (b) Rowsfrom
and two latentare
columns structures: a partition z and a parameter
shown permuted
R(i,also
to show the block structure. We j) is show
generated
howbythetossing a coin with
stochastic blockbias η(zi , zcan
model j ), where zi and
generate zj graph.
this are the cluster
Fromassignments of enti
31 IRM inverts this generative model to discover the z and the η that best explain relation R. (c) Clustering a thr
32 Figure 1: (a) Input and output when the IRM is applied to a bina
Figure 1 of [Kemp06]. Used with kind1 permission
1 2 of Charles1 Kemp. 2
R : T × T × T → {0, 1}. T might be a set of people, T a set of social predicates, and R might spec
predicate applies to each pair of people. The IRM looks for solutions where each three dimensional sub-block
33
34
partition of the entities, and the input matrix takes on a relatively c
1s or mostly 0s. (d) Clustering three relations simultaneously. T 1 might be a set of people, T 2 a set of demo
Unlike a standard mixture T 3 a set of
andmodel, it questions
is not possible to fit test.
on a personality thisNote
model
that theusing exact
partition 1
for TEM, because
is the same wherever this type ap
35 (b) The IRM assumes that relation R is generated from two laten
all the latent q variables become correlated. However, one can use variational EM [Airoldi08],
i
36
37
R(i,
collapsed Gibbsj) is generated
sampling [Kemp06],
simultaneously by
etc. Wetossing
cluster omit the
people, a coin
details
social andwith
(which
predicates, bias
are similar
demo- η(z
ontothethe LDA
cluster , case).
zj ), where
iassignments of objects 1, . z
..
In [Kemp06], they liftedgraphicthe restriction
attributes. that the number of blocks K be fixed, by replacing the "
38
IRM
Dirichlet prior inverts this
on π by a Dirichlet generative
Formally, suppose(see
process the model
that Supplementary
observed data to??).
arediscover
mThis
i
Pthe
rela- is known (z as z ,and
the
= a|z infinite
j
) = η that
. . . , z the i
a is 1 i−1
na
i−1+γ
γ
39
1 1 tions involving 2 n types. Let R be the 1 ith relation, T 2 i−1+γ
40 R
relational : T
model. ×
See T
Sectionbe ×
If we have features associated
T
30.1.3
the jth
for T with
for
. Oureach
→and z {0,
type,details.
task node,
j
1}.
be a vector
is to inferwe can
ofT
j
assignments be a set of people, T
clustermight
make assignments,
the cluster a discriminative
and
where n
version is themodel,
of this number of objects already a
a
ter a, and γ is a parameter. The distributi
41
42
predicate
for example by definingapplies to each
we are ultimately
1
pair
interested
n
ofposterior
in the
1
people.
m
The IRM
distribution
P (z , . . . , z |R , . . . , R ). We specify this distribution by
by the CRPlooks for thesolutio
is exchangeable: order in w
1
resulting partition. T
assigned to clusters can be permuted with
43
p(R1s = or
ij r|q =mostly
a, q = b, x0s.
i
defining (d)
, x , θ)
j
a
i Clustering
generative
= Ber(r|w
j
model for the
f (x , x )) three
T
relations
a,b
and relations
the
i
cluster
j simultaneously.
probability of the (30.2) P (z)
assignments:
44
3
(x , x T
computed by choosing an arbitrary ordering
45 where fand i j a set
) is some way ofofPcombining
questions the feature
1
onvectors.
(R , . . . , R , z , . . . , z ) =m
a personality
1 For example, we test.
n
Note
couldconditional
use that the
concatenation,
probabilities partiti
as specified above
46 [xi , xj ], or elementwise product xi ⊗ xj as in supervised
m LDA. Then overall modeljects
is like a relational
can always be assigned to new cluster
! !
tively has access to a countably infinite colle
47 P (Ri |z 1 , . . . , z n ) P (z j )
hence the first part of its name.
Draft of “Probabilistic Machine Learning: Advanced i=1 Topics (Supplementary
j=1 Material)”.
A CRPAugust 15,partitions
prior on 2023 is mathematic
simultaneously cluster people, social predicates, and demo-
where we assume that the relations are conditionally inde- and consistent with the intuition that the p
30.1. LATENT VARIABLE MODELS FOR GRAPHS

1
2 A

Relational system
A
3 A B
4 B C A B
B
5
D C D E F C D
6 C
7 E
G H
8 D
9
10 A B C D A B C D A B C D E F GH A B C D E
11 A
Sorted matrix

A A A
12 B
C B
13 B B
D C
14 C C E
F D
15
D G
D E
16 H
17
18
Figure 30.3: Some examples of graphs generated using the stochastic block model with different kinds of
connectivity patterns between the blocks. The abstract graph (between blocks) represent a ring, a dominance
19
hierarchy, a common-cause structure, and a common-effect structure. From Figure 4 of [Kemp10]. Used
20
with kind permission of Charles Kemp.
21
22
23 qj α πj
π
24
qi←j
25
26 qi Ri,j
qi→j Ri,j
27 πi
I
28
I
J J
29
30
η a,b η a,b
31
32 (a) (b)
33
Figure 30.4: (a) Stochastic block model. (b) Mixed membership stochastic block model.
34
35
36
37 extension of the mixture of experts model.
38
39
30.1.2 Mixed membership stochastic block model
40
41 In [Airoldi08], they lifted the restriction that each node only belong to one cluster. That is, they
42 replaced qi ∈ {1, . . . , K} with π i ∈ SK . This is known as the mixed membership stochastic
43 block model, and is similar in spirit to fuzzy clustering or soft clustering. Note that πik is
44 not the same as p(zi = k|D); the former represents ontological uncertainty (to what degree does
45 each object belong to a cluster) whereas the latter represents epistemological uncertainty (which
46 cluster does an object belong to). If we want to combine epistemological and ontological uncertainty,
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


226

1
2
1 Ambrose
12 7 Outcasts 2 Boniface
3 5
6 3 Mark
4 Winfrid
5 Elias
4 6 Basil
7 Simplicius
5 Waverers 8 Berthold
9 John Bosco
10 Victor
6 15 11 Bonaventure
Young 12 Amand
Loyal Turks 13 Louis
7 Opposition 14 Albert
9 15Ramuald
17 16 Peter
13
8 8 11
3
17 Gregory
1 14 4 2 18 Hugh
16 10 18
9
10
11 (a) (b)
12
13 Figure 30.5: (a) Who-likes-whom graph for Sampson’s monks. (b) Mixed membership of each monk in one of
14 three groups. From Figures 2-3 of [Airoldi08]. Used with kind permission of Edo Airoldi.
15
16
17 we can compute p(π i |D).
18 In more detail, the generative process is as follows. First, each node picks a distribution over
19 blocks, π i ∼ Dir(α). Second, choose the probability of connecting group a to group b, for all pairs of
20 groups, ηa,b ∼ β(α, β). Third, for each edge, sample two discrete variables, one for each direction:
21
qi→j ∼ Cat(π i ), qi←j ∼ Cat(π j ) (30.3)
22
23 Finally, generate each edge Rij using the following model:
24
25 p(Rij = 1|qi→j = a, qi←j = b, η) = ηa,b (30.4)
26
See Figure 30.4(b) for the DGM.
27
Unlike the regular stochastic block model, each node can play a different role, depending on who it
28
is connecting to. As an illustration of this, we will consider a dataset that is widely used in the social
29
networks analysis literature. The data concerns who-likes-whom amongst of group of 18 monks. It
30
was collected by hand in 1968 by Sampson [Sampson68] over a period of months. (These days, in
31
the era of social media such as Facebook, a social network with only 18 people is trivially small,
32
but the methods we are discussing can be made to scale.) Figure 30.5(a) plots the raw data, and
33
Figure 30.5(b) plots E [π]i for each monk, where K = 3. We see that most of the monks belong
34
to one of the three clusters, known as the “young turks”, the “outcasts” and the “loyal opposition”.
35
However, some individuals, notably monk 15, belong to two clusters; Sampson called these monks the
36
“waverers”. It is interesting to see that the model can recover the same kinds of insights as Sampson
37
derived by hand.
38
One prevalent problem in social network analysis is missing data. For example, if Rij = 0, it may
39
be due to the fact that person i and j have not had an opportunity to interact, or that data is not
40
available for that interaction, as opposed to the fact that these people don’t want to interact. In
41
other words, absence of evidence is not evidence of absence. We can model this by modifying the
42
observation model so that with probability ρ, we generate a 0 from the background model, and we
43
only force the model to explain observed 0s with probability 1 − ρ. In other words, we robustify the
44
observation model to allow for outliers, as follows:
45
46 p(Rij = r|qi→j = a, qi←j = b, η) = ρδ0 (r) + (1 − ρ)Ber(r|ηa,b ) (30.5)
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
30.1. LATENT VARIABLE MODELS FOR GRAPHS

1
2 See [Airoldi08] for details.
3
4 30.1.3 Infinite relational model
5
6 The stochastic block model is defined for graphs, in which each pair of edges may or may not have
7 an edge. We can easily extend this to hyper-graphs, which is useful for modeling relational data. For
8 example, suppose we want to model a family tree. We might write R1 (i, j, k) = 1 if adults i and j
9 are the parents of child k, where R1 is the “parent-of” relation. Here i and j are entities of type T 1
10 (adults), and j is an entity of type T 2 (child), so the type signature of R1 is T 1 × T 1 × T 2 → {0, 1}.
11 To define the probability of relations holding between entities, we can associate a latent cluster
12 variable qit ∈ {1, . . . , Kt } with each entity i of each type t. We then define the probability of the
13 relation holding between specific entities by looking up the probability of the relation holding between
14 the corresponding entity clusters. Continuing our example above, we have
15
p(R1 (i, j, k)|qi1 = a, qj1 = b, qk2 = c, η) = Ber(R1 (i, j, k)|ηa,b,c ) (30.6)
16
17 We can also have real-valued relations, where each edge has a weight. For example, we can write
18
19 p(R1 (i, j, k)|qi1 = a, qj1 = b, qk2 = c, µ) = N (R1 (i, j, k)|µa,b,c , σ 2 ), (30.7)
20
21 where µa,b,c captures the average response for that group of clusters. We can also add entity-specific
22 offset terms:
23
24
p(R1 (i, j, k)|qi1 = a, qj1 = b, qk2 = c, µ) = N (R1 (i, j, k)|µa,b,c + µi + µj + µk , σ 2 ), (30.8)
25
This model was proposed in [Banerjee07], who fit the model using an alternating minimization
26
procedure.
27
If we allow the number of clusters Kt for each type of entity to be unbounded, by using a Dirichlet
28
process, the model is called the infinite relational model (IRM) [Kemp06], also known as an
29
infinite hidden relational model (IHRM) [Xu06]. We can fit this model with variational Bayes
30
[Xu06; Xu07] or collapsed Gibbs sampling [Kemp06]. Rather than go into algorithmic detail, we
31
just sketch some interesting applications.
32
33
30.1.3.1 Learning ontologies
34
35 An ontology refers to an organisation of knowledge. In AI, ontologies are often built by hand (see
36 e.g., [Russell10]), but it is interesting to try and learn them from data. In [Kemp06], they show
37 how this can be done using the IRM.
38 The data comes from the Unified Medical Language System [McCray03], which defines a semantic
39 network with 135 concepts (such as “disease or syndrome”, “diagnostic procedure”, “animal”), and
40 49 binary predicates (such as “affects”, “prevents”). We can represent this as a ternary relation
41 R : T 1 × T 1 × T 2 → {0, 1}, where T 1 is the set of concepts and T 2 is the set of binary predicates.
42 The result is a 3d cube. We can then apply the IRM to partition the cube into regions of roughly
43 homogeneous response. The system found 14 concept clusters and 21 predicate clusters. Some of these
44 are shown in Figure 30.6. The system learns, for example, that biological functions affect organisms
45 (since ηa,b,c ≈ 1 where a represents the biological function cluster, b represents the organism cluster,
46 and c represents the affects cluster).
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


228

1
interact with
2 CHEMICALS BIOïACTIVE SUBSTANCES
3
affects,
4 affects causes, affects,
affects, causes,
5 causes complicates complicates complicates
causes
disrupts
6 affects, affects,
complicates, affects, process of, process of,
7 result of, manifestation of result of
process of,
8 manifestation of DISEASES result of BIOLOGICAL FUNCTIONS
9
result of result of
10 ABNORMALITIES
11
affects,
manifestation of
12 process of manifestation of,
associated with
13
affects, manifestation of,
14
process of associated with
15 ORGANISMS SIGNS
16
17 Figure 30.6: Illustration of an ontology learned by IRM applied to the Unified Medical Language System. The
18 boxes represent 7 of the 14 concept clusters. Predicates that belong to the same cluster are grouped together,
19 and associated with edges to which they pertain. All links with weight above 0.8 have been included. From
20 Figure 9 of [Kemp10]. Used with kind permission of Charles Kemp.
21
22
23
30.1.3.2 Clustering based on relations and features
24
25 We can also use IRM to cluster objects based on their relations and their features. For example,
26 [Kemp06] consider a political dataset (from 1965) consisting of 14 countries, 54 binary predicates
27 representing interaction types between countries (e.g., “sends tourists to”, “economic aid”), and 90
28 features (e.g., “communist”, “monarchy”). To create a binary dataset, real-valued features were
29 thresholded at their mean, and categorical variables were dummy-encoded. The data has 3 types: T 1
30 represents countries, T 2 represents interactions, and T 3 represents features. We have two relations:
31 R1 : T 1 × T 1 × T 2 → {0, 1}, and R2 : T 1 × T 3 → {0, 1}. (This problem therefore combines aspects
32 of both the biclustering model and the ontology discovery model.) When given multiple relations,
33 the IRM treats them as conditionally independent. In this case, we have
34
35 p(R1 , R2 |q 1 , q 2 , q 3 , θ) = p(R1 |q 1 , q 2 , θ)p(R2 |q 1 , q 3 , θ) (30.9)
36
37
The results are shown in Figure 30.7. The IRM divides the 90 features into 5 clusters, the first of
38
which contains “noncommunist”, which captures one of the most important aspects of this Cold-War
39
era dataset. It also clusters the 14 countries into 5 clusters, reflecting natural geo-political groupings
40
(e.g., USA and UK, or the Communist Bloc), and the 54 predicates into 18 clusters, reflecting similar
41
relationships (e.g., “negative behavior and “accusations”).
42
43 30.2 Learning tree structures
44
45 Since the problem of structure learning for general graphs is NP-hard [Chickering96np], we start by
46 considering the special case of trees. Trees are special because we can learn their structure efficiently,
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
30.2. LEARNING TREE STRUCTURES

intervening military
constitutional govt
2

domestic violence

energy consumed
military personnel

some censorship
a)

no free elections

govt education $

foreign mail sent


seaborne goods

foreign students
high censorship
communist bloc

num languages
female workers
religious books

assassinations
noncommunist

govt revolution

calories in diet
UN delinquent

railroad length
age of country

protein in diet
free elections

popn. density
num religions
3

exports/GNP

US aid taken
western bloc

unemployed
communists

investments
neutral bloc
far from US

road length
aid $ taken

arts NGOs
totalitarian

population

law NGOs
govt crisis

telephone
defense $

monarchy

emigrants
land area
Catholics
4

illiteracy

protests
purges

threats
rainfall

arable
elitist

GNP
5
6 Brazil
Netherlands
7 UK
USA
8 Burma
Indonesia
Jordan
9 Egypt
India
10 Israel
China
11 Cuba
Poland
USSR
12
13 b) c) d) joint joint
military sends exports membership membership
14 alliance tourists to books to exports to treaties conferences of IGOs of NGOs

15
16
17
e) f) g) h) i)
18 negative negative book economic common bloc
behavior communications accusations protests translations aid emigration membership
19
20
21
22
23 Figure 30.7: Illustration of IRM applied to some political data containing features and pairwise interactions.
24 Top row (a): the partition of the countries into 5 clusters and the features into 5 clusters. Every second
25 column is labelled with the name of the corresponding feature. Small squares at bottom (b-i): these are 8 of the
26
18 clusters of interaction types. From Figure 6 of [Kemp06]. Used with kind permission of Charles Kemp.
27
28
29 as we discuss below, and because, once we have learned the tree, we can use them for efficient exact
30 inference, as discussed in Main Section 9.3.2.
31
32 30.2.1 Chow-Liu algorithm
33
34
In this section, we consider undirected trees with pairwise potentials. The likelihood can be represented
35
as follows:
36 Y Y p(xs , xt )
37
p(x|T ) = p(xt ) (30.10)
p(xs )p(xt )
t∈V (s,t)∈E
38
39
where p(xs , xt ) is an edge marginal and p(xt ) is a node marginal. Hence we can write the log-likelihood
40
for a tree as follows:
41
XX
42
log p(D|θ, T ) = Ntk log p(xt = k|θ)
43 t k
44 XX p(xs = j, xt = k|θ)
45 + Nstjk log (30.11)
s,t j,k
p(xs = j|θ)p(xt = k|θ)
46
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


230

1
2
3
4
5
6
7
8
9
10
11
12
13
14
Figure 30.8: The MLE tree estimated from the 20-newsgroup data. Generated by chow_liu_tree_demo.ipynb.
15
16
17
18
19 where Nstjk is the number of times node s is in state j and node t is in state k, and Ntk is the number
20 of times node t is in state k. We can rewrite these counts in terms of the empirical distribution:
21 Nstjk = N pD (xs = j, xt = k) and Ntk = N pD (xt = k). Setting θ to the MLEs, this becomes
22
log p(D|θ, T ) X X
23 = pD (xt = k) log pD (xt = k) (30.12)
24 N
t∈V k
25 X
26
+ I(xs , xt |θ̂st ) (30.13)
(s,t)∈E(T )
27
28
29 where I(xs , xt |θ̂st ) ≥ 0 is the mutual information between xs and xt given the empirical distribution:
30
31
XX pD (xs = j, xt = k)
32 I(xs , xt |θ̂st ) = pD (xs = j, xt = k) log (30.14)
33 j
pD (xs = j)pD (xt = k)
k
34
35 Since the first term in Equation (30.13) is independent of the topology T , we can ignore it when
36 learning structure. Thus the tree topology that maximizes the likelihood can be found by computing
37 the maximum weight spanning tree, where the edge weights are the pairwise mutual informations,
38 I(ys , yt |θ̂st ). This is called the Chow-Liu algorithm [Chow68].
39 There are several algorithms for finding a max spanning tree (MST). The two best known are
40 Prim’s algorithm and Kruskal’s algorithm. Both can be implemented to run in O(E log V ) time,
41 where E = V 2 is the number of edges and V is the number of nodes. See e.g., [Sedgewick11] for
42 details. Thus the overall running time is O(N V 2 + V 2 log V ), where the first term is the cost of
43 computing the sufficient statistics.
44 Figure 30.8 gives an example of the method in action, applied to the binary 20 newsgroups data
45 shown in Main Figure 5.8. The tree has been arbitrarily rooted at the node representing “email”.
46 The connections that are learned seem intuitively reasonable.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
30.2. LEARNING TREE STRUCTURES

1
2 30.2.2 Finding the MAP forest
3
Since all trees have the same number of parameters, we can safely use the maximum likelihood score
4
as a model selection criterion without worrying about overfitting. However, sometimes we may want
5
to fit a forest rather than a single tree, since inference in a forest is much faster than in a tree (we
6
can run belief propagation in each tree in the forest in parallel). The MLE criterion will never choose
7
to omit an edge. However, if we use the marginal likelihood or a penalized likelihood (such as BIC),
8
the optimal solution may be a forest. Below we give the details for the marginal likelihood case.
9
In Section 30.3.3.2, we explain how to compute the marginal likelihood of any DAG using a
10
Dirichlet prior for the CPTs. The resulting expression can be written as follows:
11
12 X Z Y
N X
13 log p(D|T ) = log p(xit |xi,pa(t) |θt )p(θt )dθt = score(Nt,pa(t) ) (30.15)
14 t∈V i=1 t
15
where Nt,pa(t) are the counts (sufficient statistics) for node t and its parents, and score is defined in
16
Equation (30.28).
17
Now suppose we only allow DAGs with at most one parent. Following [Heckerman95c], let us
18
associate a weight with each s → t edge, ws,t , score(t|s) − score(t|0), where score(t|0) is the score
19
when t has no parents. Note that the weights might be negative (unlike the MLE case, where edge
20
weights are aways non-negative because they correspond to mutual information). Then we can rewrite
21
the objective as follows:
22
23
X X X
log p(D|T ) = score(t|pa(t)) = wpa(t),t + score(t|0) (30.16)
24
t t t
25
26 The last term is the same for all trees T , so we can ignore it. Thus finding the most probable tree
27 amounts to finding a maximal branching in the corresponding weighted directed graph. This can
28 be found using the algorithm in [Gabow84].
29 If the scoring function is prior and likelihood equivalent (these terms are explained in Sec-
30 tion 30.3.3.3), we have
31
32
score(s|t) + score(t|0) = score(t|s) + score(s|0) (30.17)
33
and hence the weight matrix is symmetric. In this case, the maximal branching is the same as the
34
maximal weight forest. We can apply a slightly modified version of the MST algorithm to find this
35
[Edwards10]. To see this, let G = (V, E) be a graph with both positive and negative edge weights.
36
Now let G0 be a graph obtained by omitting all the negative edges from G. This cannot reduce the
37
total weight, so we can find the maximum weight forest of G by finding the MST for each connected
38
component of G0 . We can do this by running Kruskal’s algorithm directly on G0 : there is no need to
39
find the connected components explicitly.
40
41
42
30.2.3 Mixtures of trees
43 A single tree is rather limited in its expressive power. Later in this chapter we discuss ways to
44 learn more general graphs. However, the resulting graphs can be expensive to do inference in. An
45 interesting alternative is to learn a mixture of trees [Meila00b], where each mixture component
46 may have a different tree topology. This is like an unsupervised version of the TAN classifier discussed
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


232

1
2
3
4
5
6
7
8
Figure 30.9: A simple linear Gaussian model. .
9
10
11
12
in Main Section 4.2.8.3. We can fit a mixture of trees by using EM: in the E step, we compute the
13
responsibilities of each cluster for each data point, and in the M step, we use a weighted version of
14
the Chow-Liu algorithm. See [Meila00b] for details.
15
In fact, it is possible to create an “infinite mixture of trees”, by integrating out over all possible
16
trees. Remarkably, this can be done in NG 3 time using the matrix tree theorem. This allows us to
17
perform exact Bayesian inference of posterior edge marginals etc. However, it is not tractable to use
18
this infinite mixture for inference of hidden nodes. See [Meila06] for details.
19
20 30.3 Learning DAG structures
21
22 In this section, we discuss how to estimate the structure of directed graphical models from observational
23 data. This is often called Bayes net structure learning. We can only do this if we make the
24 faithfulness assumption, which we explain in Section 30.3.1. Furthermore our output will be a
25 set of equivalent DAGs, rather than a single unique DAG, as we explain in Section 30.3.2. After
26 introducing these restrictions, we discuss some statistical and algorithmic techniques. If the DAG
27 is interpreted causal, these techniques can be used for causal discovery, although this relies on
28 additional assumptions about non-confounding. For more details, see e.g., [Glymour2019].
29
30 30.3.1 Faithfulness
31
32 The Markov assumption allows us to infer CI properties of a distribution p from a graph G. To go
33 in the opposite direction, we need to assume that the generating distribution p is faithful to the
34 generating DAG G. This means that all the conditional indepence (CI) properties of p are exactly
35 captured by the graphical structure, so I(p) = I(G); this means there cannot be any CI properties in
36 p that are due to particular settings of the parameters (such as zeros in a regression matrix) that are
37 not graphically explicit. (For this reason, a faithful distribution is also called a stable distribution.)
38 Let us consider an example of a non-faithful distribution (from [Peters2017]). Consider a linear
39 Gaussian model of the form
40 2
X = EX , EX ∼ N (0, σX ) (30.18)
41
42 Y = aX + EY , EY ∼ N (0, σY2 ) (30.19)
43 Z = bY + cX + EZ , EZ ∼ 2
N (0, σZ ) (30.20)
44
45 where the error terms are independent. If ab + c = 0, then X ⊥ Z, even though this is not implied
46 by the DAG in Figure 30.9. Fortunately, this kind of accidental cancellation happens with zero
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
30.3. LEARNING DAG STRUCTURES

1
2
3 G1 G2 G3
4
X1 X3 X1 X3 X1 X3
5
6
7
8 X2 X2 X2
9 X5 X5 X5
10
11
X4 X4 X4
12
13
14
15
Figure 30.10: Three DAGs. G1 and G3 are Markov equivalent, G2 is not.
16
17
18
probability if the coefficients are drawn randomly from positive densities [SpirtesBook].
19
20 30.3.2 Markov equivalence
21
Even with the faithfulness assumption, we cannot always uniquely identify a DAG from a joint
22
distribution. To see this, consider the following 3 DGMs: X → Y → Z, X ← Y ← Z and
23
X ← Y → Z. These all represent the same set of CI statements, namely
24
25 X ⊥ Z|Y, X 6⊥ Z (30.21)
26
27 We say these graphs are Markov equivalent, since they encode the same set of CI assumptions.
28 That is, they all belong to the same Markov equivalence class. However, the DAG X → Y ← Z
29 encodes X ⊥ Z and X 6⊥ Z|Y , so corresponds to a different distribution.
30 In [Verma90], they prove the following theorem.
31
Theorem 30.3.1. Two structures are Markov equivalent iff they have the same skeleton, i.e., the
32
have the same edges (disregarding direction) and they have the same set of v-structures (colliders
33
whose parents are not adjacent).
34
35 For example, referring to Figure 30.10, we see that G1 6≡ G2 , since reversing the 2 → 4 arc creates a
36 new v-structure. However, G1 ≡ G3 , since reversing the 1 → 5 arc does not create a new v-structure.
37 We can represent a Markov equivalence class using a single partially directed acyclic graph
38 or PDAG (also called an essential graph or pattern), in which some edges are directed and
39 some undirected (see Main Section 4.5.4.1). The undirected edges represent reversible edges; any
40 combination is possible so long as no new v-structures are created. The directed edges are called
41 compelled edges, since changing their orientation would change the v-structures and hence change
42 the equivalence class. For example, the PDAG X − Y − Z represents {X → Y → Z, X ← Y ←
43 Z, X ← Y → Z} which encodes X 6⊥ Z and X ⊥ Z|Y . See Figure 30.10 for another example.
44 The significance of the above theorem is that, when we learn the DAG structure from data, we
45 will not be able to uniquely identify all of the edge directions, even given an infinite amount of
46 data. We say that we can learn DAG structure “up to Markov equivalence”. This also cautions us
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


234

1
2
X X X X X X
3
4 Y Y Y ≡ Y Y ≡ Y
5
Z Z Z Z Z Z
6
7
8
Figure 30.11: PDAG representation of Markov equivalent DAGs.
9
10
11
12 not to read too much into the meaning of particular edge orientations, since we can often change
13 them without changing the model in any observable way. (If we want to distinguish between edge
14 orientations within a PDAG (e.g., if we want to imbue a causal interpretation on the edges), we can
15 use interventional data, as we discuss in Section 30.5.2.)
16
17 30.3.3 Bayesian model selection: statistical foundations
18
In this section, we discuss how to compute the exact posterior over graphs, p(G|D), ignoring for now
19
the issue of computational tractability. We assume there is no missing data, and that there are no
20
hidden variables. This is called the complete data assumption.
21
For simplicity, we will focus on the case where all the variables are categorical and all the CPDs
22
are tables. Our presentation is based in part on [Heckerman95c], although we will follow the
23
notation of Main Section 4.2.7.3. In particular, let xit ∈ {1, . . . , Kt } be the value of node t in case
24
i, where Kt is the number of states for node t. Let θtck , p(xt = k|xpa(t) = c), for k = 1 : Kt ,
25
and c = 1 : Ct , where Ct is the number of parent combinations (possible conditioning cases). For
26
notational simplicity, we will often assume Kt = K, so all nodes have the same number of states. We
27
will also let dt = dim(pa(t)) be the degree or fan-in of node t, so that Ct = K dt .
28
29
30 30.3.3.1 Deriving the likelihood
31
Assuming there is no missing data, and that all CPDs are tabular, the likelihood can be written as
32
follows:
33
34 Y NG
N Y

35 p(D|G, θ) = Cat(xit |xi,pa(t) , θt ) (30.22)


36 i=1 t=1
NG Y
N Y
Y Ct Y
Kt NG Y
Y Ct Y
Kt
37 I(xi,t =k,xi,pa(t) =c) Ntck
38
= θtck = θtck (30.23)
i=1 t=1 c=1 k=1 t=1 c=1 k=1
39
40
where Ntck is the number of times node t is in state k and its parents are in state c. (Technically
41
these counts depend on the graph structure G, but we drop this from the notation.)
42
43
30.3.3.2 Deriving the marginal likelihood
44
45 Choosing the graph with the maximum likelihood will always pick a fully connected graph (subject to
46 the acyclicity constraint), since this maximizes the number of parameters. To avoid such overfitting,
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
30.3. LEARNING DAG STRUCTURES

1
2 we will choose the graph with the maximum marginal likelihood, p(D|G), where we integrate out the
3 parameters; the magic of the Bayesian Occam’s razor (Main Section 3.8.1) will then penalize overly
4 complex graphs.
5 To compute the marginal likelihood, we need to specify priors on the parameters. We will make two
6 standardQassumptions. First, we assume global prior parameter independence, which means
NG
7 p(θ) = t=1 p(θt ). Second, we assume local prior parameter independence, which means
QCt
8 p(θt ) = c=1 p(θtc ) for each t. It turns out that these assumtions imply that the prior for each row of
9 each CPT must be a Dirichlet [Geiger97], that is, p(θtc ) = Dir(θtc |αtc ). Given these assumptions,
10 and using the results of Main Equation (3.94), we can write down the marginal likelihood of any
11 DAG as follows:
12  
Y Ct Z
NG Y Y
13
p(D|G) =  Cat(xit |θtc ) Dir(θtc )dθtc (30.24)
14
15 t=1 c=1 i:xi,pa(t) =c

16 NG Y
Y Ct
B(Ntc + αtc )
17 = (30.25)
t=1 c=1
B(αtc )
18
19 NG Y
Y Ct YKt G
Γ(Ntc ) Γ(Ntck + αtck )
20 = G )
(30.26)
t=1 c=1
Γ(Ntc + α tc ) Γ(α tck
21 k=1

22 NG
Y
23 = score(Nt,pa(t) ) (30.27)
24 t=1

25 P P
where Ntc = k Ntck , αtc = k αtck , Nt,pa(t) is the vector of counts (sufficient statistics) for node t
26
and its parents, and score() is a local scoring function defined by
27
28 YCt
B(Ntc + αtc )
29 score(Nt,pa(t) ) , (30.28)
30 c=1
B(αtc )
31
32
We say that the marginal likelihood decomposes or factorizes according to the graph structure.
33
34 30.3.3.3 Setting the prior
35
How should we set the hyper-parameters αtck ? It is tempting to use a Jeffreys prior of the form
36
αtck = 12 (Main Equation (3.233)). However, it turns out that this violates a property called
37
likelihood equivalence, which is sometimes considered desirable. This property says that if G1
38
and G2 are Markov equivalent (Section 30.3.2), they should have the same marginal likelihood, since
39
they are essentially equivalent models. Geiger97 proved that, for complete graphs, the only prior
40
that satisfies likelihood equivalence and parameter independence is the Dirichlet prior, where the
41
pseudo counts have the form
42
43 αtck = α p0 (xt = k, xpa(t) = c) (30.29)
44
45 where α > 0 is called the equivalent sample size, and p0 is some prior joint probability distribution.
46 This is called the BDe prior, which stands for Bayesian Dirichlet likelihood equivalent.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


236

1
2
3
4
5
6
7
8
9
10
11
12
13
Figure 30.12: The two most probable DAGs learned from the Sewell-Shah data. From [Heckerman97]. Used
14 with kind permission of David Heckerman
15
16
17
18 To derive the hyper-parameters for other graph structures, Geiger97 invoked an additional
19 assumption called parameter modularity, which says that if node Xt has the same parents in G1
20 and G2 , then p(θt |G1 ) = p(θt |G2 ). With this assumption, we can always derive αt for a node t in
21 any other graph by marginalizing the pseudo counts in Equation (30.29).
22 Typically the prior distribution p0 is assumed to be uniform over all possible joint configurations.
23 In this case, we have αtck = KtαCt , since p0 (xt = k, xpa(t) = c) = Kt1Ct . Thus if we sum the pseudo
24 counts over all Ct × Kt entries in the CPT, we get a total equivalent sample size of α. This is called
25 the BDeu prior, where the “u” stands for uniform. This is the most widely used prior for learning
26 Bayes net structures. For advice on setting the global tuning parameter α, see [Silander07].
27
28
30.3.3.4 Example: analysis of the college plans dataset
29
30 We now consider a larger example from [Heckerman97], who analyzed a dataset of 5 variables,
31 related to the decision of high school students about whether to attend college. Specifically, the
32 variables are as follows:
33 • Sex: Male or female
34 • SES: Socio economic status: low, lower middle, upper middle or high.
35 • IQ: Intelligence quotient: discretized into low, lower middle, upper middle or high.
36 • PE: Parental encouragment: low or high
37 • CP: College plans: yes or no.
38 These variables were measured for 10,318 Wisconsin high school seniors. There are 2 × 4 × 4 × 2× =
39 128 possible joint configurations.
40 Heckerman et al. computed the exact posterior over all 29,281 possible 5 node DAGs, except for
41 ones in which SEX and/or SES have parents, and/or CP have children. (The prior probability of
42 these graphs was set to 0, based on domain knowledge.) They used the BDeu score with α = 5,
43 although they said that the results were robust to any α in the range 3 to 40. The top two graphs are
44 shown in Figure 30.12. We see that the most probable one has approximately all of the probability
45 mass, so the posterior is extremely peaked.
46 It is tempting to interpret this graph in terms of causality (see Main Chapter 36 for a detailed
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
30.3. LEARNING DAG STRUCTURES

1
2 discussion of this topic). In particular, it seems that socio-economic status, IQ and parental
3 encouragment all causally influence the decision about whether to go to college, which makes sense.
4 Also, sex influences college plans only indirectly through parental encouragement, which also makes
5 sense. However, the direct link from socio economic status to IQ seems surprising; this may be due to
6 a hidden common cause. In Section 30.3.8.5 we will re-examine this dataset allowing for the presence
7 of hidden variables.
8
9
30.3.3.5 Marginal likelihood for non-tabular CPDs
10
11 If all CPDs are linear Gaussian, we can replace the Dirichlet-multinomial model with the normal-
12 gamma model, and thus derive a different exact expression for the marginal likelihood. See [Geiger94]
13 for the details. In fact, we can easily combine discrete nodes and Gaussian nodes, as long as the
14 discrete nodes always have discrete parents; this is called a conditional Gaussian DAG. Again, we
15 can compute the marginal likelihood in closed form. See [Bottcher03] for the details.
16 In the general case (i.e., everything except Gaussians and CPTs), we need to approximate the
17 marginal likelihood. The simplest approach is to use the BIC approximation, which has the form
18
19 X Kt Ct
log p(Dt |θ̂t ) − log N (30.30)
20
t
2
21
22
23 30.3.4 Bayesian model selection: algorithms
24
25
In this section, we discuss some algorithms for approximately computing the mode of (or samples
26
from) the posterior p(G|D).
27
28 30.3.4.1 The K2 algorithm for known node orderings
29
30 Suppose we know a total ordering of the nodes. Then we can compute the distribution over parents
31 for each node independently, without the risk of introducing any directed cycles: we simply enumerate
32 over all possible subsets of ancestors and compute their marginal likelihoods. If we just return the
33 best set of parents for each node, we get the the K2 algorithm [Cooper92]. In this case, we can
34 find the best set of parents for each node using `1 -regularization, as shown in [Schmidt07aaai].
35
36
30.3.4.2 Dynamic programming algorithms
37
38 In general, the ordering of the nodes is not known, so the posterior does not decompose. Nevertheless,
39 we can use dynamic programming to find the globally optimal MAP DAG (up to Markov equivalence),
40 as shown in [Koivisto04; Silander06].
41 If our goal is knowledge discovery, the MAP DAG can be misleading, for reasons we discussed
42 in Main Section 7.4.1. A better approach is to compute the marginal probability that each edge is
43 present, p(Gst = 1|D). We can also compute these quantities using dynamic programming, as shown
44 in [Koivisto06; Parviainen11ancestor].
45 Unfortunately, all of these methods take NG 2NG time in the general case, making them intractable
46 for graphs with more than about 16 nodes.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


238

1
2 30.3.4.3 Scaling up to larger graphs
3
The main challenge in computing the posterior over DAGs is that there are so many possible graphs.
4
More precisely, [Robinson73] showed that the number of DAGs on D nodes satisfies the following
5
recurrence:
6
7 XD  
i+1 D
8 f (D) = (−1) 2i(D−i) f (D − i) (30.31)
i=1
i
9
10
for D > 2. The base case is f (1) = 1. Solving this recurrence yields the following sequence: 1, 3, 25,
11
543, 29281, 3781503, etc.1
12
Indeed, the general problem of finding the globally optimal MAP DAG is provably NP-complete
13
[Chickering96np]. In view of the enormous size of the hypothesis space, we are generally forced to
14
use approximate methods, some of which we review below.
15
16
30.3.4.4 Hill climbing methods for approximating the mode
17
18 A common way to find an approximate MAP graph structure is to use a greedy hill climbing method.
19 At each step, the algorithm proposes small changes to the current graph, such as adding, deleting or
20 reversing a single edge; it then moves to the neighboring graph which most increases the posterior.
21 The method stops when it reaches a local maximum. It is important that the method only proposes
22 local changes to the graph, since this enables the change in marginal likelihood (and hence the
23 posterior) to be computed in constant time (assuming we cache the sufficient statistics). This is
24 because all but one or two of the terms in Equation (30.25) will cancel out when computing the log
25 Bayes factor δ(G → G0 ) = log p(G0 |D) − log p(G|D).
26 We can initialize the search from the best tree, which can be found using exact methods discussed
27 in Section 30.2.1. For speed, we can restrict the search so it only adds edges which are part of the
28 Markov blankets estimated from a dependency network [SchmidtThesis]. Figure 30.13 gives an
29 example of a DAG learned in this way from the 20-newsgroup data. For binary data, it is possible to
30 use techniques from frequent itemset mining to find good Markov blanket candidates, as described in
31 [Goldenberg04].
32 We can use techniques such as multiple random restarts to increase the chance of finding a good
33 local maximum. We can also use more sophisticated local search methods, such as genetic algorithms
34 or simulated annealing, for structure learning. (See also Section 30.3.6 for gradient based techniques
35 based on continuous relaxations.)
36 It is also possible to perform the greedy search in the space of PDAGs instead of in the space of
37 DAGs; this is known as the greedy equivalence search method [Chickering02opt]. Although
38 each step is somewhat more complicated, the advantage is that the search space is smaller.
39
40 30.3.4.5 Sampling methods
41
42 If our goal is knowledge discovery, the MAP DAG can be misleading, for reasons we discussed in
43 Main Section 7.4.1. A better approach is to compute the probability that each edge is present, p(Gst =
44
1. A longer list of values can be found at http://www.research.att.com/~njas/sequences/A003024. Interestingly,
45 the number of DAGs is equal to the number of (0,1) matrices all of whose eigenvalues are positive real numbers
46 [McKay04].
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
30.3. LEARNING DAG STRUCTURES

1
evidence case course question
2
msg fact drive
3
4 god nasa scsi

5
gun christian shuttle disk

6
government religion jesus car disease mission space
7
8 law jews engine patients orbit games program

9 rights power bible honda computer bmw medicine earth solar season launch technology dos

10
dealer science moon system team satellite files
11
problem studies mars lunar players version
12
13 human hockey hit windows

14 israel university nhl puck baseball won email memory ftp

15
president war state research league fans win phone format video mac
16
17 children world oil cancer number image data driver software

18 water health pc

19
food aids insurance doctor card
20
help server graphics
21
22 vitamin display

23
Figure 30.13: A locally optimal DAG learned from the 20-newsgroup data. From Figure 4.10 of
24
[SchmidtThesis]. Used with kind permission of Mark Schmidt.
25
26
27
28 1|D). We can do this exactly using dynamic programming [Koivisto06; Parviainen11ancestor],
29 although this can be expensive. An approximate method is to sample DAGs from the posterior, and
30 then to compute the fraction of times there is an s → t edge or path for each (s, t) pair. The standard
31 way to draw samples is to use the Metropolis Hastings algorithm (Main Section 12.2), where we use
32 the same local proposal as we did in greedy search [Madigan94].
33 A faster-mixing method is to use a collapsed MH sampler, as suggested in [Friedman03nir]. This
34 exploits the fact that, if a total ordering of the nodes is known, we can select the parents for each node
35 independently, without worrying about cycles, as discussed in Section 30.3.4.1. By summing over all
36 possible choice of parents, we can marginalize out this part of the problem, and just sample total
37 orders. [Ellis08] also use order-space (collapsed) MCMC, but this time with a parallel tempering
38 MCMC algorithm.
39
40
30.3.5 Constraint-based approach
41
42 We now present an approach to learning a DAG structure — up to Markov equivalence (the output
43 of the method is a PDAG) — that uses local conditional independence tests, rather than scoring
44 models globally with a likelihood. The CI tests are combined together to infer the global graph
45 structure, so this approach is called constraint-based. The advantage of CI testing is that it is
46 more local and does not require specifying a complete model. (However, the form of the CI test
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


240

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
Figure 30.14: The 3 rules for inferring compelled edges in PDAGs. Adapted from [Peer05].
21
22
23
24 implicitly relies on assumptions, see e.g., [Shah2018].)
25
26 30.3.5.1 IC algorithm
27
28
The original algorithm, due to Verma and Pearl [Verma90], was called the IC algorithm, which
29
stands for “inductive causation”. The method is as follows [PearlBook]:
30
1. For each pair of variables a and b, search for a set Sab such that a ⊥ b|Sab . Construct an undirected
31
graph such that a and b are connected iff no such set Sab can be found (i.e., they cannot be made
32
conditionally independent).
33
34 2. Orient the edges involved in v-structures as follows: for each pair of nonadjacent nodes a and b
35 with a common neighbor c, check if c ∈ Sab ; if it is, the corresponding DAG must be a → c → b,
36 a ← c → b or a ← c ← b, so we cannot determine the direction; if it is not, the DAG must be
37 a → c ← b, so add these arrows to the graph.
38
39 3. In the partially directed graph that results, orient as many of the undirected edges as possible,
40 subject to two conditions: (1) the orientation should not create a new v-structure (since that
41 would have been detected already if it existed), and (2) the orientation should not create a directed
42 cycle. More precisely, follow the rules shown in Figure 30.14. In the first case, if X → Y has a
43 known orientation, but Y − Z is unknown, then we must have Y → Z, otherwise we would have
44 created a new v-structure X → Y ← Z, which is not allowed. The other two cases follow similar
45 reasoning.
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
30.3. LEARNING DAG STRUCTURES

1
2 C C

3
A B E A B E
4
5
D D
6
True Graph Complete Undirected Graph
7
8 n=0 No zero order independencies
9
n=1 First order independencies Resulting adjacencies
10
C
11
A⊥
⊥C|B A⊥
⊥D|B
12 A B E

13 A⊥
⊥E|B C⊥
⊥D|B
D
14
15 n=2 Second order independencies Resulting adjacencies
C
16
17 B⊥
⊥ E | {C,D}
A B E
18
D
19
20
21
Figure 30.15: Example of step 1 of the PC algorithm. Adapted from Figure 5.1 of [SpirtesBook].
22
23
24 30.3.5.2 PC algorithm
25
A significant speedup of IC, known as the PC algorithm after is creators Peter Spirtes and Clark
26
Glymour [Spirtes91], can be obtained by ordering the search for separating sets in step 1 in terms
27
of sets of increasing cardinality. We start with a fully connected graph, and then look for sets Sab of
28
size 0, then of size 1, and so on; as soon we find a separating set, we remove the corresponding edge.
29
See Figure 30.15 for an example.
30
Another variant on the PC algorithm is to learn the original undirected structure (i.e., the Markov
31
blanket of each node) using generic variable selection techniques instead of CI tests. This tends to be
32
more robust, since it avoids issues of statisical significance that can arise with independence tests.
33
See [Pellet08] for details.
34
The running time of the PC algorithm is O(DK+1 ) [SpirtesBook], where D is the number of nodes
35
and K is the maximal degree (number of neighbors) of any node in the corresponding undirected
36
graph.
37
38
30.3.5.3 Frequentist vs Bayesian methods
39
40 The IC/PC algorithm relies on an oracle that can test for conditional independence between any
41 set of variables, A ⊥ B|C. This can be approximated using hypothesis testing methods applied to a
42 finite data set, such as chi-squared tests for discrete data. However, such methods work poorly with
43 small sample sizes, and can run into problems with multiple testing (since so many hypotheses are
44 being compared). In addition, errors made at any given step can lead to an incorrect final result, as
45 erroneous constraints get propagated. In practice it is a common to use a hybrid approach, where we
46 use IC/PC to create an initial structure, and then use this to speed up Bayesian model selection,
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


242

1
2 which tends to be more robust, since it avoids any hard decisions about conditional independence or
3 lack thereof.
4
5 30.3.6 Methods based on sparse optimization
6
7
There is a 1:1 connection between sparse graphs and sparse adjacency matrices. This suggests that
8
we can perform structure learning by using continuous optimization methods that enforce sparsity,
9
similar to lasso and other `1 penalty methods (Main Section 15.2.6). In the cases of undirected graphs,
10
this is relatively straightforward, and results in a convex objective, as we discuss in Section 30.4.2.
11
However, in the case of DAGs, the problem is harder, because of the acyclicity constraint. Fortunately,
12
[Zheng2018dags] showed how to encode this constraint as a smooth penalty term. (They call their
13
method “DAGs with no tears”, since it is supposed to be painless to use.) In particular, they show
14
how to convert the combinatorial problem into a continuous problem:
15
min f (W) s.t. G(W) ∈ DAGs ⇐⇒ min f (W) s.t. h(W) = 0 (30.32)
16 W∈RD×D W∈RD×D
17
18
Here W is a weighted adjacency matrix on D nodes, G(W) is the corresponding graph (obtained
19
by thresholding W at 0), f (W) is a scoring function (e.g., penalized log likelihood), and h(W) is a
20
constraint function that measures how close W is to defining a DAG. The constraint is given by
21 d
X
22 h(W) = tr((I + αW)d ) − d ∝ tr( αk W k ) (30.33)
23 k=1
24
25 where Wk = W · · · W with k terms, and α > 0 is a regularizer. Element (i, j) of Wk will be non-zero
26 iff theree is a path from j to i made of K educes. Hence the diagonal elements count the number of
27 paths from an edge to itself in k steps. Thus h(w) will be 0 if W defines a valid DAG.
28 The scoring function considered in [Zheng2018dags] has the form
29
1
30 f (W) = ||X − XW||2F + λ||W||1 (30.34)
31
2N
32 where X ∈ RN D is the data matrix. The show how to find a local optimum of the equality constrained
33 objective using gradient-based methods. The cost per iteration is O(D3 ).
34 Several extensions of this have been proposed. For example, [Yu2019DAGGNN] replace the
35 Gaussian noise assumption with a VAE (variational autoencoder, Main Section 21.2), and use a graph
36 neural network as the encoder/decoder. And [Lachapelle2020] relax the linearity assumption, and
37 allow for the use of neural network dependencies between variables.
38
39
30.3.7 Consistent estimators
40
41 A natural question is whether any of the above algorithms can recover the “true” DAG structure G
42 (up to Markov equivalence), in the limit of infinite data. We assume that the data was generated by
43 a distribution p that is faithful to G (see Section 30.3.1).
44 The posterior mode (MAP) is known to converge to the MLE, which in turn will converge to the
45 true graph G (up to Markov equivalence), so any exact algorithm for Bayesian inference is a consistent
46 estimator. [Chickering02opt] showed that his greedy equivalence search method (which is a form of
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
30.3. LEARNING DAG STRUCTURES

1
2 hill climbing in the space of PDAGs) is a consistent estimator. Similarly, [SpirtesBook; Kalisch07]
3 showed that the PC is a consistent estimator. However, the running time of these algorithms might
4 be exponential in the number of nodes. Also, all of these methods assume that all the variables are
5 fully observed.
6
7
30.3.8 Handling latent variables
8
9 In general, we will not get to observe the values of all the nodes (i.e., the complete data assumption
10 does not hold), either because we have missing data, and/ or because we have hidden variables. This
11 makes it intractable to compute the marginal likelihood of any given graph structure, as we discuss
12 in Section 30.3.8.1. It also opens up new problems, such as knowing how many hidden variables to
13 add to the model, and how to connect them, as we discuss in Section 30.3.8.7.
14
15
30.3.8.1 Approximating the marginal likelihood
16
17 If we have hidden or missing variables h, the marginal likelihood is given by
18 Z X
19
XZ
p(D|G) = p(D, h|θ, G)p(θ|G)dθ = p(D, h|θ, G)p(θ|G)dθ (30.35)
20
h h
21
22 In general this is intractable to compute. For example, consider a mixture model, where we don’t
23 observe the cluster label. In this case, there are K N possible completions of the data (assuming we
24 have K clusters); we can evaluate the inner integral for each one of these assignments to h, but we
25 cannot afford to evaluate all of the integrals. (Of course, most of these integrals will correspond to
26 hypotheses with little posterior support, such as assigning single data points to isolated clusters,
27 but we don’t know ahead of time the relative weight of these assignments.) Below we mention some
28 faster deterministic approximations for the marginal likelihood.
29
30
30.3.8.2 BIC approximation
31
32 A simple approximation to the marginal likelihood is to use the BIC score (Main Section 3.8.7.2),
33 which is given by
34
35 log N
BIC(G) , log p(D|θ̂, G) − dim(G) (30.36)
36 2
37
38 where dim(G) is the number of degrees of freedom in the model and θ̂ is the MAP or ML estimate.
39 However, the BIC score often severely underestimates the true marginal likelihood [Chickering97],
40 resulting in it selecting overly simple models. We discuss some better approximations below.
41
42
30.3.8.3 Cheeseman-Stutz approximation
43
44 We now discuss the Cheeseman-Stutz approximation (CS) to the marginal likelihood [Cheeseman96].
45 We first compute a MAP estimate of the parameters θ̂ (e.g., using EM). Denote the expected sufficient
46 statistics of the data by D = D(θ̂); in the case of discrete variables, we just “fill in” the hidden
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


244

1
2 variables with their expectation. We then use the exact marginal likelihood equation on this filled-in
3 data:
4 Z
5 p(D|G) ≈ p(D|G) = p(D|θ, G)p(θ|G)dθ (30.37)
6
7 However, comparing this to Equation (30.35), we can see that the value will be exponentially smaller,
8 since it does not sum over all values of h. To correct for this, we first write
9
10 log p(D|G) = log p(D|G) + log p(D|G) − log p(D|G) (30.38)
11
12 and then we apply a BIC approximation to the last two terms:
13  
log N
14 log p(D|G) − log p(D|G) ≈ log p(D|θ̂, G) − dim(G)
15 2
 
16 log N
− log p(D|θ̂, G) − dim(G) (30.39)
17 2
18
= log p(D|θ̂, G) − log p(D|θ̂, G) (30.40)
19
20 Putting it altogether we get
21
22 log p(D|G) ≈ log p(D|G) + log p(D|θ̂, G) − log p(D|θ̂, G) (30.41)
23
24 The first term p(D|G) can be computed by plugging in the filled-in data into the exact marginal
25 likelihood. The second term p(D|θ̂, G), which involves an exponential sum (thus matching the
26 “dimensionality” of the left hand side) can be computed using an inference algorithm. The final term
27 p(D|θ̂, G) can be computed by plugging in the filled-in data into the regular likelihood.
28
29 30.3.8.4 Variational Bayes EM
30
31 An even more accurate approach is to use the variational Bayes EM algorithm. Recall from
32 Main Section 10.3.5 that the key idea is to make the following factorization assumption:
33 Y
34 p(θ, z1:N |D) ≈ q(θ)q(z) = q(θ) q(zi ) (30.42)
35 i

36
where zi are the hidden variables in case i. In the E step, we update the q(zi ), and in the M step, we
37
update q(θ). The corresponding variational free energy provides a lower bound on the log marginal
38
likelihood. In [Beal06], it is shown that this bound is a much better approximation to the true log
39
marginal likelihood (as estimated by a slow annealed importance sampling procedure) than either
40
BIC or CS. In fact, one can prove that the variational bound will always be more accurate than CS
41
(which in turn is always more accurate than BIC).
42
43
30.3.8.5 Example: college plans revisited
44
45 Let us revisit the college plans dataset from Section 30.3.3.4. Recall that if we ignore the possibility
46 of hidden variables there was a direct link from socio economic status to IQ in the MAP DAG.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
30.3. LEARNING DAG STRUCTURES

1
p(H=0) = 0.63
2 H p(H=1) = 0.37
3 H p(SES=high|H)
PE H p(IQ=high|PE,H)
p(male) = 0.48 0 0.088
4 low 0 0.098 SEX 1 0.51
5 low 1 0.22
high 0 0.21 SES
6 high 1 0.49
PE
7
IQ SES IQ PE p(CP=yes|SES,IQ,PE)
8 low low low 0.011
9 SES SEX p(PE=high|SES,SEX) low low high 0.170
low high low 0.124
10 low male 0.32 CP low high high 0.53
low female 0.166 high low low 0.093
11 high male 0.86 high low high 0.39
12 high female 0.81 high high low 0.24
high high high 0.84
13
14
15
Figure 30.16: The most probable DAG with a single binary hidden variable learned from the Sewell-Shah data.
MAP estimates of the CPT entries are shown for some of the nodes. From [Heckerman97]. Used with kind
16
permission of David Heckerman.
17
18
19
20 Heckerman et al. decided to see what would happen if they introduced a hidden variable H, which
21 they made a parent of both SES and IQ, representing a hidden common cause. They also considered
22 a variant in which H points to SES, IQ and PE. For both such cases, they considered dropping none,
23 one, or both of the SES-PE and PE-IQ edges. They varied the number of states for the hidden node
24 from 2 to 6. Thus they computed the approximate posterior over 8 × 5 = 40 different models, using
25 the CS approximation.
26 The most probable model which they found is shown in Figure 30.16. This is 2 · 1010 times more
27 likely than the best model containing no hidden variable. It is also 5 · 109 times more likely than the
28 second most probable model with a hidden variable. So again the posterior is very peaked.
29 These results suggests that there is indeed a hidden common cause underlying both the socio-
30 economic status of the parents and the IQ of the children. By examining the CPT entries, we see
31 that both SES and IQ are more likely to be high when H takes on the value 1. They interpret this to
32 mean that the hidden variable represents “parent quality” (possibly a genetic factor). Note, however,
33 that the arc between H and SES can be reversed without changing the v-structures in the graph, and
34 thus without affecting the likelihood; this underscores the difficulty in interpreting hidden variables.
35 Interestingly, the hidden variable model has the same conditional independence assumptions
36 amongst the visible variables as the most probable visible variable model. So it is not possible to
37 distinguish between these hypotheses by merely looking at the empirical conditional independencies
38 in the data (which is the basis of the constraint-based approach to structure learning discussed in
39 Section 30.3.5). Instead, by adopting a Bayesian approach, which takes parsimony into account (and
40 not just conditional independence), we can discover the possible existence of hidden factors. This is
41 the basis of much of scientific and everday human reasoning (see e.g. [Griffiths09] for a discussion).
42
43
30.3.8.6 Structural EM
44
45 One way to perform structural inference in the presence of missing data is to use a standard search
46 procedure (deterministic or stochastic), and to use the methods from Section 30.3.8.1 to estimate the
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


246

1
2 marginal likelihood. However, this approach is not very efficient, because the marginal likelihood
3 does not decompose when we have missing data, and nor do its approximations. For example, if
4 we use the CS approximation or the VBEM approximation, we have to perform inference in every
5 neighboring model, just to evaluate the quality of a single move!
6 [Friedman97nir; Thiesson98] presents a much more efficient approach called the structural
7 EM algorithm. The basic idea is this: instead of fitting each candidate neighboring graph and then
8 filling in its data, fill in the data once, and use this filled-in data to evaluate the score of all the
9 neighbors. Although this might be a bad approximation to the marginal likelihood, it can be a good
10 enough approximation of the difference in marginal likelihoods between different models, which is all
11 we need in order to pick the best neighbor.
12 More precisely, define D(G0 , θ̂0 ) to be the data filled in using model G0 with MAP parameters θ̂0 .
13 Now define a modified BIC score as follows:
14 log N
15 BIC(G, D) , log p(D|θ̂, G) − dim(G) + log p(G) + log p(θ̂|G) (30.43)
2
16
17 where we have included the log prior for the graph and parameters. One can show [Friedman97nir]
18 that if we pick a graph G which increases the BIC score relative to G0 on the expected data, it will
19 also increase the score on the actual data, i.e.,
20
21
BIC(G, D) − BIC(G0 , D) ≤ BIC(G, D) − BIC(G0 , D) (30.44)
22
To convert this into an algorithm, we proceed as follows. First we initialize with some graph G0
23
and some set of parameters θ0 . Then we fill-in the data using the current parameters — in practice,
24
this means when we ask for the expected counts for any particular family, we perform inference using
25
our current model. (If we know which counts we will need, we can precompute all of them, which is
26
much faster.) We then evaluate the BIC score of all of our neighbors using the filled-in data, and we
27
pick the best neighbor. We then refit the model parameters, fill-in the data again, and repeat. For
28
increased speed, we may choose to only refit the model every few steps, since small changes to the
29
structure hopefully won’t invalidate the parameter estimates and the filled-in data too much.
30
One interesting application is to learn a phylogenetic tree structure. Here the observed leaves are
31
the DNA or protein sequences of currently alive species, and the goal is to infer the topology of the
32
tree and the values of the missing internal nodes. There are many classical algorithms for this task
33
(see e.g., [Durbin98]), but one that uses structural EM is discussed in [FriedmanNirPhylo02].
34
Another interesting application of this method is to learn sparse mixture models [Barash02]. The
35
idea is that we have one hidden variable C specifying the cluster, and we have to choose whether to
36
add edges C → Xt for each possible feature Xt . Thus some features will be dependent on the cluster
37
id, and some will be independent. (See also [Law04] for a different way to perform this task, using
38
regular EM and a set of bits, one per feature, that are free to change across data cases.)
39
40
30.3.8.7 Discovering hidden variables
41
42 In Section 30.3.8.5, we introduced a hidden variable “by hand”, and then figured out the local topology
43 by fitting a series of different models and computing the one with the best marginal likelihood. How
44 can we automate this process?
45 Figure 30.17 provides one useful intuition: if there is a hidden variable in the “true model”, then its
46 children are likely to be densely connected. This suggest the following heuristic [Elidan00]: perform
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
30.3. LEARNING DAG STRUCTURES

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Figure 30.17: A DGM with and without hidden variables. For example, the leaves might represent medical
symptoms, the root nodes primary causes (such as smoking, diet and exercise), and the hidden variable can
16
represent mediating factors, such as heart disease. Marginalizing out the hidden variable induces a clique.
17
18
19
20 structure learning in the visible domain, and then look for structural signatures, such as sets of
21 densely connected nodes (near-cliques); introduce a hidden variable and connect it to all nodes in
22 this near-clique; and then let structural EM sort out the details. Unfortunately, this technique does
23 not work too well, since structure learning algorithms are biased against fitting models with densely
24 connected cliques.
25 Another useful intuition comes from clustering. In a flat mixture model, also called a latent class
26 model, the discrete latent variable provides a compressed representation of its children. Thus we
27 want to create hidden variables with high mutual information with their children.
28 One way to do this is to create a tree-structured hierarchy of latent variables, each of which only
29 has to explain a small set of children. [Zhang04] calls this a hierarchical latent class model.
30 They propose a greedy local search algorithm to learn such structures, based on adding or deleting
31 hidden nodes, adding or deleting edges, etc. (Note that learning the optimal latent tree is NP-hard
32 [Roch06].)
33 Recently [Harmeling11] proposed a faster greedy algorithm for learning such models based on
34 agglomerative hierarchical clustering. Rather than go into details, we just give an example of what
35 this system can learn. Figure 30.18 shows part of a latent forest learned from the 20-newsgroup
36 data. The algorithm imposes the constraint that each latent node has exactly two children, for speed
37 reasons. Nevertheless, we see interpretable clusters arising. For example, Figure 30.18 shows separate
38 clusters concerning medicine, sports and religion. This provides an alternative to LDA and other
39 topic models (Main Section 28.5.1), with the added advantage that inference in latent trees is exact
40 and takes time linear in the number of nodes.
41 An alternative approach is proposed in [Choi11], in which the observed data is not constrained to
42 be at the leaves. This method starts with the Chow-Liu tree on the observed data, and then adds
43 hidden variables to capture higher-order dependencies between internal nodes. This results in much
44 more compact models, as shown in Figure 30.19. This model also has better predictive accuracy
45 than other approaches, such as mixture models, or trees where all the observed data is forced to
46 be at the leaves. Interestingly, one can show that this method can recover the exact latent tree
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


ed by method BIN - A modelling co-occurrences in the 20 newsgroup dataset.
248

1
2
! " " "
3
4
 ! ! !  !  !
5
6
             
7
8
      !        
9
10
            !  
11
12
         "      
13
14
    "           
15
16
                  
17
18
   
19
20
Figure 30.18: Part of a hierarchical latent tree learned from the 20-newsgroup data. From Figure 2 of 
21
[Harmeling11]. Used with kind permission of Stefan Harmeling.
22

23
24
25 structure, providing the data is generated from a tree. See [Choi11] for details. Note, however, that
26 this approach, unlike [Zhang04; Harmeling11], requires that the cardinality of all the variables,
27 hidden and observed, be the same. Furthermore, if the observed variables are Gaussian, the hidden
28 variables must be Gaussian also.
29
30 30.3.8.8 Example: Google’s Rephil
31
32
In this section, we describe a huge DGM called Rephil, which was automatically learned from data.2
33
The model is widely used inside Google for various purposes, including their famous AdSense system.3
34
The model structure is shown in Figure 30.20. The leaves are binary nodes, and represent the
35
presence or absence of words or compounds (such as “New York City”) in a text document or query.
36
The latent variables are also binary, and represent clusters of co-occuring words. All CPDs are
37
noisy-OR, since some leaf nodes (representing words) can have many parents. This means each edge
38
can be augmented with a hidden variable specifying if the link was activated or not; if the link is not
39
active, then the parent cannot turn the child on. (A very similar model was proposed independently
40
2. The original system, called “Phil”, was developed by Georges Harik and Noam Shazeer,. It has been published as
41
US Patent #8024372, “Method and apparatus for learning a probabilistic generative model for text”, filed in 2004.
42 Rephil is a more probabilistically sound version of the method, developed by Uri Lerner et al. The summary below is
43 based on notes by Brian Milch (who also works at Google).
44
3. AdSense is Google’s system for matching web pages with content-appropriate ads in an automatic way, by extracting
semantic keywords from web pages. These keywords play a role analogous to the words that users type in when
45 searching; this latter form of information is used by Google’s AdWords system. The details are secret, but [Levy11]
46 gives an overview.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
30.3. LEARNING DAG STRUCTURES

1
2 h3 h17

3
4 president government power h4 war h20 religion h14 earth lunar orbit satellite solar
children

5
moon technology mission
6 law state human rights world israel jews h8 bible god

7 mars h1
gun
8 h2 christian jesus
space launch shuttle nasa
9
10 health
case course evidence fact question
program h9
11
food aids h21
12 insurance
version h12 ftp email
13 msg water studies h13 medicine
car
14 h25 files format phone

15 dealer h15 cancer disease doctor patients vitamin


windows h18 h11 image number
16
17
18
&OXVWHU0HUJLQJ
bmw engine honda oil

h5
card driver h10 dos h19 h26

19
video h16 disk memory h22 pc software display server
h6 puck season team h7 win
20
21 computer h24
graphics h23 system data scsi drive
games baseball league players fans hockey nhl won
22
23 Ɣ 0HUJHFOXVWHUV$DQG%LI$H[SODLQV% VWRSZRUGV
hit problem help mac
science university research

24
25
DQG%H[SODLQV$ VWRSZRUGV
Figure 30.19: A partially latent tree learned from the 20-newsgroup data. Note that some words can have
26 multiple meanings, and get connected to different latent variables, representing different “topics”. For example,

Ɣ 'LVFDUGFOXVWHUVWKDWDUHXVHGWRRUDUHO\
27 the word “win” can refer to a sports context (represented by h5) or the Microsoft Windows context (represented
28 by h25). From Figure 12 of [Choi11]. Used with kind permission of Jin Choi.
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43 Figure 30.20: Google’s rephil model. Leaves represent presence or absence of words. Internal nodes represent
44 clusters of co-occuring words, or “concepts”. All nodes are binary, and all CPDs are noisy-OR. The model
45
contains 12 million word nodes, 1 million latent cluster nodes, and 350 million edges. Used with kind
permission of Brian Milch.
46
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


250

1
2 in [Singliar06].)
3 Parameter learning is based on EM, where the hidden activation status of each edge needs to be
4 inferred [Meek97]. Structure learning is based on the old neuroscience idea that “nodes that fire
5 together should wire together”. To implement this, we run inference and check for cluster-word
6 and cluster-cluster pairs that frequently turn on together. We then add an edge from parent to
7 child if the link can significantly increase the probability of the child. Links that are not activated
8 very often are pruned out. We initialize with one cluster per “document” (corresponding to a set of
9 semantically related phrases). We then merge clusters A and B if A explains B’s top words and vice
10 versa. We can also discard clusters that are used too rarely.
11 The model was trained on about 100 billion text snippets or search queries; this takes several
12 weeks, even on a parallel distributed computing architecture. The resulting model contains 12 million
13 word nodes and about 1 million latent cluster nodes. There are about 350 million links in the model,
14 including many cluster-cluster dependencies. The longest path in the graph has length 555, so the
15 model is quite deep.
16 Exact inference in this model is obviously infeasible. However note that most leaves will be off,
17 since most words do not occur in a given query; such leaves can be analytically removed. We can
18 also prune out unlikely hidden nodes by following the strongest links from the words that are on up
19 to their parents to get a candidate set of concepts. We then perform iterative conditional modes
20 (ICM) to form approximate inference. (ICM is a deterministic version of Gibbs sampling that sets
21 each node to its most probable state given the values of its neighbors in its Markov blanket.) This
22 continues until it reaches a local maximum. We can repeat this process a few times from random
23 starting configurations. At Google, this can be made to run in 15 milliseconds!
24
25 30.3.8.9 Spectral methods
26
27
Recently, various methods have been developed that can recover the exact structure of the DAG, even
28
in the presence of (a known number of) latent variables, under certain assumptions. In particular,
29
identifiability results have been obtained for the following cases:
30 • If x contains 3 of more indepenent views of z [Goodman1974; Allman2009; Anandkumar12colt;
31 Hsu12], sometimes called the triad constraint.
32
33 • If z is categorical, and x is a GMM with mixture components which depend on z [Anandkumar2014].
34
35
• If z is composed of binary variables, and x is a set of noisy-OR CPDs [Jernite13; Arora2016].
36
In terms of algorithms, most of these methods are not based on maximum likelihood, but instead use
37
the method of moments and spectral methods. For details, see [Anandkumar2014].
38
39
30.3.8.10 Constraint-based methods for learning ADMGs
40
41 An alternative to explicitly modeling latent variables is to marginalize them out, and work with acyclic
42 directed mixed graphs (Main Section 4.5.4.2). It is possible to perform Bayesian model selection
43 for ADMGs, although the method is somewhat slow and complicated [Silva09]. Alternatively, one
44 can modify the PC/IC algorithm to learn an ADMG. This method is known as the IC* algorithm
45 [PearlBook]; one can speed it up to get the FCI algorithm (FCI stands for “fast causal inference”)
46 [SpirtesBook].
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
30.4. LEARNING UNDIRECTED GRAPH STRUCTURES

1
2 Since there will inevitably be some uncertainty about edge orientations, due to Markov equivalence,
3 the output of IC*/ FCI is not actually an ADMG, but is a closely related structured called a
4 partially oriented inducing path graph [SpirtesBook] or a marked pattern [PearlBook].
5 Such a graph has 4 kinds of edges:

6 • A marked arrow a → b signifying a directed path from a to b.
7 • An unmarked arrow a → b signifying a directed path from a to b or a latent common cause
8 a ← L → b.
9 • A bidirected arrow a ↔ b signifying a latent common causes a ← L → b.
10 • An undirected edge a − b signifying a ← b or a → b or a latent common causes a ← L → b.
11 IC*/ FCI is faster than Bayesian inference, but suffers from the same problems as the original
12 IC/PC algorithm (namely, the need for a CI testing oracle, problems due to multiple testing, no
13 probabilistic representation of uncertainty, etc.) Furthermore, by not explicitly representing the
14 latent variables, the resulting model cannot be used for inference and prediction.
15
16
30.4 Learning undirected graph structures
17
18 In this section, we discuss how to learn the structure of undirected graphical models. On the one
19 hand, this is easier than learning DAG structure because we don’t need to worry about acyclicity.
20 On the other hand, it is harder than learning DAG structure since the likelihood does not decompose
21 (see Main Section 4.3.9.1). This precludes the kind of local search methods (both greedy search and
22 MCMC sampling) we used to learn DAG structures, because the cost of evaluating each neighboring
23 graph is too high, since we have to refit each model from scratch (there is no way to incrementally
24 update the score of a model). In this section, we discuss several solutions to this problem.
25
26
30.4.1 Dependency networks
27
28 A simple way to learn the structure of a UGM is to represent it is as a product of full conditionals:
29
D
30 1 Y
p(x) = p(xd |x−d ) (30.45)
31 Z
d=1
32
33 This expression is called the pseudolikelihood.
34 Such a collection of local distributions defines a model called a dependency network [Heckerman00].
35 Unfortunately, a product of full conditionals which are independently estimated is not guaranteed to
36 be consistent with any valid joint distribution. However, we can still use the model inside of a Gibbs
37 sampler to approximate a joint distribution. This approach is sometimes used for data imputation
38 [Gelman01].
39 However, the main use advantage of dependency networks is that we can use sparse regres-
40 sion techniques for each distribution p(xd |x−d ) to induce a sparse graph structure. For example,
41 [Heckerman00] use classification/ regression trees, [Meinshausen06] use `1 -regularized linear
42 regression, [Wainwright06; Wu2019nips] use `1 -regularized logistic regression, [Dobra09] uses
43 Bayesian variable selection, etc.
44 Figure 30.21 shows a dependency network that was learned from the 20-newsgroup data using
45 `1 regularized logistic regression, where the penalty parameter λ was chosen by BIC. Many of the
46 words present in these estimated Markov blankets represent fairly natural associations (aids:disease,
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


252

1
children case course fact question

2
3 earth bible christian food baseball

4
5 mission god disk mac car aids doctor fans

6
7 nasa jesus pc dos drive bmw israel government health games hockey hit

8
launch memory scsi jews engine dealer state war computer president medicine season puck nhl
9
10
shuttle religion data card honda power oil world insurance science studies team

11
12 software solar graphics driver gun research university water human cancer win league

13
14 lunar system display video windows law disease won players

15
16 moon server files rights problem evidence

17
space program format patients msg help
18
19
mars orbit technology ftp image number vitamin email

20
21 satellite version phone

22
23 Figure 30.21: A dependency network constructed from the 20 newsgroup data. We show all edges with
24 regression weight above 0.5 in the Markov blankets estimated by `1 penalized logistic regression. Undirected
25 edges represent cases where a directed edge was found in both directions. From Figure 4.9 of [SchmidtThesis].
26
Used with kind permission of Mark Schmidt.
27
28
29
baseball:fans, bible:god, bmw:car, cancer:patients, etc.). However, some of the estimated statistical
30
dependencies seem less intuitive, such as baseball:windows and bmw:christian. We can gain more
31
insight if we look not only at the sparsity pattern, but also the values of the regression weights. For
32
example, here are the incoming weights for the first 5 words:
33 • aids: children (0.53), disease (0.84), fact (0.47), health (0.77), president (0.50), research (0.53)
34
35 • baseball: christian (-0.98), drive (-0.49), games (0.81), god (-0.46), government (-0.69), hit (0.62),
36 memory (-1.29), players (1.16), season (0.31), software (-0.68), windows (-1.45)
37
• bible: car (-0.72), card (-0.88), christian (0.49), fact (0.21), god (1.01), jesus (0.68), orbit (0.83),
38
program (-0.56), religion (0.24), version (0.49)
39
40 • bmw: car (0.60), christian (-11.54), engine (0.69), god (-0.74), government (-1.01), help (-0.50),
41 windows (-1.43)
42
• cancer: disease (0.62), medicine (0.58), patients (0.90), research (0.49), studies (0.70)
43
44 Words in italic red have negative weights, which represents a dissociative relationship. For example,
45 the model reflects that baseball:windows is an unlikely combination. It turns out that most of the
46 weights are negative (1173 negative, 286 positive, 8541 zero) in this model.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
30.4. LEARNING UNDIRECTED GRAPH STRUCTURES

1
2 [Meinshausen06] discuss theoretical conditions under which dependency networks using `1 -
3 regularized linear regression can recover the true graph structure, assuming the data was generated
4 from a sparse Gaussian graphical model. We discuss a more general solution in Section 30.4.2.
5
6
30.4.2 Graphical lasso for GGMs
7
8 In this section, we consider the problem of learning the structure of undirected Gaussian graphical
9 models (GGM)s. These models are useful, since there is a 1:1 mapping between sparse parameters
10 and sparse graph structures. This allows us to extend the efficient techniques of `1 regularized
11 estimation in Main Section 15.2.6 to the graph case; the resulting method is called the graphical
12 lasso or Glasso [Friedman08glasso; Mazumder12].
13
14 30.4.2.1 MLE for a GGM
15
16 Before discussing structure learning, we need to discuss parameter estimation. The task of computing
17 the MLE for a (non-decomposable) GGM is called covariance selection [Dempster72].
18 The log likelihood can be written as
19
20 `(Ω) = log det Ω − tr(SΩ) (30.46)
21 PN
22 where Ω = Σ−1 is the precision matrix, and S = N1 i=1 (xi − x)(xi − x)T is the empirical covariance
23 matrix. (For notational simplicity, we assume we have already estimated µ̂ = x.) One can show that
24 the gradient of this is given by
25
26
∇`(Ω) = Ω−1 − S (30.47)
27
However, we have to enforce the constraints that Ωst = 0 if Gst = 0 (structural zeros), and that Ω is
28
positive definite. The former constraint is easy to enforce, but the latter is somewhat challenging
29
(albeit still a convex constraint). One approach is to add a penalty term to the objective if Ω leaves
30
the positive definite cone; this is the approach used in [Dahl08]. Another approach is to use a
31
coordinate descent method, described in [HastieBook].
32
Interestingly, one can show that the MLE must satisfy the following property: Σst = Sst if Gst = 1
33
or s = t, i.e., the covariance of a pair that are connected by an edge must match the empirical
34
covariance. In addition, we have Ωst = 0 if Gst = 0, by definition of a GGM, i.e., the precision of a
35
pair that are not connected must be 0. We say that Σ is a positive definite matrix completion of
36
S, since it retains as many of the entries in S as possible, corresponding to the edges in the graph,
37
subject to the required sparsity pattern on Σ−1 , corresponding to the absent edges; the remaining
38
entries in Σ are filled in so as to maximize the likelihood.
39
Let us consider a worked example from [HastieBook]. We will use the following adjacency matrix,
40
representing the cyclic structure, X1 − X2 − X3 − X4 − X1 , and the following empirical covariance
41
matrix:
42
   
43 0 1 0 1 10 1 5 4
44 1 0 1 0  1 10 2 6 
G= 
0 1 0 1 , S =  5
  (30.48)
45 2 10 3 
46 1 0 1 0 4 6 3 10
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


254

1
2 The MLE is given by
3    
4 10.00 1.00 1.31 4.00 0.12 −0.01 0 −0.05
 1.00 10.00 2.00 0.87   0 
5
Σ=  , Ω = −0.01 0.11 −0.02  (30.49)
6
 1.31 2.00 10.00 3.00   0 −0.02 0.11 −0.03
7 4.00 0.87 3.00 10.00 −0.05 0 −0.03 0.13
8
9 (See ggm_fit_demo.ipynb for the code to reproduce these numbers, using the coordinate descent
10 algorithm from [Friedman08glasso].) The constrained elements in Ω, and the free elements in Σ,
11 both of which correspond to absent edges, have been highlighted.
12
13
30.4.2.2 Promoting sparsity
14
15 We now discuss one way to learn a sparse Gaussian MRF structure, which exploits the fact that
16 there is a 1:1 correspondence between zeros in the precision matrix and absent edges in the graph.
17 This suggests that we can learn a sparse graph structure by using an objective that encourages zeros
18 in the precision matrix. By analogy to lasso (see Main Section 15.2.6), one can define the following
19 `1 penalized NLL:
20
21 J(Ω) = − log det Ω + tr(SΩ) + λ||Ω||1 (30.50)
22 P
23 where ||Ω||1 = j,k |ωjk | is the 1-norm of the matrix. This is called the graphical lasso or Glasso.
24 Although the objective is convex, it is non-smooth (because of the non-differentiable `1 penalty)
25 and is constrained (because Ω must be a positive definite matrix). Several algorithms have been
26 proposed for optimizing this objective [Yuan07ggm; Banerjee08; Duchi08], although arguably
27 the simplest is the one in [Friedman08glasso], which uses a coordinate descent algorithm similar
28 to the shooting algorithm for lasso. An even faster method, based on soft thresholding, is described
29 in [Fattahi2018; Fattahi2018ieee].
30 As an example, let us apply the method to the flow cytometry dataset from [Sachs05]. A discretized
31 version of the data is shown in Main Figure 20.7(a). Here we use the original continuous data.
32 However, we are ignoring the fact that the data was sampled under intervention. In Main Figure 30.1,
33 we illustrate the graph structures that are learned as we sweep λ from 0 to a large value. These
34 represent a range of plausible hypotheses about the connectivity of these proteins.
35 It is worth comparing this with the DAG that was learned in Main Figure 20.7(b). The DAG has
36 the advantage that it can easily model the interventional nature of the data, but the disadvantage
37 that it cannot model the feedback loops that are known to exist in this biological pathway (see the
38 discussion in [Schmidt09dcg]). Note that the fact that we show many UGMs and only one DAG is
39 incidental: we could easily use BIC to pick the “best” UGM, and conversely, we could easily display
40 several DAG structures, sampled from the posterior.
41
42
30.4.3 Graphical lasso for discrete MRFs/CRFs
43
44 It is possible to extend the graphical lasso idea to the discrete MRF and CRF case. However, now
45 there is a set of parameters associated with each edge in the graph, so we have to use the graph
46 analog of group lasso (see Main Section 15.2.6). For example, consider a pairwise CRF with ternary
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
30.4. LEARNING UNDIRECTED GRAPH STRUCTURES

1
2 nodes, and node and edge potentials given by
3  T   T 
4
vt1 x wt11 x wTst12 x wTst13 x
5
ψt (yt , x) = vTt2 x , ψst (ys , yt , x) = wTst21 x wTst22 x wTst23 x (30.51)
6
vTt3 x wTst31 x wTst32 x wTst33 x
7
where we assume x begins with a constant 1 term, to account for the offset. (If x only contains 1,
8
the CRF reduces to an MRF.) Note that we may choose to set some of the vtk and wstjk weights to
9
0, to ensure identifiability, although this can also be taken care of by the prior.
10
To learn sparse structure, we can minimize the following objective:
11
N
" NG X
NG
#
12 X X X
13 J =− log ψt (yit , xi , vt ) + log ψst (yis , yit , xi , wst )
14 i=1 t s=1 t=s+1

15 NG X
X NG NG
X
16 + λ1 ||wst ||p + λ2 ||vt ||22 (30.52)
17 s=1 t=s+1 t=1

18
where ||wst ||p is the p-norm; common choices are p = 2 or p = ∞, as explained in Main Sec-
19
tion 15.2.6. This method of CRF structure learning was first suggested in [Schmidt08]. (The use of
20
`1 regularization for learning the structure of binary MRFs was proposed in [Lee06].)
21
Although this objective is convex, it can be costly to evaluate, since we need to perform inference
22
to compute its gradient, as explained in Main Section 4.4.3 (this is true also for MRFs), due to the
23
global partition function. We should therefore use an optimizer that does not make too many calls to
24
the objective function or its gradient, such as the projected quasi-Newton method in [Schmidt09].
25
In addition, we can use approximate inference, such as loopy belief propagation (Main Section 9.4),
26
to compute an approximate objective and gradient more quickly, although this is not necessarily
27
theoretically sound.
28
Another approach is to apply the group lasso penalty to the pseudo-likelihood discussed in
29
Main Section 4.3.9.3. This is much faster, since inference is no longer required [Hoefling09].
30
Figure 30.22 shows the result of applying this procedure to the 20-newsgroup data, where yit indicates
31
the presence of word t in document i, and xi = 1 (so the model is an MRF).
32
For a more recent approach to learning sparse discrete UPGM structures, based on sparse full
33
conditionals, see the GRISE (Generalized Regularized Interaction Screening Estimator) method of
34
[Vuffray2019], which takes polynomial time, yet its sample complexity is close to the information-
35
theoretic lower bounds [Lokhov2018].
36
37
38
30.4.4 Bayesian inference for undirected graph structures
39 Although the graphical lasso is reasonably fast, it only gives a point estimate of the structure.
40 Furthermore, it is not model-selection consistent [Meinshausen05], meaning it cannot recover the
41 true graph even as N → ∞. It would be preferable to integrate out the parameters, and perform
42 posterior inference in the space of graphs, i.e., to compute p(G|D). We can then extract summaries
43 of the posterior, such as posterior edge marginals, p(Gij = 1|D), just as we did for DAGs. In this
44 section, we discuss how to do this.
45 If the graph is decomposable, and if we use conjugate priors, we can compute the marginal likelihood
46 in closed form [Dawid93]. Furthermore, we can efficiently identify the decomposable neighbors of
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


256

1
case children bible health

2
3 course christian insurance

4
computer evidence

5
6 disk email display card fact earth

7 files graphics government god

8
dos format help data image video gun human car president israel jesus
9
10 drive memory number power law engine dealer jews baseball

11
ftp mac scsi problem rights war religion games fans
12
13 pc program phone nasa state question hockey

14
software research shuttle league nhl
15
16 launch moon science orbit players

17
space university world season
18
19 system driver team

20
version technology win

21
22 windows won

23
Figure 30.22: An MRF estimated from the 20-newsgroup data using group `1 regularization with λ = 256.
24
Isolated nodes are not plotted. From Figure 5.9 of [SchmidtThesis]. Used with kind permission of Mark
25
Schmidt.
26
27
28
29 a graph [Thomas09], i.e., the set of legal edge additions and removals. This means that we can
30 perform relatively efficient stochastic local search to approximate the posterior (see e.g. [Giudici99;
31 Armstrong08; Scott08]).
32 However, the restriction to decomposable graphs is rather limiting if one’s goal is knowledge
33 discovery, since the number of decomposable graphs is much less than the number of general
34 undirected graphs.4
35 A few authors have looked at Bayesian inference for GGM structure in the non-decomposable case
36 (e.g., [Dellaportas03; Wong03; Jones05]), but such methods cannot scale to large models because
37 they use an expensive Monte Carlo approximation to the marginal likelihood [Atay-Kayis05].
38 [Lenkoski08] suggested using a Laplace approximation. This requires computing the MAP estimate
39 of the parameters for Ω under a G-Wishart prior [Roverato02]. In [Lenkoski08], they used the
40 iterative proportional scaling algorithm [Speed86; Hara08] to find the mode. However, this is very
41 slow, since it requires knowing the maximal cliques of the graph, which is NP-hard in general.
42 In [Moghaddam09], a much faster method is proposed. In particular, they modify the gradient-
43
44
4. The number of decomposable graphs on NG nodes, for NG = 2, . . . , 8, is as follows ([Armstrong05]): 2; 8; 61; 822;
18,154; 61,7675; 30,888,596. If we divide these numbers by the number of undirected graphs, which is 2NG (NG −1)/2 ,
45 we find the ratios are: 1, 1, 0.95, 0.8, 0.55, 0.29, 0.12. So we see that decomposable graphs form a vanishing fraction of
46 the total hypothesis space.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
30.5. LEARNING CAUSAL DAGS

1
2 based methods from Section 30.4.2.1 to find the MAP estimate; these algorithms do not need to
3 know the cliques of the graph. A further speedup is obtained by just using a diagonal Laplace
4 approximation, which is more accurate than BIC, but has essentially the same cost. This, plus the
5 lack of restriction to decomposable graphs, enables fairly fast stochastic search methods to be used
6 to approximate p(G|D) and its mode. This approach significantly outperfomed graphical lasso, both
7 in terms of predictive accuracy and structural recovery, for a comparable computational cost.
8
9
10
11
30.5 Learning causal DAGs
12
13 Causal reasoning (which we discuss in more detail in Main Chapter 36) relies on knowing the
14 underlying structure of the DAG (although [Jaber2019] shows how to answer some queries if we
15 just know the graph up to Markov equivalence). Learning this structure is called causal discovery
16 (see e.g., [Glymour2019]).
17 If we just have two variables, we need to know if the causal model should be written as X → Y
18 or X ← Y . both of these models are Markov equivalent (Section 30.3.2), meaning they cannot be
19 distinguished from observational data, yet they make very different causal predictions. We discuss
20 how to learn cause-effect pairs in Section 30.5.1.
21 When we have more than 2 variables, we need to consider more general techniques. In Section 30.3,
22 we discuss how to learn a DAG structure from observational data using likelihood based methods,
23 and hypothesis testing methods. However, these approaches cannot distinguish between models that
24 are Markov equivalent, so we need to perform interventions to reduce the size of the equivalence class
25 [Solus2019]. We discuss some suitable methods in Section 30.5.2.
26 The above techniques assume that the causal variables of interest (e.g., cancer rates, smoking
27 rates) can be measured directly. However, in many ML problems, the data is much more “low level”.
28 For example, consider trying to learn a causal model of the world from raw pixels. We briefly disuss
29 this topic in in Section 30.5.3.
30 For more details on causal discovery methods, see e.g., [Eberhardt2017; Peters2017; HeinzeDeml2018;
31 Guo2021].
32
33
34
30.5.1 Learning cause-effect pairs
35
36 If we only observe a pair of variables, we cannot use methods discussed in Section 30.3 to learn graph
37 structure, since such methods are based on conditional independence tests, which need at least 3
38 variables. However, intuitively, we should still be able to learn causal relationships in this case. For
39 example, we know that altitude X causes temperature Y and not vice versa. For example, suppose
40 we measure X and Y in two different countries, say the Netherlands (low altitude) and Switzerland
41 (high altitude). If we represent the joint distribution as p(X, Y ) = p(X)p(Y |X), we find that the
42 p(Y |X) distribution is stable across the two populations, while p(X) will change. However, if we
43 represent the joint distribution as p(X, Y ) = p(Y )p(X|Y ), we find that both p(Y ) and p(X|Y ) need
44 to change across populations, so both of the corresponding distributions will be more “complicated”
45 to capture this non-stationarity in the data. In this section, we discuss some approaches that exploit
46 this idea. Our presentation is based on [Peters2017]. (See [Mooij2016] for more details.)
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


258

1
2 30.5.1.1 Algorithmic information theory
3
Suppose X ∈ {0, 1} and Y ∈ R and we represent the joint p(x, y) using
4
5 p(x, y) = p(x)p(y|x) = Ber(x|θ)N (y|µx , 1) (30.53)
6
7 We can equally well write this in the following form [Dawid02]:
8
9
p(x, y) = p(y)p(x|y) = [θN (y|µ1 , 1) + (1 − θ)N (y|µ2 , 1)]Ber(x|σ(α + βy)) (30.54)
10
where α = logit(θ) + µ22 − µ21 and β = µ1 − µ2 . We can plausibly argue that the first model, which
11
corresponds to X → Y , is more likely to be correct, since it is consists of two simple distributions
12
that seem to be rather generic. By contrast, in Equation (30.54), the distribution of p(Y ) is more
13
complex, and seems to be dependent on the specific form of p(X|Y ).
14
[Janzing10] show how to formalize this intuition using algorithmic information theory. In
15
particular, they say that X causes Y if the distributions PX and PY |X (not the random variables X
16
and Y ) are algorithmically independent. To define this, let PX (X) be the distribution induced
17
by fx (X, UX ), where UX is a bit string, and fX is represented by a Turing machine. Define PY |X
18
analogously. Finally, let K(s) be the Kolmogorov complexity of bit string s, i.e., the length of
19
the shortest program that would generate s using a universal Turing machine. We say that PX and
20
PY |X are algorithmically independent if
21
22 K(PX,Y ) = K(PX ) + K(PY |X ) (30.55)
23
24 Unfortunately, there is no algorithm to compute the Kolmogorov complexity, so this approach is
25 purely conceptual. In the sections below, we discuss some more practical metrics.
26
27 30.5.1.2 Additive noise models
28
A generic two-variable SCM of the form X → Y requires specifying the function X = fX (UX ), the
29
distribution of UX , the function Y = fY (X, UY ), and the distribution of UY . We can simplify our
30
notation by letting X = Ux and defining p(X) directly, and defining Y = fY (X, UY ) = f (X, U ),
31
where U is a noise term.
32
In general, such a model is not identifiable from a finite dataset. For example, we can imagine that
33
the value of U can be used to select between different functional mappings, Y = f (X, U = u) = fu (X).
34
Since U is not observed, the induced distribution will be a mixture of different mappings, and it
35
will generally be impossible to disentangle. For example, consider the case where X and U are
36
Bernoulli random variables, and U selects between the functions Y = fid (X) = I (Y = X) and
37
Y = fneg (X) = I (Y 6= X). In this case, the induced distribution p(Y ) is uniform, independent of X,
38
even though we have the structure X → Y .
39
The above concerns motivate the desire to restrict the flexibility of the functions at each node. One
40
natural family is additive noise models (ANM), where we assume each variable has the following
41
dependence on its parents [Hoyer2009]:
42
43
Xi = fi (Xpai ) + Ui (30.56)
44
45 In the case of two variables, we have Y = f (X) + U . If X and U are both Gaussian, and f is linear,
46 the system defines a jointly Gaussian distribution p(X, Y ), as we discussed in Main Section 2.3.2.2.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
30.5. LEARNING CAUSAL DAGS

1
2
3
4
5

𝑡
all the heights
are the same
6
all the heights
7 are different

8
0
0
𝑥 𝑥
9
10
11
Figure 30.23: Signature of X causing Y . Left: If we try to predict Y from X, the residual error (noise term,
12
shown by vertical arrows) is independent of X. Right: If we try to predict X from Y , the residual error is
13
not constant. From Figure 8.8 of [Varshney2021]. Used with kind permission of Kush Varshney.
14
15
16
17
This is symmetric, and prevents us distinguishing X → Y from Y → X. However, if we let f be
18
nonlinear, and/or let X or U be non-Gaussian, we can distinguish X → Y from Y → X, as we
19
discuss below.
20
21
22 30.5.1.3 Nonlinear additive noise models
23
24 Suppose pY |X is an additive noise model (possibly Gaussian noise) where f is a nonlinear function.
25 In this case, we will not, in general, be able to create an ANM for pX|Y . Thus we can determine
26 whether X → Y or vice versa as follows: we fit a (nonlinear) regression model for X → Y , and then
27 check if the residual error Y − fˆY (X) is independent of X; we then repeat the procedure swapping
28 the roles of X and Y . The theory [Peters2017] says that the independence test will only pass for
29 the causal direction. See Figure 30.23 for an illustration.
30
31
30.5.1.4 Linear models with non-Gaussian noise
32
33 If the function mapping from X to Y is linear, we cannot tell if X → Y or Y → X if we assume
34 Gaussian noise. This is apparent from the symmetry of Figure 30.23 in the linear case. However, by
35 combining linear models with non-Gaussian noise, we can recover identifiability.
36 For example, consider the ICA model from Main Section 28.6. This is a simple linear model of the
37 form y = Ax, where p(x) has a non-Gaussian distribution, and p(y|x) is a degenerate distribution,
38 since we assume the observation model is deterministic (noise-free). In the ICA case, we can uniquely
39 identify the parameters A and the corresponding latent source x. This lets us distinguish the X → Y
40 model from the Y → X model. (The intuition behind this method is that linear combinations of
41 random variables tend towards a Gaussian distribution (by the central limit theorem), so if X → Y ,
42 then p(Y ) will “look more Gaussian” than p(X).)
43 Another setting in which we can distinguish the direction of the arrow is when we have non-Gaussian
44 observation noise, i.e., y = Ax + UY , where UY is non-Gaussian. This is an example of a “linear
45 non-Gaussian acyclic model” (LiNGAM) [Shimizu2006]. The non-Gaussian additive noise results
46 in the induced distributions p(X, Y ) being different depending on whether X → Y or Y → X.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


260

1
2
3
4
5
6
7
8
9
10
11
12
13
14
Figure 30.24: Illustration of information-geometric causal inference for Y = f (X). The density of the effect
p(Y ) tends to be high in regions where f is flat (and hence f −1 is steep). From Figure 4 of [Janzing2012].
15
16
17
18
30.5.1.5 Information-geometric causal inference
19 An alternative approach, known as information-geometric causal inference, or IGCI, was
20 proposed in [Daniusis2010; Janzing2012]. In this method, we assume f is a deterministic strictly
21 monotonic function on [0, 1], with f (0) = 0 and f (1) = 1, and there is no observation noise, so
22 Y = f (X). If X has the distribution p(X), then the shape of the induced distribution p(Y ) will
23 depend on the form of the function f , as illustrated in Figure 30.24. Intuitively, the peaks of p(Y )
24 will occur in regions where f has small slope, and thus f −1 has large slope. Thus pY (Y ) and f −1 (Y )
25 will depend on each other, whereas pX (X) and f (X) do not (since we assume the distribution of
26 causes is independent of the causal mechanism).
27 More precisely, let the functions log f 0 (the log of the derivative function) and pX be viewed
28 as random variables on the probability space [0, 1] with a uniform distrbution. We say pX,Y
29 satisfies an IGCI model if f is a mapping as above, and the following independence criterion holds:
30 Cov [log f 0 , pX ] = 0, where
31 Z 1 Z 1 Z 1
32
Cov [log f 0 , pX ] = log f 0 (x)pX (x)dx − log f 0 (x)dx pX (x)dx (30.57)
33 0 0 0
34 R1 h 0
i
35 where 0 pX (x)dx = 1. One can show that the inverse function f −1 satisfies Cov log f −1 , pY ≥ 0,
36 with equality iff f is linear.
37 This can be turned into an empirical test as follows. Define
38
Z N −1
39
1
1 X |yj+1 − yj |
CX→Y = log f 0 (x)p(x)dx ≈ log (30.58)
40
0 N − 1 j=1 |xj+1 − xj |
41
42 where x1 < x2 · · · xN are the observed x-values in increasing order. The quantity CY →X is defined
43 analogously. We then choose X → Y as the model whenever ĈX→Y < ĈY →X . This is called the
44 slope based approach to IGCI.
45 One can also show that an IGCI model satisfies the property that H(X) ≤ H(Y ), where H()
46 is the differential entropy. Intuitively, the reason is that applying a nonlinear function f to pX
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
30.5. LEARNING CAUSAL DAGS

1
2 can introduce additional irregularities, thus making pY less uniform that pX . This is illustrated in
3 Figure 30.24. We can then choose between X → Y and X ← Y based on the difference in estimated
4 entropies.
5 An empirical comparison of the slope-based and entropy-based approaches to IGCI can be found
6 in [Mooij2016].
7
8
9
10
11
12 30.5.2 Learning causal DAGs from interventional data
13
14 In Section 30.3, we discuss how to learn a DAG structure from observational data, using either
15 likelihood-based (Bayesian) methods of model selection, or constraint-based (frequentist) methods.
16 (See [Tu2019nips] for a recent empirical comparison of such methods applied to a medical simulator.)
17 However, such approaches cannot distinguish between models that are Markov equivalent, and thus
18 the output may not be sufficient to answer all causal queries of interest.
19 To distinguish DAGs within the same Markov equivalence class, we can use interventional data,
20 where certain variables have been set, and the consequences have been measured. In particular,
21 we can modify the standard likelihood-based DAG learning method discussed in Section 30.3 to
22 take into account the fact that the data generating mechanism has been changed. For example, if
23 θijk = p(Xi = j|Xpa(i) = k) is a CPT for node i, then when we compute the sufficient statistics
P
24 Nijk = n I Xni = j, Xn,pa(i) = k , we exclude cases n where Xi was set externally by intervention,
25 rather than sampled from θijk . This technique was first proposed in [Cooper99], and corresponds
26 to Bayesian parameter inference from a set of mutiliated models with shared parameters.
27 The preceding method assumes that we use perfect interventions, where we deterministically
28 set a variable to a chosen value. In reality, experimenters can rarely control the state of individual
29 variables. Instead, they can perform actions which may affect many variables at the same time. (This
30 is sometimes called a “fat hand intervention”, by analogy to an experiment where someone tries
31 to change a single component of some system (e.g., electronic circuit), but accidentally touching
32 multiple components and thereby causing various side effects.) We can model this by adding the
33 intervention nodes to the DAG (Main Section 4.7.3), and then learning a larger augmented DAG
34 structure, with the constraint that there are no edges between the intervention nodes, and no edges
35 from the “regular” nodes back to the intervention nodes.
36 For example, suppose we perturb various proteins in a cellular signalling pathway, and measure
37 the resulting phosphorylation status using a technique such as flow cytometry, as in [Sachs05]. An
38 example of such a dataset is shown in Main Figure 20.7(a). Main Figure 20.7(b) shows the augmented
39 DAG that was learned from the interventional flow cytometry data depicted in Main Figure 20.7(a).
40 In particular, we plot the median graph, which includes all edges for which p(Gij = 1|D) > 0.5.
41 These were computed using the exact algorithm of [Koivisto06]. See [Eaton07aistats] for details.
42 Since interventional data can help to uniquely identify the DAG, it is natural to try to choose
43 the optimal set of interventions so as to discover the graph structure with as little data as possible.
44 This is a form of active learning or experiment design, and is similar to what scientists do. See
45 e.g., [Murphy01causal; He09; Kalisch2014; Hauser2014; Mueller2017] for some approaches
46 to this problem.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


262

1
2 30.5.3 Learning from low-level inputs
3
In many problems, the available data is quite “low level”, such as pixels in an image, and is believed
4
to be generated by some “higher level” latent causal factors, such as objects interacting in a scene.
5
Learning causal models of this type is known as causal representation learning, and combines
6
the causal discovery methods discussed in this section with techniques from latent variable modeling
7
(e.g., VAEs, Main Chapter 21) and representation learning (Main Chapter 32). For more details, see
8
e.g., [Chalupka2017; Scholkopf2021].
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
31 Non-parametric Bayesian models

This chapter is written by Vinayak Rao.

31.1 Dirichlet processes


A Dirichlet process (DP) is a nonparametric probability distribution over probability distributions,
and is useful as a flexible prior for unsupervised learning tasks like clustering and density modeling
[ferguson1973bayesian]. We give more details in the sections below.

31.1.1 Definition of a DP
Let G be a probability distribution or a probability measure (we will use the latter terminology in
this chapter) on some space Θ. Recall that a probability measure is a function thatRassigns values
to subsets T ⊆ Θ satisfying the usual axioms of probability: 0 ≤ G(T ) ≤ 1, G(Θ) = Θ G(θ)dθ = 1,
PK
and for disjoint subsets T1 , . . . , TK of Θ, G(T1 ∪ . . . ∪ TK ) = k=1 G(Tk ). Bayesian unsupervised
learning now seeks to place a prior on the probability measure G.
We have already seen examples of parametric priors over probability measures. As a simple example,
consider a Gaussian distribution N (θ|µ, σ 2 ): this is a probability measure on Θ, and by placing
priors on the parameters µ and σ 2 , we have a parametric prior on probability measures. Mixture
models form more flexible priors, allowing multimodality and asymmetry, and are parameterized
by the probabilities of the mixture components, as well as their parameters. DPs directly define a
probability on probability measures G.
A Dirichlet process is specified by a positive real number α, called the concentration parameter,
and a probability measure H, called the base measure. We write a random measure drawn from
a DP as G ∼ DP(α, H). H is typically a standard probability measure on Θ, and forms the mean
of the Dirichlet process. That is, if G ∼ DP(α, H), then for any subset T of Θ, E[G(T )] = H(T ).
The parameter α measures how concentrated the Dirichlet process is around H, with V [G(T )] =
H(T )(1−H(T ))
1+α . If Θ is R2 , then setting H to the bivariate normal N (0, I2 ) and α to a large value
implies a prior belief that G sampled from DP(α, H) is close to the normal, whereas a small α
represents a relatively uninformative prior.
We now define the Dirichlet process more precisely. Let (T1 , . . . , TK ) be a finite partition of
Θ, that is, (T1 , . . . , TK ) are disjoint sets whose union is Θ. For a probability measure G, let
(G(T1 ), . . . , G(TK )) be the vector of probabilities of the elements of this partition. Then DP(α, H) is
a prior over probability measures G satisfying the following requirement: for any finite partition, the
264

1
2
3
4
5
6
7
8
Figure 31.1: Partitions of the unit square. (left) One possible partition into K = 3 regions, and (center) A
9
refined partition into K = 4 regions. In both figures, the shading of cell Tk is proportional G(Tk ), resulting
10
from the same realization of a Dirichlet process. (right) An ‘infinite partition’ of the unit square. The Dirichlet
11
process can informally be viewed as an infinite-dimensional Dirichlet distribution defined on this.
12
13
14
associated vector of probabilities has the following joint Dirichlet distribution:
15
16
(G(T1 ), . . . , G(TK )) ∼ Dir(αH(T1 ), . . . , αH(TK )). (31.1)
17
18 Just like the Gaussian process, the DP is defined implicitly though a set of finite-dimensional
19 distributions, in this case through the distribution of G projected onto any finite partition. The
20 finite-dimensional distributions are consistent in the following sense: if T11 and T12 form a partition
21 of T1 , then one can sample G(T1 ) in two ways: directly, by sampling
22
23 (G(T1 ), . . . , G(TK )) ∼ Dir(αH(T1 ), . . . , αH(TK )) (31.2)
24
25 or, indirectly, by sampling
26
27
(G(T11 ), G(T12 ), . . . , G(TK )) ∼ Dir(αH(T11 ), αH(T12 ), . . . , αH(TK )) (31.3)
28
and then setting G(T1 ) = G(T11 ) + G(T12 ). From the properties of the Dirichlet distribution, G(T1 )
29
sampled either way follows the same distribution. This consistency property implies, via Kolmogorov’s
30
extension theorem [kallenberg2006foundations], that underlying all finite-dimensional probability
31
vectors for different partitions is a single infinite-dimensional vector that we could informally write
32
as
33
34
G(dθ1 ), . . . , G(θ∞ ) ∼ Dir(αH(dθ1 ), . . . , αH(dθ∞ )). (31.4)
35
36 Very roughly, this ‘infinite-dimensional Dirichlet distribution’ is the Dirichlet process. Figure 31.1
37 sketches this out.
38 Why is the Dirichlet process, defined in this indirect fashion, useful to practitioners? The answer
39 has to do with conjugacy properties that it inherits from the Dirichlet distribution. One of the
40 simplest unsupervised learning problems seeks to learn an unknown probability distribution G from
41 iid samples {θ1 , . . . , θN } drawn from it. Consider placing a DP prior on the unknown G. Then given
42 the data, one is interested in the posterior distribution over G, representing the updated probability
43 distribution over G. For a partition (T1 , . . . , TK ) of Θ, an observation falls into the cell z following a
44 multinoulli distribution:
45
46 z ∼ Cat(G(T1 ), . . . , G(TK )). (31.5)
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
31.1. DIRICHLET PROCESSES

1
2
3
4
5
6
7
(a) (b)
8
9
Figure 31.2: Realizations from a Dirichlet process when Θ is (a) the real line, and (b) the unit square. Also
10 shown are the base measures H. In reality, the number of atoms is infinite for both cases.
11
12
13
Under a DP prior on G, (G(T1 ), . . . , G(TK )) follows a Dirichlet distribution (equation (31.1)). From
14
the Dirichlet-multinomial conjugacy, the posterior for (G(T1 ), . . . , G(TK )) given the observations is
15
16 N N
X  X 
17 (G(T1 ), . . . , G(TK ))|{θ1 , . . . , θN } ∼ Dir(G(T1 )+ I θi ∈ T1 , . . . , G(TK )+ I θi ∈ TK ) (31.6)
18 i=1 i=1
19
20 This is true for any finite partition, so that following our earlier definition, the posterior over G itself
21 is a Dirichlet process, and it is easy to see that:
22
N
!!
23 1 X
G|θ1 , . . . , θN , α, H ∼ DP α + N, αH + δθi . (31.7)
24 α+N i=1
25
26 Thus we see that the DP prior on G is a conjugate prior for iid observations from G, with the
27 posterior distribution over G also a Dirichlet process with concentration parameter α + N , and base
28 measure a convex combination of the original base measure H and the empirical distribution of
29 the observations. Note that as N increases, the influence of the original base measure H starts to
30 wane, and the posterior base measure becomes closer and closer to the empirical distribution of the
31 observations. At the same time, the concentration parameter increases, suggesting that the posterior
32 distribution concentrates around the empirical distribution.
33
34
31.1.2 Stick breaking construction of the DP
35
36 Our discussion so far has been very abstract, with no indication of how to sample either the random
37 measure G or observations from G. We address the first question, giving a constructive definition for
38 the DP known as the stick-breaking construction [sethuraman1994constructive].
39 We first mention that probability measures G sampled from a DP are discrete with probability
40 one (see Figure 31.2), taking the form
41

X
42
G(θ) = πk δθk (θ). (31.8)
43
k=1
44
45 Thus G consists of an infinite number of atoms, the kth atom located at θk , and having weight πk .
46 Informally, this follows from Equation (31.4), which represents the DP as an infinite-dimensional
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


266

1
2 β1 1−β1
3 π1 β2 1−β2
4
5
π2 β3 1−β3
6 π3 β4 1−β4
7
8
π4 β5
9 π5
10
11 (a) (b)
12
13 Figure 31.3: Illustration of the stick breaking construction. (a) We have a unit length stick, which we
14 break at a random point β1 ; the length of the piece we keep is called π1 ; we then recursively break off pieces
15
of the remaining stick, to generate π2 , π3 , . . .. From Figure 2.22 of [SudderthThesis]. Used with kind
permission of Erik Sudderth. (b) Samples of πk from this process for different values of α. Generated by
16
stick_breaking_demo.ipynb.
17
18
19
20 but infinitely-sparse Dirichlet distribution (recall that as its parameters become smaller, a Dirichlet
21 distribution concentrates on sparse distributions that are dominated by a few components).
22 For a DP, the locations θk of the atoms are drawn independently from the base measure H, whereas
23 the concentration parameter α controls the distribution of the weights πk . Observe that the infinite
24 sequence of weights (π1 , π2 , . . . ) must add up to one, since G is a probability measure. The weights
25 can be simulated by the following process sketched in Figure 31.3, and known as the stick-breaking
26 process. Start with a stick of length 1 representing the total probability mass, and sequentially
27 break off a random Beta(1, α) distributed fraction of the remaining stick. The kth break forms πk .
28 In equations, for k = 1, 2, . . . ,
29
βk ∼ Beta(1, α), θk ∼ H, (31.9)
30
k−1
Y k−1
X
31
32 π k = βk (1 − βl ) = βk (1 − πl ) (31.10)
33 l=1 l=1
P∞
34
Then, setting G(θ) = k=1 πk δθk (θ), one can show that G ∼ DP(α, H). The distribution over the
35
weights is often denoted by
36
37 π ∼ GEM(α), (31.11)
38
39 where GEM stands for Griffiths, Engen, and McCloskey (this term is due to [Ewens90]). Some
40 samples from this process are shown in Figure 31.3.
41 We note that since the number of atoms in infinite, one cannot exactly simulate from a DP in finite
42 time. However, the sequence of weights from the GEM distribution are stochastically ordered,
43 having decreasing averages, and the truncation error resulting from terminating after a finite number
44 of steps quickly becoming neglible [ishwaran2001gibbs]. Nevertheless, we will see in the next
45 section that it is possible to simulate samples and make predictions from a DP-distributed probability
46 measure G without any truncation error. This exploits the conjugacy propery of the DP.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
31.1. DIRICHLET PROCESSES

1
2 31.1.3 The Chinese restaurant process (CRP)
3
Consider a single observation θ1 from a DP-distributed probability measure G. The probability
4
that θ1 lies within a set T ⊆ Θ, marginalizing out the random G, is E[G(T )] = H(T ), the equality
5
following from the definition of the DP. This holds for arbitrary T , which implies that the first
6
observation θ1 is distributed as the base measure of the DP:
7
8
p(θ1 = θ|α, H) = H(θ). (31.12)
9
10
Given N observations θ1 , . . . , θN , the updated distribution over G is still a DP, but now modified
11
as in Equation (31.7). Repeating the same argument, it follows that the (N + 1)st observation is
12
distributed as the base measure of the posterior DP, given by
13
K
!
14
1 X
15 p(θN +1 = θ|θ1:N , α, H) = αH(θ) + Nk δθk (θ) (31.13)
α+N
16 k=1
17
18 where Nk is the number of observations equal to θk . The previous two equations form the basis of
19 what is called the Pólya urn or Blackwell-MacQueen sampling scheme [blackwell1973ferguson].
20 This provides a way to exactly produce samples from a DP-distributed random probability measure.
21 It is often more convenient to work with discrete variables (z1 , . . . , zN ), with zi specifying which
22 value of θk the ith sample takes. In particular, for the ith observation, θi = θzi . This allows us to
23 decouple the cluster or partition structure of the dataset (controlled by α) and the cluster parameters
24 (controlled by H). Let us assign the first observation to cluster 1, i.e., z1 = 1. The second observation
25 can either belong to the same cluster as observation 1, or belong to a new cluster, which we call
26 cluster 2. In the former event, z2 = 1, after which z3 can equal 1 or 2. In the latter event, z2 = 2,
27 and z3 can equal 1, 2, or 3. Based on the Equation (31.13), we have
28
K
!
29 1 X
p(zN +1 = z|z1:N , α) = αI (z = K + 1) + Nk I (z = k) , (31.14)
30 α+N
k=1
31
32 assuming the first N observations have been assigned to K clusters. This is called the Chinese
33 restaurant process or CRP, based on the following analogy: observations are customers in a
34 restaurant with an infinite number of tables, each corresponding to a different cluster. Each table
35 has a dish, corresponding to the parameter θ of that cluster. When a customer enters the restaurant,
36 they may choose to join an existing table with probability proportional to the number of people
37 already sitting at this table (i.e., they join table k with probability proportional to Nk ); otherwise,
38 with probability proportional to α, they choose to sit at a new table, ordering a new dish by sampling
39 from the base measure H.
40 The sequence Z = (z1 , . . . , zN ) of cluster assignments is partition of the integers 1 to N , and
41 the CRP is a distribution over such partitions. The probability of creating a new table diminishes as
42 the number of observations increases, but is always non-zero, and one can show that the number of
43 occupied tables K approaches α log(N ) as N → ∞ almost surely. The fact that currently occupied
44 tables are more likely to get new customers is sometimes called a rich get richer phenomenon.
45 It is important to recognize that despite being defined as a sequential process, the CRP is an
46 exchangeable process, with partition probabilities that are independent of the observation indices.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


268

1
2 Indeed, it is easy to show that the probability of a partition of N integers into K clusters with sizes
3 N1 , . . . , NK is
4
K
5 αK−1 Y
p(N1 , . . . , Nk ) = Nk ! (31.15)
6 [α + 1]N
1 k=1
7
8 QN −1
Here, [α + 1]N
1 = i=0 (α + 1 + i) is the rising factorial. Equation (31.15) depends only on the cluster
9
sizes, and is called the Ewens sampling formula [ewens1972sampling]. Exchangeability implies
10
that the probability that the first two customers sit at the same table is the same as the probability
11
that the first and last sit at the same table. Similarly all customers have the same probability of
12
ending up in a cluster of size S. The fact that the first customer can only belong to cluster 1 (i.e.,
13
that z1 = 1) does not contradict exchangeability and reflects the fact that the cluster indices are
14
chosen arbitrarily. This disappears if we index clusters by their associated parameter θk .
15
16
17 31.2 Dirichlet process mixture models
18
19 Real-world datasets are often best modeled by continuous probability densities. By contrast, a
20 sample G from a DP is discrete with probability one, and sampling observations from G will result
21 in repeated values, making it inappropriate for many applications. However, the discrete structure
22 of G is useful in clustering applications, as a prior for the latent variables underlying the observed
23 datapoints. In particular, zi and θ i can represent the cluster assignment and cluster parameter of
24 the i’th datapoint, whose value xi is a draw from some parametric distribution F (x|θ) indexed by θ,
25 with base measure H. The resulting model follows along the lines of a standard mixture model, but
26 now is an infinite mixture model, consisting of an infinite number of components or clusters, one
27 for each atom in G.
28 A very common setting when xi ∈ Rd is to set F to be the multivariate normal distribution,
29 θ = (µ, Σ), and H to be the normal-inverse-Wishart distribution. Then, each of the infinite clusters
30 has an associated mean and covariance matrix, and to generate a new observation, one picks cluster k
31 with probability πk , and simulates from a normal with mean µk and covariance Σk . See Figure 31.4
32 for some samples from this model.
33 We discuss DP mixture models (DPMM) in more detail below.
34
35
36
31.2.1 Model definition
37 We define the DPMM model as follows:
38
39 π ∼ GEM (α), θk ∼ H, k = 1, 2, . . . (31.16)
40
zi ∼ π, xi ∼ F (θzi ), i = 1, . . . , N. (31.17)
41
42
Equivalently, we can write this as
43
44
G ∼ DP(α, H) (31.18)
45
46 θ i ∼ G, xi ∼ F (θ i ), i = 1, . . . , N. (31.19)
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
31.2. DIRICHLET PROCESS MIXTURE MODELS

1
2
alpha= 1.0 , N= 50 alpha= 2.0 , N= 50
3
4
5
6
7
8
9
10
11
12
13
14
15 (a) (b)
16 alpha= 1.0 , N= 500 alpha= 2.0 , N= 500
17
18
19
20
21
22
23
24
25
26
27
28 (c) (d)
29 alpha= 1.0 , N= 1000 alpha= 2.0 , N= 1000
30
31
32
33
34
35
36
37
38
39
40
41
42
(e) (f )
43
Figure 31.4: Some samples from a Dirichlet process mixture model of 2d Gaussians, with concentration
44
parameter α = 1 (left column) and α = 2 (right column). From top to bottom, we show N = 50, N = 500
45 and N = 1000 samples. Generated by dp_mixgauss_sample.ipynb.
46
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


270

1
2
3 H
4
5
6
α G
7
8 θ 2 θ 1
9 θ i
10
11
12
xi x2 x1
13 N
14 (a) (b)
15
16 Figure 31.5: Two views of N observations sampled from a DP mixture model. Left: representation where
17 cluster indicators are sampled from the GEM-distributed distribution π. Right: representation where parameters
18
are samples from the DP-distributed random measure G. The rightmost picture illustrates the case where
N = 2, θ is real-valued with a Gaussian prior H(·), and F (x|θ) is a Gaussian with mean θ and variance σ 2 .
19
We generate two parameters, θ1 and θ2 , from G, one per data point. Finally, we generate two data points, x1
20
and x2 , from N (θ1 , σ 2 ) and N (θ2 , σ 2 ). From Figure 2.24 of [SudderthThesis]. Used with kind permission
21
of Erik Sudderth.
22
23
24
P∞
25 G and F together define the infinite mixture model: GF (x) ∼ k=1 πk F (x|θk ). If F (x|θ) is
26 continuous, then so is GF (x), and the Dirichlet process mixture model serves as a nonparametric
27 prior over continuous distributions or probability densities.
28 Figure 31.5 illustrates two graphical models that summarize this, corresponding to the two sets
29 of equations above. The first generates the set of weights (π1 , π2 , . . . ) from the GEM distribution,
30 along with an infinite collection of cluster parameters (θ1 , θ2 , . . . ). It then generates observations by
31 first sampling a cluster indicator zi from π, indexing the associated cluster parameter θzi and then
32 simulating the observation xi from F (θzi ). The second graphical model simulates a random measure
33 G from the DP. It generates observations by directly simulating a parameter θ i from G, and then
34 simulating xi from F (θ i ). The infinite mixture model can be viewed as the limit of a K-component
35 finite mixture model with a Dir(α/K, . . . , α/K) prior on the mixture weights (π1 , . . . , πK ) and with
36 mixture parameters θ1 , . . . , θK , as K → ∞ [Rasmussen00; Neal00]. Producing exact samples
37 (x1 , . . . , xN ) from this model involves one additional step to the Chinese restaurant process: after
38 selecting a table (cluster) zi with its associate dish (parameter) θzi , the i’th customer now samples
39 an observation from the distribution F (θzi ).
40
41
31.2.2 Fitting using collapsed Gibbs sampling
42
43 Given a dataset of observations, one is interested in the posterior distribution p(G, z1 , . . . , zN |x1 , . . . , xN , α, H),
44 or equivalently, p(π, θ1 , θ2 , . . . , z1 , . . . , zN |x1 , . . . , xN , α, H). The most common way to fit a DPMM
45 is via Markov chain Monte Carlo (MCMC), producing samples by constructing a Markov chain
46 that targets this posterior distribution. We describe a collapsed Gibbs sampler based on the
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
31.2. DIRICHLET PROCESS MIXTURE MODELS

1
2 Chinese restaurant process that marginalizes out the infinite-dimensional random measure G, and
3 that targets the distribution p(z1 , . . . , zN |x1 , . . . , xN , α, H) summarizing all clustering information.
4 It produces samples from this distribution by cycling through each observation xi , and updating its
5 assigned cluster zi , conditioned on all other variables. Write x−i for all observations other than the
6 ith observation, and z−i for their cluster assignments. Then we have
7
8 p(zi = k|z−i , x, α, H) ∝ p(zi = k|z−i , α)p(xi |x−i , zi = k, z−i , H) (31.20)
9
10 By exchangeability, each observation can be treated as the last customer to enter the restaurant.
11 Hence the first term is given by
12  
K−i
13 1 X
p(zi |z−i , α) = αI (zi = K−i + 1) + Nk,−i I (zi = k) (31.21)
14 α+N −1
k=1
15
16
where Nk,−i is the number of observations in cluster k, and K−i is the number of clusters used
17
by x−i , both obtained after removing observation i, eliminating empty clusters, and renumbering
18
clusters.
19
To compute the second term, p(xi |x−i , zi = k, z−i , H), let us partition the data x−i into clusters
20
based on z−i . Let x−i,c = {xj : zj = c, j 6= i} be the datapoints assigned to cluster c. If zi = k, then
21
xi is conditionally independent of all the datapoints except those assigned to cluster k. Hence,
22
23 p(xi , x−i,k )
24
p(xi |x−i , z−i , zi = k) = p(xi |x−i,k ) = , (31.22)
p(x−i,k )
25
26 where
27  
Z Y
28
29 p(xi , x−i,k ) = p(xi |θk )  p(xj |θk ) H(θk )dθk (31.23)
30 j6=i:zj =k

31
32
is the marginal likelihood of all the data assigned to cluster k, including i, and p(x−i,k ) is an analogous
33
expression excluding i. Thus we see that the term p(xi |x−i , z−i , zi = k) is the posterior preditive
34
distribution for cluster k evaluated at xi .
35
If zi = k ∗ , corresponding to a new cluster, we have
36 Z

37 p(xi |x−i , z−i , zi = k ) = p(xi ) = p(xi |θ)H(θ)dθ (31.24)
38
39 which is just the prior predictive distribution for a new cluster evaluated at xi .
40 The overall sampler is sometimes called “Algorithm 3” (from [Neal00]). Algorithm 31.1 provides
41 the pseudocode. The algorithm is very similar to collapsed Gibbs for finite mixtures except that we
42 have to consider the case zi = k ∗ . Note that in order the evaluate the integrals in Equation (31.23)
43 and Equation (31.24), we require the base measure H to be conjugate to the likelihood F . For
44 example, if we use an NIW prior for the Gaussian likelihood, we can use the results from ?? to
45 compute the predictive distributions. Extensions to the case of non-conjugate priors are discussed in
46 [Neal00].
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


272

1
2
3
Algorithm 31.1: Collapsed Gibbs sampler for DP mixture model
4
5
1 foreach i = 1 : N in random order do
6
2 Remove xi ’s sufficient statistics from old cluster zi
7
3 foreach k = 1 : K do
8
4 Compute pk (xi ) = p(xi |x−i (k))
9
5 Set Nk,−i = dim(x−i (k))
Nk,−i
10 6 Compute p(zi = k|z−i , D) = α+N −1 pk (xi )
11 7 Compute p(zi = ∗|z−i , D) = α+N α
−1 p(xi )
12
8 Normalize p(zi |·)
13
9 Sample zi ∼ p(zi |·)
14
10 Add xi ’s sufficient statistics to new cluster zi
15
11 If any cluster is empty, remove it and decrease K
16
17
18
19
20 DP DP
21
22
23
24
25
26
27
28
29
30
(a) (b)
DP DP
31 0.30
32 0.8
0.25
33
0.6 0.20
34
35 0.4 0.15
36 0.10
37 0.2
0.05
38
0.0 0.00
39 1 2 3 4 5 6 7 8 9 1 2 3 4 5 6 7 8 9
40 (c) (d)
41
42 Figure 31.6: Output of the DP mixture model fit using Gibbs sampling to two different datasets. Left column:
43
dataset with 4 clear clusters. Right column: dataset with an unclear number of clusters. Top row: single
sample from the Markov chain. Bottom row: empirical fraction of times a given number of cluster is used,
44
computed across all samples from the chain. Generated by dp_mixgauss_sample.ipynb.
45
46
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
31.2. DIRICHLET PROCESS MIXTURE MODELS

1
2 31.2.3 Fitting using variational Bayes
3
This section is written by Xinglong Li.
4
5 In this section, we discuss how to fit a DP mixture model using mean field variational Bayes (??),
6 as described in [blei2006variational].
7 Given samples x1 , . . . , xN from DP mixture, the mean field variational inference (MFVI) algorithm
8 is based on the stick-breaking representation of the DP mixture. The target of the inference is the joint
9 posterior distribution of the beta random variables β = {β1 , β2 . . .} in the stick-breaking construction
10 of the DP, the locations θ = {θ1 , θ2 , . . .} of atoms, and the cluster assignments z = {z1 , . . . , zN }:
11
12 w = {β, θ, z} (31.25)
13
14
The hyperparameters are the concentration parameter of the DP and the parameter of the conjugate
15
base distribution of θ:
16
λ = {α, η} (31.26)
17
18
The variational inference algorithm minimizes the KL-divergence between qψ (w) and p(w|x, λ):
19
20 DKL (qψ (w) k p(w|x, λ)) = Eq [log qψ (w)] − Eq [log p(w, x|λ)] + log p(x|λ) (31.27)
21
22 Minimizing the KL divergence is equivalent to maximizing the evidence lower bound (ELBO):
23
24 Ł =Eq [log p(w, x|λ)] − Eq [log qψ (w)] (31.28)
25 N
X
26 =Eq [log p(β|α)] + Eq [log p(θ|η)] + (Eq [log p(zn |β)] + Eq [log p(xn |zn )]) (31.29)
27 n=1
28 − Eq [log qψ (β, θ, z)] (31.30)
29
30 To deal with the infinite parameters in β and θ, the variational inference algorithm truncates the
31 DP by fixing a value T and setting q(βT = 1) = 1, which implies that πt = 0 for t > T . Therefore,
32 qψ (w) in the MFVI for DP mixture models factorizes into
33
34 −1
TY T
Y N
Y
35 qψ (β, θ, z) = qγ t (βt ) qτ t (θt ) qφn (zn ), (31.31)
36 t=1 t=1 n=1

37
where qγ t (βt ) is the beta distribution with parameters {γt,1 , γt,2 }, qτ t (θt ) is the exponential family
38
distribution with natural parameters τ t , and qφn (zn ) is the categorical distribution of the cluster
39
assignment of observation xn , with q(zn = t) = φn,t . The free variational parameters are
40
41 ψ = {γ 1 , · · · , γ T −1 , τ 1 , · · · , τ T , φ1 , · · · , φN }. (31.32)
42
43 Notice that only qψ (w) is truncated, the true posterior p(w, x|λ) from the model need not be
44 truncated when minimizing the KL.
45 The MFVI can be optimized via the coordinate ascent algorithm, and the closed form update in
46 each step exists when the base measure of the DP is conjugate to the likelihood of observations. In
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


274

1
2 particular, suppose that conditional distribution of xn conditioned on zn and θ is an exponential
3 family distribution (??):
4
5 p(xn |zn , θ1 , θ2 , . . .) = h(xn ) exp{θTzn xn − a(θzn )} (31.33)
6
7
where xn is the sufficient statistic for the natural parameter θzn . Therefore, the conjugate base
8
distribution is
9
p(θ|η) = h(θ) exp(ηT1 θ + η2 (−a(θ)) − a(η)), (31.34)
10
11
where η 1 contains the first dim(θ) components and η2 is a scalar. See Algorithm 31.2 for the resulting
12
pseudocode.
13
Extensions of this method to infer the hyperparameters, λ = {α, η}, can be found in the appendix
14
of [blei2006variational].
15
16
17
Algorithm 31.2: Variational inference for DP mixture model
18 1 Initialize the variational parameters:
19 2 φnt is membership probability of xn in cluster t;
20 3 τ t are the natural parameters for cluster t;
21 4 γ t are the parameters for the stick breaking distribution.
22 5 while not converged do
23 6 foreach γ t do
24 7 Update thePbeta distribution qγ t (βt );
25 8 γt,1 = 1 + n φn,t
P PT
26 9 γt,2 = α + n j=t+1 φn,j
27
10 foreach τ t do
28
11 Update the exponential
P family distribution qτ t (θt ) given sufficient statistics {xn };
29
12 τ t,1 = η 1 +P n φn,t xn
30
13 τt,2 = η2 + n φn,t
31
32 14 foreach φn,t do
33 15 Update the categorical distribution qφn (zn ) for each observation;
34 16 φn,t ∝ exp(St )
Pt−1
35 17 St = Eq [log βt ] + i=1 Eq [log(1 − βi )] + Eq [θt ]T xn − Eq [a(θt )]
36
37
38
39
31.2.4 Other fitting algorithms
40
41 While collapsed Gibbs ssmpling is the simplest approach to posterior inference for DPMMs, a variety
42 of other methods have been proposed as well. One popular class of MCMC samplers works with
43 the stick-breaking representation of the DP instead of CRP, instantiating the random measure
44 G [ishwaran2001gibbs]. The sampler then proceeds by sampling the cluster assignments z given
45 G, and then G given z. An advantage of this is that the cluster assignments can be updated in
46 parallel, unlike the CRP, where they are updated sequentially. To be feasible however, these methods
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
31.3. GENERALIZATIONS OF THE DIRICHLET PROCESS

1
2 require truncating G to a finite number of atoms, though the resulting approximation error can
3 be quite small. The posterior approximation error can be eliminated altogether by slice-sampling
4 methods [kalli2011slice], that work with random truncation levels.
5 Alternatives to MCMC also exist. [Daume07] shows how one can use A* search and beam search
6 to quickly find an approximate MAP estimate. [Mansinghka07] discusses how to fit a DPMM
7 online using particle filtering, which is a like a stochastic version of beam search. This can be more
8 efficient than Gibbs sampling, particularly for large datasets.
9 In Section 31.2.3 we discussed an approach based on mean field variational inference. A variety
10 of other variational approximation methods have been proposed as well, for example [Kurihara06;
11 teh2008collapsed; Zobay09; wang2012truncation].
12
13
31.2.5 Choosing the hyper-parameters
14
15 An important issue is how to set the model hyper-parameters. These include the DP concentration
16 parameter α, as well as any parameters λ of the base measure H. For the DP, the value of α does not
17 have much impact on predictive accuracy, but has quite a strong affect the number of clusters. One
18 approach is to put a gamma prior for α, and then form its posterior, p(α|K, N, a, b) [Escobar95].
19 Simulating α given the cluster assignments z is quite straightforward, and can be incorporated
20 into the earlier Gibbs sampler. The same is the case with the hyper-parameters λ [Rasmussen00].
21 Alternatively, one can use empirical Bayes [McAuliffe06] to fit rather than sample these parameters.
22
23
31.3 Generalizations of the Dirichlet process
24
25 Dirichlet process mixture models are flexible nonparametric models of continuous probability densities,
26 and if set up with a little care, can possess important frequentist properties like asymptotic
27 consistency: with more and more observations, the posterior distribution concentrates around the
28 ‘true’ data generating distribution, with very little assumed about the this distribution. Nevertheless,
29 DPs still represent very strong prior information, especially in clustering applications. We saw that the
30 number of clusters in a dataset of size N a priori is around α log N . As indicated by Equation (31.15),
31 not just the number of clusters, but also the distribution of their sizes is controlled by a single
32 parameter α. The resulting clustering model is thus quite inflexible, and in many cases, inappropriate.
33 Two examples from machine learning that highlight its limitations are applications involving text and
34 image data. Here, it has been observed empirically that the number of unique words in a document,
35 the frequency of word usage, the number of objects in an image, or the number of pixels in an object
36 follow power-law distributions. Clusterings sampled from the CRP do not produce this property, and
37 the resulting model mismatch can result in poor predictive performance.
38
39
40
31.3.1 Pitman-Yor process
41 A popular generalization of the Dirichlet process is the Pitman-Yor process [pitman1997two]
42 (also called the two-parameter Poisson-Dirichlet process). Written at PYP(α, d, H), the Pitman-
43 Yor process includes an additional discount parameter d which must be greater than 0. The
44 concentration parameter can now be negative, but must satisfy α ≥ −d. As with the DP, a sample
45 G from a PYP is a random probability measure that is discrete almost surely. It has a stick-breaking
46 representation that generalizes that of the DP: again, we start with a stick of length 1, but now at
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


276

1
2 step k, a random Beta(1 − d, α + kd) fraction of the remaining probability mass is broken off, so that
3 G is written as
4 X∞
5 G= πk δ θ k , (31.35)
6 k=1
7 βk ∼ Beta(1 − d, α + kd), θk ∼ H, (31.36)
8 k−1 k−1
Y X
9 π k = βk (1 − βl ) = βk (1 − πl ). (31.37)
10 l=1 l=1
11
12
Because G is discrete, once again observations θ1 , . . . , θd sampled iid from G will possess a clustering
13
structure. These can directly be sampled following a sequential process that generalizes the CRP.
14
Now, when a new customer enters the restaurant, they join an existing table with Nk customers
15
with probability proportional to (Nk − d), and create a new table with probability proportional to
16
α + Kd, where K is the number of clusters.
17
Observe that the Dirichlet process is a special instance of the Pitman-Yor process, corresponding
18
to d equal to 0. Non-zero settings of d counteract the rich-get-richer dynamics of the DP to some
19
extent, increasing the probability of creating new clusters. The more clusters there are present in a
20
dataset, the higher the probability of creating a new cluster (relative to the Dirichlet process). This
21
behavior means that a large number of clusters, as well as a few very large clusters are more likely
22
under the PYP than the DP.
23
An even more general class of probability measures are the so-called Gibbs-type priors [de2013gibbs].
24
Under such a prior, given N observations, the probability these are clustered into K clusters, the kth
25
having nk observations, is
26 K
Y
27 p(n1 , . . . , nk ) = VN,K (σ − 1)Nk −1 , (31.38)
28 k=1
29
where (σ)n = σ(σ + 1) . . . (σ + n − 1), and for non-negative weights VN,K satisfying VN,K =
30
(N − σK)VN +1,K + VN +1,K+1 and V1,1 = 1. This definition ensures that the probability over
31
partitions is consistent, exchangeable and tractable, and includes the DP and PYP as special cases.
32
Besides these two, Gibbs-type priors include settings where the number of components (or the number
33
of atoms in the random measures) are random but bounded. Recall that DP and PYP mixture models
34
are infinite mixture models, with the number of components growing with the number of observations.
35
A sometimes undesirable feature of these models is that if a dataset is generated from a finite number
36
of clusters, these models will not xs recover the true number of clusters [miller2014inconsistency].
37
Instead, the estimated number of clusters will increasexs with the size of the dataset, resulting in
38
redundant clusters that are located very close to each other. Gibbs-type priors with almost surely
39
bounded number of components can learn the true number of clusters while still remaining reasonably
40
tractable: so long as one can calculate the VN,K terms, MCMC for all these models can by carried
41
out by modifications of the CRP-based sampler described earlier.
42
43
31.3.2 Dependent random probability measures
44
45 Dirichlet processes, and more generally, random probability measures, have also been generalized
46 from settings with a single set of observations to those involving grouped observations, or observations
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
31.3. GENERALIZATIONS OF THE DIRICHLET PROCESS

1
2 indexed by covariates. Consider T sets of observations {X1 , . . . , XT }, each perhaps corresponding to
3 a different country, a different year, or more abstractly, a different set of observed covariates. There
4 are two immediate (though inadequate) ways to model such data: 1) by pooling all datasets into a
5 single dataset, modeled as drawn from a DP or DPMM-distributed random probability measure G, or
6 2) by treating each group as independent, having its own random probability measure Gt . The first
7 approach fails to capture differences between the groups, while the second ignores similarities between
8 the groups (e.g., shared clusters). Dependent random probability measures seek a compromise
9 between these extremes, allowing statistical information to be shared across different groups.
10 A seminal paper in the machine learning literature that addresses this problem is the hierarchical
11 Dirichlet process (HDP) [teh06_hdp]. The basic idea here is simple: each group has its own
12 random measure drawn from a Dirichlet process DP(α, H). The twist now is that the base measure
13 H itself is random, in fact it is itself drawn from a Dirichlet process. Thus, the overall generative
14 process is
15
16 H ∼ DP(α0 , H0 ), (31.39)
17 t
G ∼ DP(α, H), t ∈ 1, . . . , T (31.40)
18 t
19 θi ∼G, t
i ∈ 1, . . . , Nt , t ∈ 1, . . . , T. (31.41)
20
t
21 The θi s might be the observations themselves, or the latent parameter underlying each observation,
t
22 with xti ∼ F (θi ).
23 Recall that if a probability measure G1 is drawn from DP(α, H), its atoms are drawn independently
24 from the base measure H. In the HDP, H, which is a draw from a DP, is discrete, so that some
25 atoms of G1 will sit on top of each other, becoming a single atom. More importantly, all measures
26 Gt will share the same infinite set of locations: each atom of H will eventually be sampled to form a
27 location of an atom of each Gt . This will allow the same clusters to appear in all groups, though they
28 will have different weights. Moreover, a big cluster in one group is a priori likely to be a big cluster
29 in another group, as underlying both is a large atom in H. Since the common probability measure
30 H itself is random, its components (both weights as well as locations) will be learned from data.
31 Despite its apparent complexity, it is fairly easy to develop an analogue of the Chinese restaurant
32 process for the HDP, allowing us to sample observations directly without having to instantiate any of
33 the infinite-dimensional measures. This is called the Chinese restaurant franchise, and essentially
34 amounts to each group having its own Chinese restaurant with the following modification: whenever
35 a customer sits at a new table and orders a dish, that dish itself is sampled from an upper CRP
36 common to all restaurants. It is also possible to develop stick-breaking representations of the HDP.
37 The HDP has found wide application in the machine learning literature. A common application is
38 document modeling, where the location of each atom is a topic, corresponding to some distribution
39 over words. Rather than bounding the number of topics, there are an infinite number of topics, with
40 document d having its own distribution over topics (represented by a measure Gd ). By tying all the
41 Gd ’s together through an HDP, different documents can share the same topics while emphasizing
42 different topics.
43 Another application involves infinite-state hidden Markov models, also called HDP-HMM.
44 Recall from ?? that a Markov chain is parameterized by a transition matrix, with row r giving
45 the distribution over the next state if the current state is r. For an infinite-state HMM, this is an
46 infinite-by-infinite matrix, with row r corresponding to a distribution Gr with an infinite number of
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


278

1
2 atoms. The different Gr ’s can be tied together by modeling them with an HDP, so that atoms from
3 each correspond to the same states [Fox08].
4 Hierarchical nonparametric models of this kind can be constructed with other measures, such as the
5 Pitman-Yor process. For certain parameter settings, the PYP possesses convenient marginalization
6 properties that the DP does not [Wood2009]. In particular, simulating a random probability
7 measure (RPM) in the following two steps
8
9
G0 |H0 ∼ PYP(0, d0 , H0 ), (31.42)
10 G1 |G0 ∼ PYP(0, d1 , G0 ) (31.43)
11
12
is equivalent to directly simulating G1 without instantiating G0 as below:
13
G1 |H0 ∼ PYP(0, d0 d1 , H0 ). (31.44)
14
15 This marginalization property (also called coagulation) allows deep hierarchies of dependent RPMs,
16 with only a smaller, dataset-dependent subset of G’s having to be instantiated. In the sequence
17 memoizer of [wood2011sequence], the authors model sequential data (e.g., text) with hierarchies
18 that are infinitely deep, but with only a finite number of levels ever having to be instantiated. If
19 needed, intermediate random measures can be instantiated by a dual fragmentation operator.
20 Deeper hierarchies like those described above allow more refined modeling of similarity between
21 different groups. Under the original HDP, the groups themselves are exchangeable, with no subset of
22 groups a priori more similar to each other than to others. For instance, suppose each of the measures
23 G1 , . . . , GT correspond to distributions over topics in different scientific journals. Modeling the Gt ’s
24 with an HDP allows statistical sharing across journals (through shared clusters and similar cluster
25 probabilities), but does not regard some journals as a priori more similar than others. If one had
26 further information, e.g., that some are physics journals and the rest are biology journals, then one
27 might add another level to the hierarchy. Now rather than each journal having a probability measure
28 Gt drawn from a DP(α, H), physics and biology journals have their own base measures Hp and Hb
29 that allow statistical sharing among physics and biology journals respectively. Like the HDP, these
30 are draws from a DP with base measure H. To allow sharing across disciplines, H is again random,
31 drawn from a DP with base measure H0 . Overall, if there are D disciplines, 1, 2, . . . , D, the overall
32 model is
33
34 H ∼ DP(α0 , H0 ), (31.45)
35 d
H ∼ DP(α1 , H), d ∈ 1, . . . , D (31.46)
36
37 Gt,d d
∼ DP(α2 , H ), t ∈ 1, . . . , Td (31.47)
t,d
38 θi ∼ Gt,d d ∈ 1, . . . , D, t ∈ 1, . . . , Td , i ∈ 1, . . . , Nt (31.48)
39
40 One might add further levels to the hierarchy given more information (e.g., if disciplines are grouped
41 into physical sciences, social sciences and humanities).
42 Dependent random probability measures can also be indexed by covariates in some continuous
43 space, whether time, space, or some Euclidean space or manifold. This space is typically endowed
44 with some distance or similarity function, and one expects that measures with similar covariates are
45 a priori more similar to each other. Thus, Gt might represent a distribution over topics in year t,
46 and one might wish to model Gt evolving gradually over time. There is rich history of dependent
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
31.4. THE INDIAN BUFFET PROCESS AND THE BETA PROCESS

1
2 random probability measures in statistics literature, starting from [maceachern1999dependent].
3 A common requirement in such models is that at any fixed time t, the marginal distribution of Gt
4 follows some well specified distribution, e.g., a Dirichlet process, or a PYP. Approaches exploit the
5 stick-breaking representation, the CRP representation or something else like the Poisson structure
6 underlying such models (see Section 31.6).
7 As a simple and early example, recall the stick-breaking construction from Section 31.1.2, where a
8 random probability measure is represented by an infinite collection of pairs (βk , θk ). To construct
9 a family of RPMs indexed by some covariate t, we need a family of such sets (βkt , θkt ) for each of
10 the possibly infinite values of t. To ensure each Gt is marginally DP distributed with concentration
11 parameter α and base measure H, we need that for each t, the βkt s are marginally iid draws from a
12 Beta(1, α), and the θkt ’s iid draws from H. Further, we do not want independence across t, rather,
13 for two times t1 and t2 , we want similarity to increase as |t1 − t2 | decreases. To achieve this, define
14 an infinite sequence of independent Gaussian processes on T , fk (t), k = 1, 2, . . . , with mean 0 and
15 some covariance function. At any time t, fk (t) for all k are iid draws from a normal distribution.
16 By transforming these Gaussian processes through the cdf of a Gaussian density (to produce an iid
17 uniform random variables at each t), and then through the inverse cdf of a Beta(1, α)), one has an
18 infinite collection of iid Beta(1, α) random variables at any time t. The GP construction means that
19 each βkt varies smoothly with t. A similar approach can be used to construct smoothly varying θkt s
20 that are marginally H-distributed, and together, these form a family of gradually varying RPMs Gt ,
21 with
22 ∞ k−1
X Y
23
Gt = πkt δθkt , where πkt = βkt (1 − βjt ) (31.49)
24 i=1 j=1
25
26 Of course, such a model comes with formidable computational challenges, however the marginal
27 properties allows standard MCMC methods from the DP and other RPMs to be adapted to such
28 settings.
29
30 31.4 The Indian buffet process and the beta process
31
32 The Dirichlet process is a Bayesian nonparametric model useful for clustering applications, where
33 the number of clusters is allowed to grow with the size of the dataset. Under this generative model,
34 each observation is assigned to a cluster through the variable zi . Equivalently, zi can be written as
35 a one-hot binary vector, and the entire clustering structure can be written as a binary matrix Z
36 consisting of N rows, each with exactly one non-zero element (Figure 31.7(a)).
37 Clustering models that limit each observation to a single cluster can be overly restrictive, failing to
38 capture the complexity of the real datasets. For instance, in computer vision applications, rather
39 than assign an image to a single cluster, it might be more appropriate to assign it a binary vector of
40 features, indicating whether different object types are or are not present in the image. Similarly, in
41 movie recommendation systems, rather than assign a movie to a single genre (‘comedy’ or ‘romance’
42 etc.), it is more realistic to assign to multiple genres (‘comedy’ AND ‘romance’). Now, in contrast to
43 all-or-nothing clustering (which would require a new genre ’romantic comedy’), different movies can
44 have different but overlapping sets of features, allowing a partial sharing of statistical information.
45 Latent feature models generalize clustering models, allowing each observation to have multiple
46 features. Nonparametric latent feature models allow the number of available features to be infinite
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


280

1
Dishes/tables Dishes
2
3
4
5

Customers
Customers
6
7
8
9
10
11 (a) (b)
12
13 Figure 31.7: (a) A realization of the cluster matrix Z from the Chinese restaurant process (CRP) (b) A
14 realization of the feature matrix Z from the Indian buffet process (IBP). Rows are customers and columns are
15
dishes (for the CRP each table has its own dish). Both produce binary matrices, with the CRP assigning a
customer to a single table (cluster), and the IBP assigning a set of dishes (features) to each customer.
16
17
18
19 (rather than fixed a priori to some finite number), with the number of active features in a dataset
20 growing with a dataset size. Such models associate the dataset with an infinite binary matrix
21 consisting of N rows, but now where each row can have multiple elements set to 1, corresponding to
22 the active features. This is shows in Figure 31.7(b). As with the clustering models, each column is
23 associated with a parameter drawn iid from some base measure H.
24 The Indian buffet process (IBP) in a Bayesian nonparametric analogue of the CRP for latent
25 feature models. As with the CRP, the IBP is specified by a concentration parameter α, and a
26 base measure H on some space Θ. The former controls the distribution over the binary feature
27 matrix, whereas feature parameters θk are drawn iid from the latter. Under the IBP, individuals
28 enter sequentially into a restaurant, now picking a set of dishes (instead of a single table). The
29 first customer samples a Poisson(α)-distributed random number of dishes, and assigns each of them
30 values drawn from H. When the ith customer enters the restaurant, they first make a pass through
31 all dishes already chosen. Suppose Nd of the earlier customers have chosen dish d: then customer
32 i selects this with probability Nd /i. This results in a rich-get-richer phenomenon, where popular
33 dishes (common features) are more likely to be selected in the future. Additionally, the i customer
34 samples a Poisson(α/i) number of new dishes. This results in a non-zero probability of new dishes,
35 that nonetheless decreases with i.
36 A key property of the Indian buffet process is that, like the CRP, it is exchangeable. In other
37 words, its statistical properties do not change if its rows are permuted. For instance, one can show
38 that the number of dishes picked by any customer is marginally Poisson(α) distributed. Similarly, the
39 distribution over the number of features shared by the first two customers is the same as the that for
40 the first and last customer. We mention that like the CRP, the ordering of the dishes (or columns) is
41 arbitrary and might appear to violate exchangeability. For instance, the first customer cannot pick
42 the first and third dishes and not the second dish, while this is possible for the second customer.
43 These artifacts disappear if we index columns by their associated parameters. Equivalently, after
44 reordering rows, we can transform the feature matrix to be left-ordered (essentially, all new dishes
45 selected by a customer must by adjacent, see [griffiths2011indian]), and we can view the IBP as a
46 prior on such left-ordered matrices.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
31.4. THE INDIAN BUFFET PROCESS AND THE BETA PROCESS

1
2
3
4
5
6
7
8
9
10
11 (a) (b)
12
13 Figure 31.8: (a) A realization of a beta process on the real line. The the atoms of the random measure are
14 show in red, while the beta subordinator (whose value at time t sums all atoms up to t) is shown in black. (b)
15 Samples from the beta process
16
17
18 The exchangeability of the rows of the IBP implies via de Finetti’s theorem that there exists
19 an underlying random measure G, conditioned on which the rows are iid draws. Just as the
20 Chinese restaurant process represents observations drawn from a Dirichlet process-distributed random
21 probability measure, the dishes of each customer in the IBP represent observations drawn from a
22 Beta process. Like the DP, the beta process is an atomic measure taking the form
23 ∞
X
24 G(θ) = πk δθk (θ). (31.50)
25 k=1
26
27
Each of the πk ’s lie between 0 and 1, but unlike the DP where they add up to 1, the πk ’s sum up
28
to a finite random number. The beta process is thus a random measure, rather than a random
29
probability measure. One can imagine the k’th atom as a coin located at θk , with probability of
30
success equal to πk . To simulate a row of the IBP, one flips each of the infinite coins, selecting the
31
feature at θk if the k coin comes up heads. One can show that if the πk ’s sum up to a finite number,
32
the number of active features will be finite. Of the infinite atoms in G, a few will dominate the rest,
33
these will be revealed through the rich-gets-richer dynamics of the IBP as features common to a large
34
proportion of the observations.
35
The beta process has a construction similar to the stick-breaking representation of the Dirichlet
36
process. As with the DP, the locations of the atoms are independent draws from the base measure,
37
while the sequence of weights π1 , π2 , . . . are constructed from an infinite sequence of beta variables:
38
now these are Beta(α, 1) distributed, rather than Beta(1, α) distributed. The overall representation
39
of the Beta process is then
40
βk ∼ Beta(α, 1), θk ∼ H, (31.51)
41
k
Y
42
43
πk = βk πk1 = βl (31.52)
l=1
44
45 It is not hard to see that under the IBP, the total numberPof dishes underlying a dataset of size
N
46 N follows a Poisson(αHN ) distribution, where HN is HN = i=1 1/i is the N th harmonic number.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


282

1
2 The beta process and the IBP have been generalized to three-parameter versions allowing power-law
3 behavior in the total number of dishes (features) in a dataset of size N , as well as in the number
4 of customers trying each dish [teh2009indian]. It has found application in tasks from genetics,
5 collaborative filtering, and in models for graphs. Just as with the DP, posterior inference can proceed
6 via MCMC (exploiting either the IBP or the stick-breaking representation), particle filtering or using
7 variational methods.
8
9 31.5 Small-variance asymptotics
10
11 Nonparametric Bayesian methods can serve as a basis to develop new and efficient discrete optimization
12 algorithms that have the flavor of Lloyd’s algorithm for the k-means objective function. The starting
13 point for this line of work is the view of the k-means algorithm as the small-variance asymptotic
14 limit of the EM algorithm for a mixture of Gaussians. Specifically, consider the EM algorithm to
15 estimate the unknown cluster means µ ≡ (µ1 , . . . , µk ) of a mixture of k Gaussians, all of which have
16 the same known covariance σ 2 I. In the limit as σ 2 → 0, the E-step, which computes the cluster
17 assignment probabilities of each observation given the current parameters µ(t) , now just assigns each
18 observation to the nearest cluster. The M-step recomputes a new set of parameters µ(t+1) given the
19 cluster assignment probabilities, and when each observation is hard-assigned to a single cluster, the
20 cluster means are just the means of the assigned observations. The process of repeatedly assigning
21 observations to the nearest cluster, and recomputing cluster locations by averaging the assigned
22 observations is exactly Lloyd’s algorithm for k-means clustering.
23 To avoid having to specify the number of clusters k, [kulis2011revisiting] considered a Dirichlet
24 process mixture of Gaussians, with all infinite components again having the same known variance σ 2 .
25 The base measure H from which the component means are drawn was set to a zero-mean Gaussian
26 with variance ρ2 , with α the concentration parameter. The authors then considered a Gibbs sampler
27 for this model, very closely related to the sampler in Algorithm 31.1, except that instead of collapsing
28 or marginalizing out the cluster parameters, these are instantiated. Thus, an observation xi as assigned
29 to a cluster c with mean µc with probability proportional to Nc,−i √2πσ 1
2
exp(− 2σ1 2 kxi − µc k2 ), while
30
it is assigned to a new cluster with probability proportional to α √ 12 2 exp(− 2(σ21+ρ2 ) kxi k2 ).
31 2π(σ +ρ )

32
After cycling through all observations, one then resamples the parameters of each cluster. Following
33
the Gaussian base measure and Gaussian likelihood, this too follows a Gaussian distribution, whose
34
mean is a convex combination of the prior mean 0 and data-driven term, specifically the average of
35
the assigned observations. The weight of the prior term is proportional to the inverse of the prior
36
variance ρ2 , while the weight of the likelihood term is proportional to the inverse of the likelihood
37
variance σ 2 .
38
To derive a small-variance limit of the sampler above, we cannot just let σ 2 go to 0, as that
39
would result in each observation being assigned to its own cluster. To prevent the likelihood from
40
dominating the prior in this way, one must also send the concentration parameter α to 0, so that
41
the DP prior enforces an increasingly strong penalty on a large number of clusters (recall that a
42
priori, the average number of clusters underlying N observations is α log N ). [kulis2011revisiting]
43
showed that with α scaled as α = (1 + ρ2 /σ 2 )d/2 exp(− 2σλ2 ) for some parameter λ, and taking the
44
limit σ 2 → 0 with ρ fixed, we get the following modification of the k-means algorithm:
45 1. Assign each observation to the nearest cluster, unless the distance to the nearest cluster exceeds λ,
46 in which case assign the observation to its own cluster.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
31.5. SMALL-VARIANCE ASYMPTOTICS

1
2 2. Set the cluster means equal to the average of the assigned observations.
3
4
5 Algorithm 31.3: DP-means hard clustering algorithm for data D = {x1 , . . . , xN }
6
1 Initialize the number of clusters K = 1, with cluster parameter µ1 equal to the global mean:
7 PN
µ1 = N1 i=1 xi
8
2 Initialize all cluster assignment indicators zi = 1
9
3 while all zi s have not converged do
10
4 for each i = 1 : N do
11
5 Compute distance dik = kxi − µk k2 to cluster k for k = 1, . . . , K
12
6 if mink dik > λ then
13
7 Increase the number of clusters K by 1: K = K + 1
14
8 Assign observation i to this cluster: zi = K and µK = xi
15
16 9 else
17 10 Set zi = arg minc dik
18 11 for each k = 1 : K do
19 12 Set Dk be the observations assigned to cluster k. Compute cluster mean
P
20 µk = |D1k | x∈Dk x
21
22
23
24 Like k-means, this is a hard-clustering algorithm, except that instead of having to specify the
25 number of clusters k, one just specifies a penalty λ for introducing a new cluster. The actual number
26 of clusters is determined by the data, [kulis2011revisiting] refer to this algorithm as DP-means.
27 One can show that the iterates above converge monotonically to a local maximum of the following
28 objective function:
29
K X
X
30 1 X
kx − µk k2 + λK, where µk = x. (31.53)
31 |Dk |
k=1 x∈Dk x∈Dk
32
33
The first term in the expression above is exactly the objective function of the k-means algorithm with
34
K clusters. The second term is an penalty term that introduces a cost λ for each additional cluster.
35
Interestingly, the penalty term above corresponds to the so called Akaike information criterion (AIC),
36
a well studied approach to penalizing model complexity.
37
It is also possible to derive hard-clustering algorithms for simulataneously clustering multiple
38
datasets, while allowing these to share clusters. This is possible through the small-variance limit of a
39
Gibbs sampler for the hierarchical Dirichlet process (HDP) and results in a clustering algorithm that
40
now has two thresholding parameters, a local one λl and a global one λg . The algorithm proceeds by
41
maintaining a set of global clusters, with the local clusters of each dataset assigned to a subset of the
42
global clusters. It then repeats the following steps until convergence:
43
44 Assign observations to local clusters For xij , the i’th datapoint in dataset j, compute the
45 distance to all global clusters. For those global clusters not currently present in dataset j, add λl
46 to their distance; this reflects the cost of introducing a new cluster into dataset j. Now, assign xij
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


284

1
2 to the cluster with the smallest distance, unless the smallest distance exceeds λg + λl , in which
3 case, create a new global cluster and assign xij to it. Observe that in the latter case, the distance
4 of xij to the new cluster is 0, with λg + λl reflecting the cost of introducing a new global and
5 then local cluster.
6
Assign local clusters to global clusters For each local cluster l, compute the sum of the dis-
7
tances of all its assigned observations to the cluster mean. Call this dl . Also compute the sum of
8
the distances of the assigned observations to each global cluster. For global cluster p, call this dl,p .
9
Then assign local cluster l to the global cluster with the smallest dl,p , unless min dl,p > dl + λg in
10
which case we create a new global cluster.
11
12 Recompute global cluster means Set the global cluster means equal to the average of the as-
13 signed observations across all datasets.
14
15
In [jiang2012small], these ideas were extended from DP mixtures of Gaussians to DP mixtures of
16
more general exponential family distributions. Briefly, the hard-clustering algorithms maintained the
17
same structure, the only difference being that distance from clusters was measured using a Bregman
18
divergence specific to that exponential family distribution. For Gaussians, the Bregman divergence
19
reduces to the usual Euclidean distance.
20
In [broderick2013mad], the authors showed how such small-variance algorithms could be derived
21
directly from a probabilistic model, and independent of any specific computational algorithm such as
22
EM or Gibbs sampling. Their approach involves computing the MAP solution of the parameters of
23
the model, and then taking the small-variance limit to obtain an objective function. This approach,
24
called MAP-based asymptotic derivations from Bayes or MAD-Bayes, allowed them to derive,
25
among other things, an analog of the DP-means algorithm from feature-based models. The called
26
this the BP-means algorithm, after the beta process underlying the Indian buffet process.
27
Write X for an N × D matrix of N D-dimensional observations, Z for an N × K matrix of binary
28
feature assignments, and A for a K × D matrix of K D-dimensional features, with one seeking a
29
pair (A, Z) such that ZA approximates X as well as possible. [broderick2013mad] showed in the
30
small-variance limit, finding the MAP solution for the IBP is equivalent to the following problem:
31 argminK,Z,A tr[(X − ZA)t (X − ZA)] + λK, (31.54)
32
33 where again λ is a parameter of the algorithm, governing how the concentration parameter scales
34 with the variance, and specifying a penalty on introducing new features. This objective function
35 is very intuitive: the first term corresponds to the approximation error from K features, while the
36 second term penalizes model complexity resulting from a large number of features K. This objective
37 funtion can be optimized greedily by repeating three steps for each observation i until convergence:
38
1. Given A and K, compute the optimal value of the binary feature assignment vector zi of observation
39
i and update Z.
40
41 2. Given Z and A, introduce an additional feature vector equal to the residual for the i’th observation,
42 namely, xi − zi A. Call A0 the updated feature matrix, and Z 0 the updated feature assignment
43 matrix, where only observation i has been assigned this feature. If the configuration (K + 1, Z 0 , A0 )
44 results in a lower value of the objective function than (K, Z, A), set the former as the new
45 configuration. In other words, if the benefit of introducing a new feature outweights the penalty
46 λ, then do so.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
31.6. COMPLETELY RANDOM MEASURES

1
2
3
4
5
6
7
8
9
10
11 P
Figure 31.9: Poisson process construction of a completely random measure G = i wi δθi . The set of pairs
12
{(w1 , θ1 ), (w2 , θ2 ), . . . } is a realization from a Poisson process with intensity λ(θ, w) = H(θ)γ(w).
13
14
15
16
3. Update the feature vectors A given the feature assignment vectors Z.
17
[broderick2013mad] showed that this algorithm monotonically decreases the objective in Equa-
18
tion (31.54), converging eventually to a local optimum. Subsequent works have considered the
19
small-variance asymptotics of more structured models, such as topic models, hidden Markov models
20
and even continuous-time stochastic process models.
21
22
23 31.6 Completely random measures
24
The Dirichlet process is a example of a random probability measure, a class of random measures
25
which always integrate to 1. The beta process, while not an RPM, belongs to another class of random
26
measures: completely random measures (CRM). A completely random measure G satisfies the
27
following property: for any two disjoint subsets T1 and T2 of Θ, the values G(T1 ) and G(T2 ) are
28
independent random variables:
29
30
G(T1 ) ⊥
⊥ G(T2 ) ∀ T1 , T2 ⊂ Θ s.t. T1 ∩ T2 = ∅. (31.55)
31
32 Note that the Dirichlet process is not a CRM, since for disjoint sets T and its complement T c = Θ \ T ,
33 we have G(T ) = 1 − G(T c ), which is as far from independent as can be. More generally, under the
34 DP, for disjoint sets T1 and T2 , the measures G(T1 ) and G(T2 ) are negatively correlated. We will see
35 later though that the Dirichlet process is closely related to another CRM, the gamma process. As a
36 side point, beyond the sum-to-one constraint, G(T1 ) does not tell us anything about the distribution
37 of probability within G(T1c ), making the DP what is known as a neutral process.
38 The simplest example of a CRM is the Poisson process (see Section 31.8). A Poisson process
39 with intensity λ(θ) is a point process
R producing points or events in Θ, with the number of points
40 in any set T following a Poisson( T λ(θ)dθ)-distribution, and with the counts in any two disjoint
41 sets independent of each other. While it is common to think of this as a point process, one can also
42 think of a realization from a Poisson process as an integer-valued random measure, where the
43 measure of any set is the number of points falling within that set. It is clear then that the Poisson
44 process is an example of a CRM.
45 It turns out that the Poisson process underlies all completely random measures in a fundamental
46 way. For some space Θ, and W the positive real line, simulate a Poisson process on the product space
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


286

1
2 W × Θ with intensity λ(θ, w). Figure 31.9 shows a realization from this Poisson process, write it as
3 M = {(θ1 , w1 ), . . . , (θ|M | , w|M | )} where |M | is the number of events. This can be used to construct
P|M |
4 an atomic measure G = i=1 wi δθi as illustrated in Figure 31.9. From its P Poisson construction, this
5 is a completely random measure on Θ, with set T ∈ Θ having measure i wi I (θi ∈ T ). Different
6 settings of λ(θ, w) give rise to different CRMs, and in fact, other than CRMs with atoms at some
7 fixed locations in Θ, this construction characterizes all CRMs.
R For a CRM, the Poisson intensity is typically chosen to factor as λ(θ, w) = H(θ)γ(w), with
8
9
Θ
H(θ)dθ = 1. Then, H(θ) is the base measure controlling the locations of the atoms in the CRM,
10 while the measure γ(w) controls the number of atoms, and the distribution of their weights. Setting
11 γ(w) = w−1 (1 − w)α−1 gives the beta process with base measure H(θ). Other choices include
12 the gamma process (γ(w) = αw−1 exp(−w)), the stable process (γ(w) = Γ(1+σ) 1
αw−1−σ ), and the
13
generalized gamma process (γ(w) = Γ(1+σ) αw−1−σ exp(−ζw)).
1
14 R R
For all three processes described earlier, the γ(w) integrates to infinity, so that Θ W λ(θ, w)dθdw =
15
∞. Consequently, the number of Poisson events, and thus the number of atoms in the CRMs are
16
infinite with probability one. At the same time, R ∞mass of the γ(w) function is mostly concentrated
17
around 0 (see Figure 31.9), and for any  > 0,  γ(w)dw is finite. It is easy to show that the sum
18
of the w’s is finite almost surely. Call this sum W , then the first condition ensures W greater than 0,
19
while the second ensures it is finite. These two conditions make it sensible to divide a realization
20
of a CRM by its sum, resulting in a random measure that integrates to 1: a random probability
21
measure. Such RPMs are called normalized completely random measures, or sometimes just
22
normalized random measures (NRMs). The Dirichlet process we saw earlier is an example of
23
an NRM: it is a normalized gamma process. This result mirrors the situation with the finite Dirichlet
24
distribution: one can simulate from a d-dimensional Dir(α1 , . . . , αd ) distribution by first simulating
25
d independent gamma variables gi ∼ Ga(αi , 1), i = 1, . . . , d, and then defining the probability
26
vector Pd1 g (g1 , . . . , gd ). The Pitman-Yor process is not an NRM except for special settings of
27 i=1 i

28 its parameters: like we saw, d = 0 is a Dirichlet process or normalized gamma process. α = 0 is a


29 normalized stable process. The normalized generalized gamma process is an NRM that includes
30 the DP and the normalized stable process as special cases.
31
32 31.7 Lévy processes
33
34 Completely random measures are also closely related to Lévy processes and Lévy subordina-
35 tors [bertoin1996levy]. A Lévy process is a continuous-time stochastic process {Lt }t≥0 taking
36 values in some space (e.g. Rd ) that satisfies two properties1 :
37
38 stationary increments: for t, ∆ > 0, the random variable Lt+∆ − Lt does not depend on t
39
independent increments: for t, ∆ > 0, Lt+∆ − Lt is independent of values before t.
40
41 A Lévy subordinator is a real-valued, nondecreasing Lévy process. If we drop the stationarity
42 condition, it should be clear that the increments of a Lévy subordinator are exactly the atoms
43 of a completely random measure (see Figure 31.8). Common examples of Lévy subordinators are
44 beta subordinators (or beta processes), gamma subordinators, stable subordinators, and generalized
45
46 1. There is also a technical continuity condition that we do not discuss here.
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
31.7. LÉVY PROCESSES

1
2 0.4 0.5

Brownian motion
3 0.4

Cauchy process
0.3
4 0.3
5 0.2
0.2
6 0.1
0.1
7
0.0 0.0
8
9 −0.1
0.00 0.25 0.50 0.75 1.00 0.00 0.25 0.50 0.75 1.00
10 t t
11 (a) (b)
12
13 Figure 31.10: (a) A realization of a Brownian motion. (b) A realization of a Cauchy process (an α-stable
14 process with α = 1).
15
16
17
18
19
20 gamma subordinators. CRMs generalize Lévy subordinators, by allowing them to be indexed by
21 some general space Θ (rather than by nonnegative reals), and by relaxing the stationarity condition
22 to allow the atoms to follow some general base measure H.
23 Unlike Lévy subordinators, a general Lévy process can have negative changes as well, with the
24 change Lt+∆ − Lt belonging to an infinitely divisible distribution, scaled by ∆. A random
25 variable X follows an infinitely divisible distribution if, for any positive integer N , there exists some
26 probability distribution such that the sum of N iid samples from that distribution has the same
27 distribution as X. Examples of infinitely divisible distributions include the Poisson distribution, the
28 gamma distribution, the Cauchy distribution, the α-stable distribution, and the inverse-Gaussian
29 distribution (among many others), though by far the most well-known example is the Gaussian
30 distribution. It is easy to see why the change in a Lévy process Lt+∆ − Lt over some time interval ∆
31 follows an infinitely divisible distribution: just divide the interval into N equal-length segments. From
32 the properties of the Lévy process, the changes over these segments are independent and identically
33 distributed, and their sum equals Lt+∆ − Lt . The Lévy-Khintchine formula shows that the
34 converse is also true: any infinitely divisible distribution represents the change of an associated Lévy
35 process. The Lévy process corresponding to the Gaussian distribution is the celebrated Brownian
36 motion (or Wiener process). Brownian motion is a fundamental and widely applied stochastic
37 process, whose mathematics was first studied by Louis Bachelier to model stock markets, and later,
38 famously, by Albert Einstein to argue about the existence of atoms. For a Brownian motion, the
39 increment Lt1 +∆ − Lt1 follows a normal N (µ∆, σ∆) distribution, where µ and σ are the drift and
40 diffusion coefficients. Setting µ and σ to 0 and 1 gives standard Brownian motion. Paths sampled
41 from the Weiner process are continuous with probability one, all other Lévy processes are jump
42 processes. Figure 31.10 shows realizations from some processes. The jump processes have a Poisson
43 construction related to the Lévy subordinators and completely random measures, and the Lévy-Itô
44 decomposition shows that every Lévy process can be decomposed into Brownian and Poisson
45 components. Levy processes have been widely applied in mathematical finance to model asset prices,
46 insurance claims, stock prices, and other financial assets.
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


288

1
2 31.8 Point processes with repulsion and reinforcement
3
4 In this section, we look more closely at the Poisson process, as well as other, more general point
5 processes that allow inter-event interactions.
6
7
8 31.8.1 Poisson process
9
10
We have already briefly seen the Poisson process: this is a point process on Rsome space Θ that is
11
parameterized by an intensity function λ(θ) ≥ 0, and that produces a Poisson( T λ(θ)dθ)-distributed
12
number of points of events in any set T ∈ Θ, with the counts in any two disjoint sets independent
13
of each other. Recall that if Ni ∼ Poisson(λi ) are independent, then a well-known property of the
14
Poisson distribution is that N1 + N2 ∼ Poisson(λ1 + λ2 ). This relates to the infinite divisibility
15
of the Poisson distribution. It is clear that the average number of points is large in areas where
16
λ(θ) is high, and small where λ(θ) is small. When the intensity λ(θ) is some constant λ, we have a
17
homogeneous Poisson process, R otherwise we have an inhomogeneous Poisson process.
18
Depending on whether Θ λ(θ)dθ is infinite or finite, a Poisson process will either produce an
19
infinite number of points (recall the Poisson process underlying the CRMs of Section 31.6) or a
20
finite number of points. The latter is more common in applications, such as modeling phone calls
21
or financial shocks (Θ is some finite time-interval), the locations of trees or forest fires (Θ is some
22
subset of the Euclidean plane), the locations of cells or galaxies (Θ is a 3-dimensional space) or
23
events in higher-dimensional spaces (for example, spatio-temporal activity). One way to simulate a
24
finite Poisson
R process is to first simulate the number of total number of points N , which follows a
25
Poisson( Θ λ(θ)dθ) distribution. One can then simulate the locations of these points by sampling N
26 times from the probability density R λ(θ) λ(θ)dθ
. For a homogeneous Poisson process, the locations are
Θ
27 uniformly distributed over Θ. One can also easily simulate a rate-λ homogeneous Poisson on the real
28 line by exploiting Rthe fact that inter-event times follow an exponential distribution with mean 1/λ.
29 If the integral Θ λ(θ)dθ is difficult to evaluate, one can also simulate a Poisson process using
30 the thinning theorem [lewis1979simulation]. Here, one needs to find a function γ(θ) such that
31 γ(θ) ≥ λ(θ), ∀θ, and such that it is easy to simulate from a rate-γ(θ) Poisson process. Suppose the
32 result is Ψ = {ψ1 , . . . , ψN }. Since γ(θ) ≥ λ(θ), Ψ is going to contain more events, and one thins Ψ
33 by keeping each element ψi in Ψ with probability λ(ψi )/γ(ψi ) (otherwise one discards it). Once can
34 show that the set of surviving points is then a realization of a Poisson process with rate λ(θ).
35 The defining feature of the Poisson process is the assumption of independence among events. In
36 many settings, this is inappropriate and unrealistic, and the knowledge of an event at some location
37 θ might suggest a reduced or elevated probability of events in neighboring areas. Point processes
38 satisfying the former property are called underdispersed (Figure 31.11(b)), while point process
39 satisfying the latter property are called overdispersed (Figure 31.11(c)). Examples of underdispersed
40 point processes include the locations of trees or train stations, which tend to be more spread out than
41 Poisson because of limited resources. An example of an overdispersed point process is earthquake
42 locations, where aftershocks tend to occur in the vicinity of the main shock.
43 A simple approach to modeling overdispersed point processes is through hierarchical extensions of
44 the Poisson process that allow the Poisson intensity function λ(·) to be a random variable. Such
45 models are called doubly-stochastic Poisson processes or Cox processes [cox1980point], and a
46 common approach is to model the intensity λ(·) via a Gaussian process. Note though that the Poisson
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
31.8. POINT PROCESSES WITH REPULSION AND REINFORCEMENT

1
2
3
4
5
6
7
8
9
10
11 (a) (b) (c)
12
13 Figure 31.11: Realizations of (a) a homogeneous Poisson process; (b) an underdispersed point process (Swedish
14 pine sapling locations [ripley2005spatial]); (c) an overdispersed point process (California redwood tree
15 locations [ripley1977modelling]).
16
17
18
process intensity function must be nonnegative, so that λ is often a transformed Gaussian process:
19
20 λ(θ) = g(`(θ)), `(θ) ∼ GP. (31.56)
21
22 Common examples for g include exponentiation, sigmoid transformation or just thresholding. Because
23 of the smoothness of the unknown λ(·), observing an event at some location suggests events are
24 likely in the neighborhood. Such models still do not capture direct interactions between point
25 process events, rather, these are mediated through the unknown intensity functions, making them
26 inappropriate for many applications. For instance, in neuroscience a neuron’s spiking can be driven
27 directly by past activity, and activity of other neurons in the network, rather than just through some
28 shared stimulus λ(t). Similarly, social media activity has a strong reciprocal component, where, for
29 instance, emails sent out by a user might be in response to past activity, or activity of other users.
30 The next few subsections show how one might explicitly model such interactions.
31
32
31.8.2 Renewal process
33
34 Renewal processes are one class of models of repulsion and reinforcement for point processes
35 defined on the real line (typically regarded as time). Recall that for a homogeneous Poisson process,
36 inter-event times follow an exponential distribution. The exponential distribution has the property of
37 memorylessness, where the time until the next event is independent of how far in the past the last
38 event occurred. That is, if τ follows the exponential distribution, then for any δ and ∆ > 0, we have
39 p(τ > ∆ + δ|τ ≥ δ) = p(τ > ∆)). Renewal processes incorporate memory by allowing the interevent
40 times to follow some general distribution on the positive reals. Examples include gamma renewal
41 processes and Weibull renewal processes, where interevent times follow the gamma and Weibull
42 distribution respectively. Both these processes include parameter settings that recover the Poisson
43 process, but also allow burstiness and refractoriness. Burstiness refers to the phenomenon where,
44 after an event has just occured, one is more likely to see more events than a longer time afterwards.
45 This is useful for modeling email activity for instance. Refractoriness refers to the opposite situation,
46 where an event occuring implies new events temporarily less likely to occur. This is useful to model
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


290

1
2
3
4
5
6
7
8
9
10
11 (a) (b)
12
13 Figure 31.12: (a) A self-exciting Hawkes process. Each event causes a rise in the event rate, which can lead
14 to burstiness. (b) A pair of mutually exciting Hawkes processes, events from each cause a rise in the event
15
rate of the other.
16
17
18
19 neural spiking activity, for instance, where after spiking, a neuron is depleted of resources, and
20 requires a recovery period before it can fire again.
21
22
23 31.8.3 Hawkes process
24
25 Hawkes processes [hawkes1971point] are another class of reinforcing point processes that have
26 attracted much recent attention. Hawkes processes provide an intuitive framework for modeling
27 reinforcement in point processes, through self-excitation when a single point process is involved, and
28 through mutual excitation (sometimes called reciprocity) when collections of point processes are
29 under study. The former, shown in Figure 31.12(a), is relevant when modeling bursty phenomena
30 like visits to a hospital by an individual, or purchases and sales of a particular stock. The latter,
31 displayed in Figure 31.12(b) is useful to characterize activity on social or biological networks, such as
32 email communications or neuronal spiking activity. In both examples, each event serves as a trigger
33 for subsequent bursts of activity. This is achieved by letting λ(t), the event rate at time t, be a
34 function of past activity or event history. Write H(t) = {tk : tk ≤ t} for the set of event times up to
35 time t. Then the rate at time t, called the conditional intensity function at that time, is given by
36 X
37 λ(t|H(t)) = γ + φ(t − tk ) , (31.57)
38 tk ∈H(t)
39
40 where γ is called the base-rate and the function φ(·) is called the triggering kernel. The latter
41 characterizes the excitatory effect of a past event on the current event rate. Figure 31.12 shows
42 the common situation where φ(·) is an exponential kernel φ(∆) = βe−∆/τ , ∆ ≥ 0. Here, a new
43 event causes a jump of magnitude β in the intensity λ(t), with its excitatory influence decaying
44 exponentially back to γ, with τ the time-scale of the decay. For a multivariate Hawkes process, there
45 are m point processes (N1 (t), N2 (t), · · · , Nm (t)) associated with m users or nodes. Write Hi (t) for
46 the event history of the ith process at time t. Its conditional intensity can depend on all m event
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
31.8. POINT PROCESSES WITH REPULSION AND REINFORCEMENT

1
2 histories, taking the form
3
m
X X
4
λi (t|{Hj (t)}m
j=1 ) = γi + φij (t − tk ). (31.58)
5
j=1 tk ∈Hj (t)
6
7 More typically, the conditional intensity of user i can depend only on the event histories of those
8 nodes ‘connected’ to it, with connectivity specified by a graph structure, and with φij (·) = 0 if there
9 is no edge linking i and j. Alternately, events can have marks indicating whom they are sent to (e.g.,
10 receipients of an email), with each event only updating the conditional intensities of its recipients.
11 Simulating from a Hawkes process is a fairly straightforward extension of simulating from an
12 inhomogeneous Poisson process. Consider the univariate Hawkes process, and suppose the last event
13 occured at time ti . Then, until the next event occurs, at any time t > ti , we have that H(t) = H(ti ),
14 so that the conditional intensity λ(t|H(t)) = λ(t|H(ti )). The next event time, ti+1 , is then just the
15 time of the first event of a Poisson process with intensity λ(t|H(ti )) on the interval [ti , ∞). With most
16 choices of the kernel φ, ti+1 can easily be simulated, either by integrating λ(t|H(ti )), or by Poisson
17 thinning. At time ti+1 , the history is updated to incorporate the new event, and the conditional
18 intensity experiences a jump. The updated intensity is used to simulate the next event at some time
19 ti+2 > ti+1 from a rate-λ(t|H(ti+1 )) Poisson process, and the process is repeated until the end of the
20 observation interval. For the case of multivariate Hawkes processes, one has a collection of competing
21 intensities λi (t|{Hj (t)}m j=1 ). The next event is the first event among all events produced by these
22 intensities, after which the intensities are updated and the process is repeated. A realization Ψ from
23 a Hawkes process has log-likelihood
24
X Z
25
`(Ψ) = log λ(t∗ |H(t∗ )) − λ(t|H(t))dt. (31.59)
26
t∗ ∈Ψ
27
28 This can typically be evaluated quite easily, so that maximum likelihood estimates of parameters like
29 the base-rate as well as parameters of the excitation kernel can be obtained straightforwardly.
30 The Hawkes process as described is a fairly simple model, and there have been a number of
31 extensions enriching its structure. An early example is [blundell2012modelling], where the
32 authors considered the multivariate Hawkes process, now with an underlying clustering structure.
33 Instead of each individual point process having its own conditional intensity function, each cluster
34 has an intensity function shared by all point process assigned to it. The interaction kernels are also
35 defined at the cluster level, with an event in process i causing a jump φci c (τ ) in the intensity function
36 of cluster c (where ci is the cluster point process i belongs to). The cluster structure was modeled
37 through a Dirichlet process, allowing the authors to learn the underlying clustering, as well as the
38 inter-cluster interaction kernels from interaction data. In [tan2016content], the authors considered
39 marked point processes, where event i at time ti has some associated content yi (for example, each
40 event is a social media post, and yi is the text associated with the post at time ti ). The authors
41 allowed the jump in the conditional intensity to depend on the associated mark, with the Hawkes
42 kernel taking the form φ(∆) = f (yi ) exp(−∆/τ ). In that work, the authors modeled the function
43 f with a Gaussian process, though other approaches, such as ones based on neural networks, are
44 possible.
45 The neural Hawkes process [mei2017neural] is a more fleshed out approach to modeling point
46 processes using neural networks. This models event intensities through the state of a continuous-time
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


292

1
2 LSTM, a modification the more standard discrete-time LSTMs from ??. Central to an LSTM is a
3 memory cell ci to store long-term memory, summarizing the past until time step i. Continuous-time
4 LSTMs include two long-term memory cells, ci and c̃i , summarizing the history until the ith Hawkes
5 event. The first cell represents the starting value to which the intensity jumps after the i’th event, and
6 the second represents a baseline rate after the i’th event. These are both updated after each event,
7 with c(t), the instantaneous rate at any intermediate time determined by ci decaying exponentially
8 towards the baseline c̃i . This mechanism allows the intensity at any time to be influenced not just
9 by the number of events in the past, but also the waiting times between them. It can be extended
10 to marked point processes, where each event is also associated with a mark y: now both a learned
11 embedding of the mark as well as the time since the last event is used to update state and long-term
12 memory. For more details, see also du2016recurrent.
13
14 31.8.4 Gibbs point process
15
16 Gibbs point processes [moller2007modern] from the statistical physics and spatial statistics
17 literature provide a general framework for modeling interacting point processes on higher-dimensional
18 spaces. Such spaces are more challenging than the real line, since there is now no ordering of points,
19 and thus no natural notion of history affecting future activity. Instead, Gibbs point processes use an
20 energy function E to quantify deviations from a Poisson process with rate 1. Specifically, under a
21 Gibbs process, the probability density of any configuration Ψ with respect to a rate-1 Poisson process
22 takes the form
23
1
24 Pβ (Ψ) = exp(−βE(Ψ)) (31.60)

25
26
where β is the inverse-temperature parameter, and Zβ = Eπ [exp(−βE(Ψ))] is the normalization
27
constant (the expectation is with respect to π, the unit-rate Poisson process). Under some conditions
28
on the energy function E, the expectation Zβ is finite, and Pβ is a well-defined density, whose integral
29
with respect to π is 1. While the above equation resembles a Markov random field, the domain of Ψ
30
is much more complicated, and evaluating Zβ now involves solving an infinite dimensional integral.
31
Equation (31.60) states that configurations Ψ for which E(Ψ) is small are more likely than under
32
a Poisson process, with β controlling how peaked this is. The most common energy functions are
33
pairwise potentials, taking the form
34
X
35
E(Ψ) = φ(ks − s0 k), (31.61)
36
(s,s0 )∈Ψ
37
38 where the summation is over all pairs of events in Ψ, and φ : R+ → R ∪ ∞. The Strauss process is a
39 specific example, with energy function specified by positive parameters a and R as
40
X
41
E(Ψ) = a · I (ks − s0 k ≤ R) . (31.62)
42
s,s0 ∈Ψ
43
44 This is a repulsive process that penalizes configurations with events separated by distance less than
45 R, and as a → ∞ becomes a hardcore repulsive process, forbidding configurations with two
46 points separated by less than R. More generally, the energy function can be piecewise-constant,
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
31.8. POINT PROCESSES WITH REPULSION AND REINFORCEMENT

1
2 parameterized by a collection of pairs (a1 , R1 ), . . . , (an , Rn ):
3
n
X X
4
E(Ψ) = ai · I (ks − s0 k ≤ Ri ) . (31.63)
5
s,s0 ∈Ψ i=1
6
7 Another natural option is to use smooth functions like a squared exponential kernel. Gibbs point
8 processes can also involve higher-order interactions, examples being Geyer’s triplet point process
9 (which penalizes occurrences of 3 events that are all within some distance R), or area-interaction
10 point processes (that center disks of radius R on each event, calculate the area of the union of these
11 disks, and define E as some function of this area).
12 While Gibbs point processes are flexible and interpretable point process models, the intractable
13 normalization constant Zβ makes estimating parameters like β a formidable challenge. In practice,
14 these models have to be fit using approximate approaches such as maximizing a pseudo-likelihood
15 function (instead of Equation (31.60)).
16
17 31.8.5 Determinantal point process
18
19 Determinantal point processes or DPPs are another approach to modeling repulsion, and have
20 seen considerable popularity in the machine learning literature (see e.g., ([kulesza2012determinantal;
21 macchi1975coincidence; borodin2009determinantal; lavancier2015determinantal]). Like
22 any point process, a DPP is a probability distribution over subsets of a fixed set S, and in the DPP
23 literature this is often called the ground set. Point process applications typically have S with
24 uncountably infinite cardinality (for example, S could be the real line), although machine learning
25 applications of DPPs often focus on S with a finite number of elements. For instance, S could be a
26 database of images, a collection of news articles, or a group of individuals. A sample from a DPP is
27 a random subset of S, produced, for instance, in response to a search query. The repulsive nature of
28 DPPs ensures diversity and parsimony in the returned subset. This could be useful, for example, to
29 ensure responses to a search query are not minor variations of the same image. Another application
30 is clustering, where a DPP serves as a prior over the number of clusters and their locations, with the
31 repulsiveness discouraging redundant clusters that are very similar to each other.
32 DPPs are parameterized by a similarity kernel K, whose element Kij gives the similarity between
33 elements i and j of the ground set. For finite ground sets, K is just a similarity matrix. DPPs require
34 K to be positive definite, with largest eigenvalue less than 1, that is 0  K  I. For simplicity, K is
35 often assumed symmetric, though this is not necessary. Given a kernel K, the associated DPP is
36 defined as follows: if Y is the random subset of S drawn from the DPP, then the probability that Y
37 contains any subset A of S is given by
38
p(A ⊆ Y ) = det(KA ). (31.64)
39
40
Here KA is the submatrix obtained by restricting K to rows and columns in A, and we define
41
det(Kφ ) = 1 for the empty set φ. Observe that this probability is specified exactly, and not just up
42
to a normalization constant. We immediately see that Y contains the empty set with probability one
43
(this is trivially true since the empty set is a subset of every set). We also see that the probability
44
that element i is selected in Y is the ith diagonal element of K:
45
46 p(i ∈ Y ) = det(Kii ) = Kii , (31.65)
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


294

1
2
3
4
5
6
7
8
9
10
11
Figure 31.13: The determinant of a matrix V is the volume of the parallelogram spanned by the columns of V
12
and the origin.
13
14
15
so that the diagonal of K gives the inclusion probabilities of the individual elements in S. More
16
interestingly, the probability a pair of elements {i, j} are both contained in Y is given by
17
18 2
p({i, j} ⊆ Y ) = Kii Kjj − Kij . (31.66)
19
20 The first term Kii Kjj is the probability of including i and j if they were independent, and this
21 probability is adjusted by subtracting Kij 2
, a measure of similarity between i and j. This is the
22 source of repulsiveness or diversity in DPPs. More generally, the determinant of a set of vectors is
23 the volume of the parallelogram spanned by them and the origin (see Figure 31.13), and by making
24 the probability of inclusion proportional to the volume, DPPs encourage diversity. We mention for
25 completeness that for uncountable ground sets like a Euclidean space, the determinant det(KA ) now
26 gives the product density function of a realization A. This generalizes the intensity of a Poisson
27 process to account for interactions between point process events, and we refer the interested reader
28 to lavancier2015determinantal for more details.
29 Equation (31.64) characterizes the marginal probabilities of a determinantal point process: what is
30 the probability that a realization Y contains some subset of S. More often, one is interested in the
31 probability that the realization Y equals some subset of S. The latter is of interest for simulation,
32 inference and parameter learning with DPPs. For this, it is typical to work with a narrower class
33 of DPPs called L-ensembles [borodin2009determinantal; kulesza2012determinantal]. Like
34 general DPPs, such a process is characterized by a positive semidefinite matrix L, with the probability
35 that Y equals some configuration A given by:
36
37 p(Y = A) ∝ det(LA ). (31.67)
38
Note that this probability is specified only upto a normalization constant, and so we do not need
39
to upperbound eigenvalues of L. In fact, it is not hard to show that the normalizer is given by
40
(I + det(L)), so that
41
42 det(LA )
43 p(Y = A) = . (31.68)
I + det(L)
44
45 A similar calculation can be used to show that p(A ⊆ Y ) = det(K), where K = L(I + L)−1 =
46 I − (I + L)−1 , showing that L-ensembles are indeed a special kind of DPP. Equation (31.68) and
47
Draft of “Probabilistic Machine Learning: Advanced Topics (Supplementary Material)”. August 15, 2023
31.8. POINT PROCESSES WITH REPULSION AND REINFORCEMENT

1
2 Equation (31.64) allow parameters of the DPP to be estimated from realizations of a point process,
3 typically by gradient descent. Without additional structure, naively calculating determinants is
4 cubic in the cardinality of S, and this represents a substantial saving when one considers the number
5 of possible subsets of S. When even cubic scaling is too expensive, a number of approximation
6 approaches can be adopted, and these are often closely related to approaches to solve the cubic cost
7 of Gaussian processes.
8 So far, we have only discussed how to calculate the probability of samples from a DPP. Simulating
9 from a DPP, while straightforward, is a bit less intuitive, and we refer the reader to [lavancier2015determinantal;
10 kulesza2012determinantal]. At a high level, these approaches use the eigenstructure of K to
11 express a DPP as a mixture of determinantal projection point processes. The latter are
12 DPPs whose similarity kernel has binary eigenvalues (either 0 or 1) and are easier to sample from.
13 Observe that any eigenvalue λi of the similarity kernel K must lie in the interval [0, 1]. This allows
14 us to generate a random similarity kernel K̂ with the same eigenvectors as K but with binary
15 eigenvalues as follows: for each i, replace eigenvalue λi of K with a binary variable λ̂i simulated
16 from a Bernoulli(λi ) distribution. One can show that the DPP simulated from K̂ after simulating
17 K̂ from K in this fashion is distributed exactly as a DPP with kernel K. We refer the reader to
18 [lavancier2015determinantal] for details on how to simulate a determinantal projection point
19 process with similarity kernel K̂.
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

Author: Kevin P. Murphy. (C) MIT Press. CC-BY-NC-ND license


32 Representation learning
33 Interpretability
Part VI

Decision making
34 Multi-step decision problems
35 Reinforcement learning
36 Causality

You might also like