0% found this document useful (0 votes)
4 views41 pages

Lecture6 Handout

The document discusses expectation propagation and variational inference as methods for approximate probabilistic inference. It introduces Kullback-Leibler divergence as a measure of distance between distributions that is minimized in these methods. Moment matching is also discussed as it relates to approximating distributions. The contents section previews topics to be covered including KL divergence, moment matching, expectation propagation, and variational inference.

Uploaded by

Sapto Indratno
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
4 views41 pages

Lecture6 Handout

The document discusses expectation propagation and variational inference as methods for approximate probabilistic inference. It introduces Kullback-Leibler divergence as a measure of distance between distributions that is minimized in these methods. Moment matching is also discussed as it relates to approximating distributions. The contents section previews topics to be covered including KL divergence, moment matching, expectation propagation, and variational inference.

Uploaded by

Sapto Indratno
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 41

Advanced Probabilistic Machine Learning

Lecture 6 – Expectation propagation and Variational inference

Niklas Wahlström
Division of Systems and Control
Department of Information Technology
Uppsala University

[email protected]
www.it.uu.se/katalog/nikwa778

1 / 33 [email protected] Expectation propagation and Variational inference


Summary of lecture 5 (I/IV)

A Markov random field describes the factorization of the joint


distribution, not the distributions of the variables

a d

b e f

1
p(a, b, c, d, e, f ) = ψ1 (a, b, c)ψ2 (c, d, e)ψ3 (e, f )
Z

The joint distribution is a product of factors corresponding to the


variables in the maximal cliques of the graph.

2 / 33 [email protected] Expectation propagation and Variational inference


Summary of lecture 5 (II/IV)

Global Markov For any disjoint subsets of variables A, B, C,


where C separates A from B,
A B|C

|=
a d

b C e f

A B
p(b, f |c) = p(b|c)p(f |c)

In MRF, observed nodes block paths (separation⇒independence)

3 / 33 [email protected] Expectation propagation and Variational inference


Summary of lecture 5 (III/IV)

A FG describes explicitly the factorization of the joint distribution

f1 a
f3
c d
f5
f2 b
f4

p(a, b, c, d) = f1 (a)f2 (b)f3 (a, b, c)f4 (b, c)f5 (c, d)

Warning: the normalization constant may be explicit or included


in one of the factors!

4 / 33 [email protected] Expectation propagation and Variational inference


Summary of lecture 5 (IV/IV)
Belief propagation algorithm
x f f x
• Initialize all leaf messages: µx→f (x) = 1 µf →x (x) = f (x)
c c
• Each variable node outputs the product of all messages from
neighboring factors except for the target factor
Y
µx→f (x) = µfs →x (x)
fs \f

• Each factor node computes the product of all incoming messages


times its own factor, and outputs the integral with respect to all the
variables except for the target variable
X Y
µf →x (x) = f (x) µxm →f (xm )
x\x m

• Marginals at each variable node are proportional to the product of


all incoming messages
Y
p(x) = µfs →x (x)
5 / 33 [email protected] s Expectation propagation and Variational inference
Contents

Kullback-Leibler divergence

Moment matching in factor graphs

Expectation propagation

Variational inference

6 / 33 [email protected] Expectation propagation and Variational inference


Bayesian framework reminder
In this course we solve problems using Bayes’ theorem
p(D|θ)p(θ)
p(θ|D) =
p(D)

• D : observed data
• θ : parameters of some model explaining the data
• p(θ): prior belief of parameters before we collected any data
• p(θ|D): posterior belief of parameters after inferring data
• p(D|θ): likelihood of the data in view of the parameters

Sometimes posterior can be found exactly with conjugate priors.


What if there are no exact solution?
• Stochastic approximate inference: Monte Carlo lecture 4
• Deterministic approximate inference:
Expectation propagation and Variational inference Today
7 / 33 [email protected] Expectation propagation and Variational inference
Deterministic approx. inference: The idea
Idea Approximate the posterior p(θ|D) with q(θ) ∈ Q

Find a tractable distribution q(θ) ∈ Q which is close to p(θ|D)

q̂(θ) = arg min D(p(θ|D) k q(θ))


q(θ)∈Q

Short summary of Tamaras Variational Bayes tutorial in ICML2018 medium.com/@aminamollaysa/


short-summary-of-tamaras-variational-bayes-tutorial-in-icml2018-68aec59cdc37

8 / 33 [email protected] Expectation propagation and Variational inference


Deterministic approx. inference: The idea
Idea Approximate the posterior p(θ|D) with q(θ) ∈ Q

Find a tractable distribution q(θ) ∈ Q which is close to p(θ|D)

q̂(θ) = arg min D(p(θ|D) k q(θ))


q(θ)∈Q

As distance we use the Kullback-Leibler divergence.


Z
  q(x)
KL p(x) k q(x) = − p(x) ln dx
p(x)

8 / 33 [email protected] Expectation propagation and Variational inference


Kullback-Leibler divergence

9 / 33 [email protected] Expectation propagation and Variational inference


Kullback-Leibler divergence

Kullback-Leibler divergence is a distance between two


distributions
Z
  q(x)
KL p k q = − p(x) ln dx
p(x)

Some properties
 
• Non-negative KL p(x) k q(x) ≥ 0,

 
• KL p(x) k q(x) = 0 if and only if p(x) = q(x)

   
• Non-symmetric KL p(x) k q(x) 6= KL q(x) k p(x)

10 / 33 [email protected] Expectation propagation and Variational inference


Deterministic approx. inference: The idea
Idea Approximate the posterior p(θ|D) with q(θ) ∈ Q

Find a tractable distribution q(θ) ∈ Q which is close to p(θ|D)

q̂(θ) = arg min D(p(θ|D) k q(θ))


q(θ)∈Q

As distance we use the Kullback-Leibler divergence.


Z
  q(x)
KL p(x) k q(x) = − p(x) ln dx
p(x)

We have two classes of variational approximations


 
• D(p(θ|D) k q(θ)) = KL p(θ|D) k q(θ) gives expectation
propagation
 
• D(p(θ|D) k q(θ)) = KL q(θ) k p(θ|D) gives variational
inference
11 / 33 [email protected] Expectation propagation and Variational inference
Minimization of KL-divergence

Suppose we have
   
p(θ) = 0.2N θ; 5, 12 + 0.8N θ; −5, 22

Let
 
q(θ) = N θ; µ, σ 2
   
q̂ = minµ,σ KL p k qµ,σ q̂ = minµ,σ KL qµ,σ k p
0.2 p(θ) 0.2 p(θ)
0.15 q̂(θ) 0.15 q̂(θ)
0.1 0.1
0.05 0.05
0 0
−10 0 10 −10 0 10
12 / 33 [email protected] Expectation propagation and Variational inference
Minimization of KL-divergence
   
q̂ = minµ,σ KL p k qµ,σ q̂ = minµ,σ KL qµ,σ k p
0.2 p(θ) 0.2 p(θ)
0.15 q̂(θ) 0.15 q̂(θ)
0.1 0.1
0.05 0.05
0 0
−10 0 10 −10 0 10
   
KL p k qµ,σ = KL qµ,σ k p =
Z Z
qµ,σ (θ) p(θ)
− p(θ) ln dθ − qµ,σ (θ) ln dθ
p(θ) qµ,σ (θ)

non-zero-forcing zero-forcing
Where p  0, q needs to be  0 Where p ≈ 0, q needs to be ≈ 0.

12 / 33 [email protected] Expectation propagation and Variational inference


Minimization of KL-divergence
 
q̂ = minµ,σ KL p k qµ,σ
For the first form
0.2 p(θ)
q̂(θ)
 
KL p k qµ,σ 0.15
Z
qµ,σ (θ) 0.1
= − p(θ) ln dθ
p(θ) 0.05
0
−10 0 10
 
we have that µ̂, σ̂ = minµ,σ KL p k qµ,σ gives
Z
µ̂ = θp(θ)dθ = Ep [θ]
Z
σ̂ = (θ − µ̂)2 p(θ)dθ = Ep [(θ − µ̂)2 ]

(see Exercise 6.2 )


We call this moment matching
13 / 33 [email protected] Expectation propagation and Variational inference
Moment matching in factor graphs

14 / 33 [email protected] Expectation propagation and Variational inference


Example: Moment matching in factor graphs

Remember example from lecture 5.


• Prior: w ∼ N (m, κ2 )
• Likelihood: t|w ∼ N (w, σ 2 )
Objective: compute the marginals p(t)

N (w; µ, κ2 ) N (t; w, σ 2 )
w t
µ1 µ2 µ3

µ1 (w) = N (w; m, κ2 )
µ2 (w) = µ1 (w)
Z
µ3 (t) = N (t; w, σ 2 )µ1 (w) dw

= N (t; m, σ 2 + κ2 )
p(t) ∝ µ3 (t) = N (t; m, σ 2 + κ2 )
15 / 33 [email protected] Expectation propagation and Variational inference
Example: Moment matching in factor graphs

Remember example from lecture 5. Now we measure the sign of t


• Prior: w ∼ N (m, κ2 )
• Likelihood: y = sign(t), t|w ∼ N (w, σ 2 )
Objective: compute the posterior p(w|y = 1)

N (w; µ, κ2 ) N (t; w, σ 2 ) δ(y = sign(t)) δ(y = 1)


w t y
µ1 µ2 µ3

µ1 (w) = N (w; m, κ2 )
µ2 (w) = µ1 (w)
Z
µ3 (t) = N (t; w, σ 2 )µ1 (w) dw

= N (t; m, σ 2 + κ2 )

15 / 33 [email protected] Expectation propagation and Variational inference


Example: Moment matching in factor graphs

Remember example from lecture 5. Now we measure the sign of t


• Prior: w ∼ N (m, κ2 )
• Likelihood: y = sign(t), t|w ∼ N (w, σ 2 )
Objective: compute the posterior p(w|y = 1)

N (w; µ, κ2 ) N (t; w, σ 2 ) δ(y = sign(t)) δ(y = 1)


w t y
µ1 µ2 µ3 µ6 µ5 µ4

µ1 (w) = N (w; m, κ2 )
µ2 (w) = µ1 (w) µ4 (y) = δ(y = 1)
µ5 (y) = µ4 (y)
Z
µ3 (t) = N (t; w, σ 2 )µ1 (w) dw
µ6 (t) = δ(t > 0)
= N (t; m, σ 2 + κ2 )

15 / 33 [email protected] Expectation propagation and Variational inference


Example: Moment matching in factor graphs

N (w; µ, κ2 ) N (t; w, σ 2 ) δ(y = sign(t)) δ(y = 1)


w t y
µ1 µ2 µ3 µ6 µ5 µ4

µ1 (w) = N (w; m, κ2 ) µ4 (y) = δ(y = 1)


µ2 (w) = µ1 (w) µ5 (y) = µ4 (y)
2 2
µ3 (t) = N (t; m, σ + κ ) µ6 (t) = δ(t > 0)

15 / 33 [email protected] Expectation propagation and Variational inference


Example: Moment matching in factor graphs

N (w; µ, κ2 ) N (t; w, σ 2 ) δ(y = sign(t)) δ(y = 1)


w t y
µ1 µ2 µ3 µ6 µ5 µ4

µ1 (w) = N (w; m, κ2 ) µ4 (y) = δ(y = 1)


µ2 (w) = µ1 (w) µ5 (y) = µ4 (y)
2 2
µ3 (t) = N (t; m, σ + κ ) µ6 (t) = δ(t > 0)

Idea Approximate marginals as Gaussians using moment matching!

p(t|y)
p(t|y) ∝ µ3 (t)µ6 (t)
q̂(t)
mt = Ep(t|y) [t]
σt2 = Varp(t|y) [t]
 
q̂(t) = N t; mt , σt2
15 / 33 [email protected] Expectation propagation and Variational inference
Example: Moment matching in factor graphs

N (w; µ, κ2 ) N (t; w, σ 2 ) δ(y = sign(t)) δ(y = 1)


w t y
µ1 µ2 µ3 µ6 µ5 µ4

µ1 (w) = N (w; m, κ2 ) µ4 (y) = δ(y = 1)


µ2 (w) = µ1 (w) µ5 (y) = µ4 (y)
2 2
µ3 (t) = N (t; m, σ + κ ) µ6 (t) = δ(t > 0)
• The outgoing message must be consistent with the newly
approximated marginal.

q̂(t)
µ7 (t) =
µ3 (t)

15 / 33 [email protected] Expectation propagation and Variational inference


Example: Moment matching in factor graphs

N (w; µ, κ2 ) N (t; w, σ 2 ) δ(y = sign(t)) δ(y = 1)


w t y
µ1 µ2 µ3 µ6 µ5 µ4
µ8 µ7
µ1 (w) = N (w; m, κ2 ) µ4 (y) = δ(y = 1)
µ2 (w) = µ1 (w) µ5 (y) = µ4 (y)
µ3 (t) = N (t; m, σ 2 + κ2 ) µ6 (t) = δ(t > 0)
Now we can proceed using the Gaussian rules from Lecture 5.
q̂(t)  
µ7 (t) = ∝ N t; m7 , σ72 (Gaussian division)
µ3 (t)
Z    
µ8 (w) = N t; w, σ 2 µ7 (t) dt = N w; m7 , σ72 + σ 2 (Gaussian marginalization)
 
2
q(w) ∝ µ1 (w)µ8 (w) = N w; mw , σw (Gaussian multiplication)

15 / 33 [email protected] Expectation propagation and Variational inference


Example: Moment matching in factor graphs

Simulation of previous experiment with the following parameters

m = 0, κ = 1, σ = 1, y=1

Below the approximated marginal of w and t is compared with


importance sampling.
p(w|y = 1) p(t|y = 1)
0.5 Moment matching Moment matching
Importance sampling 0.5 Importance sampling
0.4
0.4
0.3
0.3
0.2 0.2

0.1 0.1

0.0 0.0
4 3 2 1 0 1 2 3 4 4 3 2 1 0 1 2 3 4
w t
Note, we only did explicit moment matching in node t !
16 / 33 [email protected] Expectation propagation and Variational inference
Moment matching in graphs
If more than one node in the graph needs to be approximated the
solution will be iterative.
Example
• Prior: w ∼ N (m, κ2 )
• Likelihood: yn = sign(tn ), tn |w ∼ N (w, σ 2 ), n = 1, 2
What is p(w|y1 , y2 )?

µ1
µ2 µ11 • Moment matching in t1
w
⇒ a new message µ4
µ6 µ7
µ3 µ5 µ8 µ10 • Pass the messages to t2
• Moment matching in t2
t1 t2
⇒ a new message µ9
µ4 µ9
• Pass the messages to t1
y1 y2
• ...
17 / 33 [email protected] Expectation propagation and Variational inference
Moment matching in factor graphs

Moment matching in factor graphs

1. Initialize all messages, for example with µi (x) = N (0, 1)


2. Recompute all required messages in the graph. When
you ship a non-Gaussian message to a node x, then
a) compute p(x) and approximate it with a Gaussian q̂(x)
using moment matching, and then
b) compute the outgoing message from node x to factor f as
f1
µx→f (x)
q̂(x) x f
µx→f (x) =
µf →x (x) ) µf →x (x)
fs (x
x
µ f s→
3. Repeat 2 untill convergence

This is an instance of a more general framework called


Expectation propagation (see Bishop sec 10.7 or Barber 28.8)
18 / 33 [email protected] Expectation propagation and Variational inference
Expectation propagation

19 / 33 [email protected] Expectation propagation and Variational inference


Expectation propagation

• Suppose we have a probabilistic model on the form


I I
Y 1 Y
p(θ, D) = fi (θ) ⇒ p(θ|D) = fi (θ),
p(D)
i=1 i=1

i.e. for a factor graph, fi (θ) are the factors.


• Expectation propagation approximates the posterior as
I I
1 Y e.g. 1 Y
q(θ) = qi (θ) = N (θ; µi , Σi )
Z Z
i=1 i=1

Aim Find a tractable distribution q(θ) which is close to p(θ|D)


 
q̂(θ) = arg min KL p(θ|D) k q(θ)
q(θ)∈Q

20 / 33 [email protected] Expectation propagation and Variational inference


Expectation propagation

• The terms qj (θ) are estimated iteratively by keeping the last


estimates of {q̂i (θ)}i6=j
 
1 Y 1 Y
q̂j (θ) = arg min KL  fj (θ) q̂i (θ) k qj (θ) q̂i (θ)
qj p(D) Z
i6=j i6=j

Comments
• In the factor graph example we approximated the marginals
p(θ i ) rather than the factors fj (θ).
• It can be shown that this is equivalent (by factorizing each
approximated factor f˜j (θ) further into its marginals)

21 / 33 [email protected] Expectation propagation and Variational inference


Variational inference

22 / 33 [email protected] Expectation propagation and Variational inference


Deterministic approx. inference: The idea
Idea Approximate the posterior p(θ|D) with q(θ) ∈ Q

Find a tractable distribution q(θ) ∈ Q which is close to p(θ|D)

q̂(θ) = arg min D(p(θ|D) k q(θ))


q(θ)∈Q

As distance we use the Kullback-Leibler divergence.


Z
  q(x)
KL p(x) k q(x) = − p(x) ln dx
p(x)

We have two classes of variational approximations


 
• D(p(θ|D) k q(θ)) = KL p(θ|D) k q(θ) gives expectation
propagation
 
• D(p(θ|D) k q(θ)) = KL q(θ) k p(θ|D) gives variational
inference
23 / 33 [email protected] Expectation propagation and Variational inference
Variational inference

Aim Find a tractable distribution q(θ) which is close to p(θ|D)


 
q̂(θ) = arg min KL q(θ) k p(θ|D)
q(θ)∈Q

It is tricky compute this since it contains an expression of the


posterior which we don’t have access to.
Z
  p(θ|D)
KL q(θ) k p(θ|D) = − q(θ) ln dθ
| {z } q(θ)

p(θ, D)
Z
= − q(θ) ln dθ
q(θ)p(D)
p(θ, D)
Z
By maximizing L(q) we = ln p(D) − q(θ) ln dθ
q(θ)
get minimum of (*)! | {z }
L(q)

24 / 33 [email protected] Expectation propagation and Variational inference


Mean field variational inference (I/II)
We want to maximize the evidence lower bound (ELBO)
p(θ, D)
Z
q(θ) ln dθ = Eq [ln p(θ, D)] − Eq [ln q(θ)]
q(θ)

In mean field variational inference we do the following restric-


tion on q(θ): YI
q(θ) = qi (θ i )
i

where θ = {θ 1 , . . . , θ I } are disjoint elements of θ

From variation of calculus in can be shown that the qi (θ i ) that


minimizes ELBO fullfills
  
qj (θ) = exp E{qi }i6=j ln p(θ, D) + const.
(see for example page 465 in Bishop)
25 / 33 [email protected] Expectation propagation and Variational inference
Mean field variational inference (II/II)

As in Expectation propagation the terms qj (θ) are estimated


iteratively by keeping the last estimates of {q̂i (θ)}i6=j fixed

Mean field variational iteration using coordinate decent

Solve the problem iteratively


1. For j = 1, 2, , . . . , I
• Fix {qi (θ i )}i6=j to their last estimated values {q̂i (θ i )}i6=j
• Find solution to
 
ln q̂j (θ j ) = E{q̂i }i6=j ln p(θ, D)

• Normalize q̂j (θ j )
2. Repeat 1 until convergence

26 / 33 [email protected] Expectation propagation and Variational inference


Ex.: Variational linear regression (I/VI)
Recall again Bayesian linear regression from lecture 2.
For ease of notation, we consider a scalar parameter w

The probabilistic model with unknown w is given by:


 
p(w) = N w; 0, α−1 prior distribution
N
Y    
p(y | w) = N yn ; wxn , β −1 = N y; wx, β −1 IN
n=1
likelihood.

What value shall we pick for α?

Put a prior on α as well!

27 / 33 [email protected] Expectation propagation and Variational inference


Ex.: Variational linear regression (II/VI)

The probabilistic model with unknown w and α is given by:


p(α) = Gam (α; a0 , b0 )
 
p(w|α) = N w; 0, α−1 prior distribution
 
p(y | w) = N y; wx, β −1 IN likelihood.

where Gam (α; a, b) is the Gamma distribution


a = 1 b = 0.5
0.4 a = 2 b = 0.5
Gam (α; a, b)
1 a a−1 −bα a=2 b=1
= b α e , 0.2
Γ(a)
α ∈ [0, ∞) 0
0 2 4 6 8 10
α
28 / 33 [email protected] Expectation propagation and Variational inference
Ex.: Variational linear regression (III/VI)
We make an assumption that the variational distribution factorizes
q(w, α) = q(w)q(α)
The two equations we will iterate are
ln q̂(α) = Eq̂(w) [ln p(y, w, α)] + const.
ln q̂(w) = Eq̂(α) [ln p(y, w, α)] + const.
The joint of distribution of y, w and α is
p(y, w, α) = p(y|w)p(w|α)p(α) ⇒
ln p(y, w, α) = ln p(y|w) + ln p(w|α) + ln p(α)
where
β
ln p(y|w) = − (wx − y)T (wx − y) + const.
2
1 α
ln p(w|α) = ln α − w2 + const.
2 2
ln p(α) = (a0 − 1) ln α − b0 α + const.
29 / 33 [email protected] Expectation propagation and Variational inference
Ex.: Variational linear regression (IV/VI)
We start with q̂(α)

ln q̂(α) = Eq̂(w) [ln p(y, w, α)] + const.


= Eq̂(w) [ln p(y|w) + ln p(w|α) + ln p(α)] + const.
= ln p(α) + Eq̂(w) [ln p(w|α)] + const.
1 α
= (a0 − 1) ln α − b0 α + ln α − Eq̂(w) [w2 ] + const.
2 2
We recognize this as a Gamma distribution

ln q̂(α) = ln Gam (α; aN , bN ) = (aN − 1) ln α − bN α

with
1
aN = a0 + ,
2
1
bN = b0 + Eq̂(w) [w2 ]
2
30 / 33 [email protected] Expectation propagation and Variational inference
Ex.: Variational linear regression (V/VI)
Now we proceed with q̂(w).
ln q̂(w) = Eq̂(α) [ln p(y, w, α)] + const.
= Eq̂(α) [ln p(y|w) + ln p(w|α) + ln p(α)] + const.
= ln p(y|w) + Eq̂(α) [ln p(w|α)] + const.
β 1
= − (wx − y)T (wx − y) − E[α]w2 + const.
2 2
1
= − (Eq̂(α) [α] + βx x)w + βxT yw + const.
T 2
2
(w − mN )2
=− 2 + const.
2σN
2

We recognize this as a Gaussian q̂(w) = N w; mN , σN where
2
σN = (Eq̂(α) [α] + βxT x)−1
2 T
mN = βσN x y
31 / 33 [email protected] Expectation propagation and Variational inference
Ex.: Variational linear regression (VI/VI)
Since  
2
q̂(α) = Gam (α; aN , bN ) , q̂(w) = N w; mN , σN
we can compute
aN
Eq̂(α) [α] =
bN
Eq̂(w) [w ] = m2N + σN
2 2

Solution: Iterate the following two steps until convergence


1. Compute 1
aN = a0 + ,
2
1 2 2
bN = b0 + (mN + σN )
2
−1
2. Compute

2 aN T
σN = + βx x
bN
2 T
mN = βσN x y
32 / 33 [email protected] Expectation propagation and Variational inference
A few concepts to summarize lecture 6

 
Kullback-Leibler (KL) divergence: A distance KL p k q
between two distributions p and q

Deterministic approximative inference: Approximate Bayesian


inference where KL divergence is minimized between the
approximate posterior q and the true posterior p.

Expectation propagation:
  A form of deterministic approximative
inference where KL p k q is minimized.

Variational inference:
 A form of deterministic approximative
inference where KL q k p is minimized.

33 / 33 [email protected] Expectation propagation and Variational inference

You might also like