0% found this document useful (0 votes)
16 views76 pages

05 Vae

Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
16 views76 pages

05 Vae

Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 76

Variational Autoencoders

Recap: Story so far


• A classification MLP actually comprises two components
• A “feature extraction network” that converts the inputs into linearly
separable features
• Or nearly linearly separable features
• A final linear classifier that operates on the linearly separable features
• Neural networks can be used to perform linear or non-linear PCA
• “Autoencoders”
• Can also be used to compose constructive dictionaries for data
• Which, in turn can be used to model data distributions
Recap: The penultimate layer
𝑦1 𝑦2

y2

x1 x2
y1

• The network up to the output layer may be viewed as a transformation that


transforms data from non-linear classes to linearly separable features
• We can now attach any linear classifier above it for perfect classification
• Need not be a perceptron
• In fact, slapping on an SVM on top of the features may be more generalizable!
Recap: The behavior of the layers
Recap: Auto-encoders and PCA
Training: Learning 𝑊 by minimizing
𝐱ො L2 divergence

𝒘𝑻 xො = 𝑤 𝑇 𝑤x
2
𝑑𝑖𝑣 xො , x = x − xො = x − w 𝑇 𝑤x 2

𝒘
෡ = argmin 𝐸 𝑑𝑖𝑣 xො , x
𝑊
𝑊
𝐱 ෡ = argmin 𝐸 x − w 𝑇 𝑤x 2
𝑊
𝑊

5
Recap: Auto-encoders and PCA
𝐱ො

𝒘𝑻

• The autoencoder finds the direction of maximum energy


• Variance if the input is a zero-mean RV

• All input vectors are mapped onto a point on the principal


axis
6
Recap: Auto-encoders and PCA

• Varying the hidden layer value only generates data along


the learned manifold
• May be poorly learned
• Any input will result in an output along the learned manifold
Recap: Learning a data-manifold
Sax dictionary

DECODER

• The decoder represents a source-specific generative


dictionary
• Exciting it will produce typical data from the source!
8
Overview
• Just as autoencoders can be viewed as performing a non-linear PCA,
variational autoencoders can be viewed as performing a non-linear
Factor Analysis (FA)
• Variational autoencoders (VAEs) get their name from variational
inference, a technique that can be used for parameter estimation
• We will introduce Factor Analysis, variational inference and
expectation maximization, and finally VAEs
Why Generative Models? Training data
• Unsupervised/Semi-supervised learning: More training data available
• E.g. all of the videos on YouTube
Why generative models? Many right answers
• Caption -> Image • Outline -> Image
https://openreview.net/pdf?id=Hyvw0L9el

A man in an orange jacket with


sunglasses and a hat skis down a hill

https://arxiv.org/abs/1611.07004
Why generative models? Intrinsic to task
Example: Super resolution

https://arxiv.org/abs/1609.04802
Why generative models? Insight

• What kind of structure can we find in complex


observations (MEG recording of brain activity
above, gene-expression network to the left)?
• Is there a low dimensional manifold underlying
these complex observations?
• What can we learn about the brain, cellular
https://bmcbioinformatics.biomedcentral.c
function, etc. if we know more about these om/articles/10.1186/1471-2105-12-327

manifolds?
Factor Analysis
• Generative model: Assumes that data are generated from real valued
latent variables

Bishop – Pattern Recognition and Machine Learning


Factor Analysis model
Factor analysis assumes a generative model
• where the 𝑖𝑡ℎ observation, 𝒙𝒊 ∈ ℝ𝐷 is conditioned on
• a vector of real valued latent variables 𝒛𝒊 ∈ ℝ𝐿 .
Here we assume the prior distribution is Gaussian:
𝑝 𝒛𝒊 = 𝒩(𝒛𝒊 |𝝁𝟎 , 𝚺𝟎 )
We also will use a Gaussian for the data likelihood:
𝑝 𝒙𝒊 𝒛𝒊 , 𝑾, 𝝁, 𝚿 = 𝒩(𝑾𝒛𝒊 + 𝝁, 𝚿)

Where 𝑾 ∈ ℝ𝐷×𝐿 , 𝚿 ∈ ℝ𝐷×𝐷 , 𝚿 is diagonal


Marginal distribution of observed 𝒙𝒊

𝑝 𝒙𝒊 𝑾, 𝝁, 𝚿 = න 𝒩(𝑾𝒛𝒊 + 𝝁, 𝚿) 𝒩 𝒛𝒊 𝝁𝟎 , 𝚺𝟎 𝐝𝒛𝒊

= 𝒩 𝒙𝒊 𝑾𝝁𝟎 + 𝝁, 𝚿 + 𝑾 𝚺𝟎 𝑾𝑇
Note that we can rewrite this as:
𝑝 𝒙𝒊 𝑾෢, 𝝁ෝ , 𝚿 = 𝒩 𝒙𝒊 𝝁 ෢𝑾
ෝ, 𝚿 + 𝑾 ෢𝑇
1
−2
Where 𝝁ෝ = 𝑾𝝁𝟎 + 𝝁 and 𝑾 ෢ = 𝑾𝚺 .
𝟎
Thus without loss of generality (since 𝝁𝟎 , 𝚺𝟎 are absorbed into learnable
parameters) we let:
𝑝 𝒛𝒊 = 𝒩 𝒛𝒊 𝟎, 𝑰
And find:
𝑝 𝒙𝒊 𝑾, 𝝁, 𝚿 = 𝒩 𝒙𝒊 𝝁, 𝚿 + 𝑾𝑾𝑇
Marginal distribution interpretation
• We can see from 𝑝 𝒙𝒊 𝑾, 𝝁, 𝚿 = 𝒩 𝒙𝒊 𝝁, 𝚿 + 𝑾𝑾𝑇 that the
covariance matrix of the data distribution is broken into 2 terms
• A diagonal part 𝚿: variance not shared between variables
• A low rank matrix 𝑾𝑾𝑇 : shared variance due to latent factors
Special Case: Probabilistic PCA (PPCA)
• Probabilistic PCA is a special case of Factor Analysis
• We further restrict 𝚿 = 𝜎 2 𝑰 (assume isotropic independent variance)
• Possible to show that when the data are centered (𝝁 = 0), the limiting
case where 𝜎 → 0 gives back the same solution for 𝑾 as PCA
• Factor analysis is a generalization of PCA that models non-shared
variance (can think of this as noise in some situations, or individual
variation in others)
Inference in FA
• To find the parameters of the FA model, we use the Expectation
Maximization (EM) algorithm
• EM is very similar to variational inference
• We’ll derive EM by first finding a lower bound on the log-likelihood
we want to maximize, and then maximizing this lower bound
Evidence Lower Bound decomposition
• For any distributions 𝑞 𝑧 , 𝑝(𝑧) we have:
𝑞(𝑧)
KL 𝑞 𝑧 || 𝑝 𝑧 ≜ න 𝑞 𝑧 log 𝐝𝑧
𝑝(𝑧)
• Consider the KL divergence of an arbitrary weighting distribution
𝑞 𝑧 from a conditional distribution 𝑝 𝑧|𝑥, 𝜃 :
𝑞(𝑧)
KL 𝑞 𝑧 || 𝑝 𝑧|𝑥, 𝜃 ≜ න 𝑞 𝑧 log 𝐝𝑧
𝑝(𝑧|𝑥, 𝜃)

= න 𝑞 𝑧 [log 𝑞 𝑧 − log 𝑝(𝑧|𝑥, 𝜃)] 𝐝𝑧


Applying Bayes
𝑝 𝑥 𝑧, 𝜃 𝑝(𝑧|𝜃)
log 𝑝 𝑧 𝑥, 𝜃 = log
𝑝(𝑥|𝜃)
= log 𝑝 𝑥 𝑧, 𝜃 + log 𝑝 𝑧 𝜃 − log 𝑝 𝑥 𝜃
Then:
KL 𝑞 𝑧 || 𝑝 𝑧|𝑥, 𝜃 = න 𝑞 𝑧 [log 𝑞 𝑧 − log 𝑝(𝑧|𝑥, 𝜃)] 𝐝𝑧

= න 𝑞 𝑧 log 𝑞 𝑧 − log 𝑝 𝑥 𝑧, 𝜃 − log 𝑝 𝑧 𝜃 + log 𝑝 𝑥 𝜃 𝐝𝑧


Rewriting the divergence
• Since the last term does not depend on z, and we know ‫ 𝑧 𝑞 ׬‬d𝑧 = 1, we can pull it out of the
integration:
න 𝑞 𝑧 log 𝑞 𝑧 − log 𝑝 𝑥 𝑧, 𝜃 − log 𝑝 𝑧 𝜃 + log 𝑝 𝑥 𝜃 𝐝𝑧

= න 𝑞 𝑧 log 𝑞 𝑧 − log 𝑝 𝑥 𝑧, 𝜃 − log 𝑝 𝑧 𝜃 𝐝𝑧 + log 𝑝 𝑥 𝜃

𝑞(𝑧)
= න 𝑞 𝑧 log 𝐝𝑧 + log 𝑝 𝑥 𝜃
𝑝 𝑥 𝑧, 𝜃 𝑝(𝑧, 𝜃)
𝑞(𝑧)
= න 𝑞 𝑧 log 𝐝𝑧 + log 𝑝 𝑥 𝜃
𝑝(𝑥, 𝑧 |𝜃)
Then we have:
KL 𝑞 𝑧 || 𝑝 𝑧|𝑥, 𝜃 = KL 𝑞 𝑧 || 𝑝 𝑥, 𝑧 |𝜃 + log 𝑝 𝑥 𝜃
Evidence Lower Bound
• From basic probability we have:
KL 𝑞 𝑧 || 𝑝 𝑧|𝑥, 𝜃 = KL 𝑞 𝑧 || 𝑝 𝑥, 𝑧 |𝜃 + log 𝑝 𝑥 𝜃
• We can rearrange the terms to get the following decomposition:
log 𝑝 𝑥 𝜃 = KL 𝑞 𝑧 || 𝑝 𝑧|𝑥, 𝜃 − KL 𝑞 𝑧 || 𝑝 𝑥, 𝑧 |𝜃
• We define the evidence lower bound (ELBO) as:
ℒ 𝑞, 𝜃 ≜ −KL 𝑞 𝑧 || 𝑝 𝑥, 𝑧 |𝜃
Then:
log 𝑝 𝑥 𝜃 = KL 𝑞 𝑧 ||𝑝 𝑧|𝑥, 𝜃 + ℒ 𝑞, 𝜃
Why the name evidence lower bound?
• Rearranging the decomposition
log 𝑝 𝑥 𝜃 = KL 𝑞 𝑧 ||𝑝 𝑧|𝑥, 𝜃 + ℒ 𝑞, 𝜃
• we have
ℒ 𝑞, 𝜃 = log 𝑝 𝑥 𝜃 − KL 𝑞 𝑧 || 𝑝 𝑧|𝑥, 𝜃
• Since KL 𝑞 𝑧 ||𝑝 𝑧|𝑥, 𝜃 ≥ 0, ℒ 𝑞, 𝜃 is a lower bound on the log-
likelihood we want to maximize
• 𝑝 𝑥 𝜃 is sometimes called the evidence
• When is this bound tight? When 𝑞 𝑧 = 𝑝 𝑧|𝑥, 𝜃
• The ELBO is also sometimes called the variational bound
Visualizing ELBO decomposition

Bishop – Pattern Recognition and Machine Learning

• Note: all we have done so far is decompose the log


probability of the data, we still have exact equality
• This holds for any distribution 𝑞
Expectation Maximization
• Expectation Maximization alternately optimizes the ELBO, ℒ 𝑞, 𝜃 ,
with respect to 𝑞 (the E step) and 𝜃 (the M step)

• Initialize 𝜃 (0)
• At each iteration 𝑡 = 1, …
• E step: Hold 𝜃 (𝑡−1) fixed, find 𝑞 (𝑡) which maximizes ℒ 𝑞, 𝜃 (𝑡−1)
• M step: Hold 𝑞 (𝑡) fixed, find 𝜃 (𝑡) which maximizes ℒ 𝑞 (𝑡) , 𝜃
The E step

Bishop – Pattern Recognition and Machine Learning

• Suppose we are at iteration 𝑡 of our algorithm. How do we maximize


ℒ 𝑞, 𝜃 (𝑡−1) with respect to 𝑞? We know that:
argmax𝑞 ℒ 𝑞, 𝜃 (𝑡−1) = argmax𝑞 log 𝑝 𝑥|𝜃 𝑡−1 − KL 𝑞 𝑧 || 𝑝 𝑧|𝑥, 𝜃 (𝑡−1)
The E step

• The first term does not involve 𝑞, and we know the KL


divergence must be non-negative
• The best we can do is to make the KL divergence 0
• Thus the solution is to set 𝒒 𝒕 𝒛 ← 𝒑 𝒛 𝒙, 𝜽 𝒕−𝟏

Bishop – Pattern Recognition and Machine Learning

• Suppose we are at iteration 𝑡 of our algorithm. How do we maximize


ℒ 𝑞, 𝜃 (𝑡−1) with respect to 𝑞? We know that:
argmax𝑞 ℒ 𝑞, 𝜃 (𝑡−1) = argmax𝑞 log 𝑝 𝑥|𝜃 𝑡−1 − KL 𝑞 𝑧 || 𝑝 𝑧|𝑥, 𝜃 (𝑡−1)
The E step

Bishop – Pattern Recognition and Machine Learning

• Suppose we are at iteration 𝑡 of our algorithm. How do we maximize


ℒ 𝑞, 𝜃 (𝑡−1) with respect to 𝑞? 𝒒 𝒕 𝒛 ← 𝒑 𝒛 𝒙, 𝜽 𝒕−𝟏
The M step
• Fixing 𝑞 𝑡 𝑧 we now solve:
argmax𝜃 ℒ 𝑞 (𝑡) , 𝜃 = argmax𝜃 −KL 𝑞 (𝑡) 𝑧 || 𝑝 𝑥, 𝑧|𝜃
(𝑡)
(𝑡)
𝑞 𝑧
= argmax𝜃 − න 𝑞 𝑧 log 𝐝𝑧
𝑝 𝑥, 𝑧|𝜃
= argmax𝜃 න 𝑞 (𝑡) 𝑧 log 𝑝 𝑥, 𝑧 𝜃 − log 𝑞 (𝑡) 𝑧 𝐝𝑧

= argmax𝜃 න 𝑞 (𝑡) 𝑧 log 𝑝 𝑥, 𝑧 𝜃 − 𝑞 (𝑡) 𝑧 log 𝑞 (𝑡) 𝑧 𝐝𝑧

= argmax𝜃 න 𝑞 (𝑡) 𝑧 log 𝑝 𝑥, 𝑧 𝜃 𝐝𝑧


Constant w.r.t. 𝜃
= argmax𝜃 𝔼𝑞 𝑡 (𝑧) log 𝑝 𝑥, 𝑧 𝜃
The M step

Bishop – Pattern Recognition and Machine Learning

• After applying the E step, we increase the likelihood of the data by finding better
parameters according to: 𝜃 (𝑡) ← 𝐚𝐫𝐠𝐦𝐚𝐱 𝜽 𝔼𝒒 𝒕 (𝒛) 𝐥𝐨𝐠 𝒑 𝒙, 𝒛 𝜽
EM algorithm
• Initialize 𝜃 (0)
• At each iteration 𝑡 = 1, …
• E step: Update 𝑞 𝑡 𝑧 ← 𝑝 𝑧 𝑥, 𝜃 𝑡−1
• M step: Update 𝜃 (𝑡) ← argmax𝜃 𝔼𝑞 𝑡 (𝑧) log 𝑝 𝑥, 𝑧 𝜃
Why does EM work?
• EM does coordinate ascent on the ELBO, ℒ 𝑞, 𝜃
• Each iteration increases the log-likelihood until 𝑞 𝑡 converges (i.e. we
reach a local maximum)!
• Simple to prove By definition of argmax in the M step:
ℒ 𝑞 𝑡 , 𝜃 (𝑡) ≥ ℒ 𝑞 𝑡 , 𝜃 (𝑡−1)
Notice after the E step: By simple substitution:
ℒ 𝑞 𝑡 , 𝜃 (𝑡−1) ℒ 𝑞 𝑡 , 𝜃 (𝑡) ≥ log 𝑝 𝑥 𝜃 𝑡−1
Rewriting the left hand side:
= log 𝑝(𝑥|𝜃 (𝑡−1) ) − KL 𝑝 𝑧|𝑥, 𝜃 𝑡−1 || 𝑝 𝑧|𝑥, 𝜃 𝑡−1
log 𝑝(𝑥|𝜃 (𝑡) ) − KL 𝑝 𝑧|𝑥, 𝜃 𝑡−1 || 𝑝 𝑧|𝑥, 𝜃 𝑡
= log 𝑝(𝑥|𝜃 (𝑡−1) )
The ELBO is tight! ≥ log 𝑝 𝑥 𝜃 𝑡−1
Noting that KL is non-negative:
𝐥𝐨𝐠 𝒑 𝒙 𝜽 𝒕 ≥ 𝐥𝐨𝐠 𝒑 𝒙 𝜽 𝒕−𝟏
Why does EM work?

Bishop – Pattern Recognition and Machine Learning

• This proof is saying the same thing we saw in pictures. Make the KL 0,
then improve our parameter estimates to get a better likelihood
A different perspective
• Consider the log-likelihood of a marginal distribution of the data 𝑥 in a generic
latent variable model with latent variable 𝑧 parameterized by 𝜃:
𝑁 𝑁

ℓ 𝜃 ≜ ෍ log 𝑝 𝑥𝑖 𝜃 = ෍ log න 𝑝 𝑥𝑖 , 𝑧𝑖 𝜃 𝐝𝑧𝑖


𝑖=1 𝑖=1
• Estimating 𝜃 is difficult because we have a log outside of the integral, so it does
not act directly on the probability distribution (frequently in the exponential
family)
• If we observed 𝑧𝑖 , then our log-likelihood would be:
𝑁

ℓ𝑐 𝜃 ≜ ෍ log 𝑝(𝑥𝑖 , 𝑧𝑖 |𝜃)


𝑖=1
This is called the complete log-likelihood
Expected Complete Log-Likelihood
• We can take the expectation of this likelihood over a distribution of the
latent variable 𝑞 𝑧 :
𝑁

𝔼𝑞 𝑧 ℓ𝑐 𝜃 = ෍ න 𝑞 𝑧𝑖 log 𝑝 𝑥𝑖 , 𝑧𝑖 𝜃 d𝑧𝑖
𝑖=1
• This looks similar to marginalizing, but now the log is inside the integral, so
it’s easier to deal with
• We can treat the latent variables as observed and solve this more easily
than directly solving the log-likelihood
• Finding the 𝑞 that maximizes this is the E step of EM
• Finding the 𝜃 that maximizes this is the M step of EM
Back to Factor Analysis
• For simplicity, assume data is centered. We want:
𝑁

argmax𝑾,𝚿 log 𝑝 𝑿 𝑾, 𝚿 = argmax𝑾,𝚿 ෍ log 𝑝 𝒙𝒊 𝑾, 𝚿


𝑁 𝑖=1

= argmax𝑾,𝚿 ෍ log 𝒩 𝒙𝒊 𝟎, 𝚿 + 𝑾𝑾𝑇


𝑖=1
• No closed form solution in general (PPCA can be solved in closed
form)
• 𝚿, 𝑾 get coupled together in the derivative and we can’t solve for
them analytically
EM for Factor Analysis
𝑁

argmax𝑾,𝚿 𝔼𝑞 𝑡 (𝒛) log 𝑝 𝑿, 𝒁 𝑾, 𝚿 = argmax𝑾,𝚿 ෍ 𝔼𝑞 𝑡 (𝒛𝒊) log 𝑝 𝒙𝒊 𝒛𝒊 , 𝑾, 𝚿 + 𝔼𝑞 𝑡 (𝒛𝒊 ) log 𝑝(𝒛𝒊 )


𝑁 𝑖=1

= argmax𝑾,𝚿 ෍ 𝔼𝑞 𝑡 (𝒛𝒊 ) log 𝑝 𝒙𝒊 𝒛𝒊 , 𝑾, 𝚿


𝑖=1
𝑁

= argmax𝑾,𝚿 ෍ 𝔼𝑞 𝑡 (𝒛𝒊 ) log 𝒩(𝑾𝒛𝒊 , 𝚿)


𝑖=1 𝑁
𝑁 1
= argmax𝑾,𝚿 const − log det(𝚿) − ෍ 𝔼𝑞 𝑡 (𝒛𝒊 ) 𝒙𝒊 − 𝑾𝒛𝒊 𝑇 𝚿 −1 𝒙𝒊 − 𝑾𝒛𝒊
2 2
𝑁 𝑖=1
𝑁 1 1
= argmax𝑾,𝚿 − log det(𝚿) − ෍ 𝒙𝑇𝑖 𝚿 −1 𝒙𝑖 − 𝒙𝑇𝒊 𝚿 −1 𝑾𝔼𝑞 𝑡 (𝒛𝒊 ) 𝒛𝑖 + tr 𝑾𝑇 𝚿 −1 𝑾𝔼𝑞 𝑡 𝒛𝒊 𝒛𝒊 𝒛𝑇𝒊
2 2 2
𝑖=1
• We only need these 2 sufficient statistics to enable the M step.
• In practice, sufficient statistics are often what we compute in the E step
Factor Analysis E step

(𝒕−𝟏) 𝑇 (𝑡−1) −1
𝔼𝑞 𝑡 (𝒛𝒊 ) 𝒛𝒊 = 𝑮𝑾 𝚿 𝒙𝑖
𝔼𝑞 𝑡 (𝒛𝒊 ) 𝒛𝒊 𝒛𝑇𝒊 = 𝑮 + 𝔼𝑞 𝑡 (𝒛𝒊 ) 𝒛𝒊 𝔼𝑞 𝑡 (𝒛𝒊 ) 𝒛𝒊 𝑇

Where
−1
𝑮= 𝑰+𝑾 𝑡−1 𝑇 𝚿 𝑡−1 −1 𝑾 𝑡−1

This is derived via the Bayes rule for Gaussians


Factor Analysis M step
−1
𝑁 𝑁

𝑾(𝑡) ← ෍ 𝒙𝑖 𝔼𝑞 𝑡 (𝒛𝒊 ) 𝒛𝒊 𝑇 ෍ 𝔼𝑞 𝑡 𝒛𝒊 𝒛𝒊 𝒛𝑇𝒊


𝑖=1 𝑖=1

𝑁 𝑁
1 1
𝚿 (𝑡) ← diag 𝑇
෍ 𝒙𝒊 𝒙𝒊 − 𝑾 (𝑡) ෍ 𝔼𝑞 𝑡 (𝒛𝒊 ) 𝒛𝒊 𝒙𝑇𝑖
𝑁 𝑁
𝑖=1 𝑖=1
From EM to Variational Inference
• In EM we alternately maximize the ELBO with respect to 𝜃 and
probability distribution (functional) 𝑞
• In variational inference, we drop the distinction between hidden
variables and parameters of a distribution
• I.e. we replace 𝑝(𝑥, 𝑧|𝜃) with 𝑝(𝑥, 𝑧). Effectively this puts a
probability distribution on the parameters 𝜽, then absorbs them into
𝑧
• Fully Bayesian treatment instead of a point estimate for the
parameters
Variational Inference
• Now the ELBO is just a function of our weighting distribution ℒ(𝑞)
• We assume a form for 𝑞 that we can optimize
• For example mean field theory assumes 𝑞 factorizes:
𝑀

𝑞 𝑍 = ෑ 𝑞𝑖 (𝑍𝑖 )
𝑖=1
• Then we optimize ℒ(𝑞) with respect to one of the terms while
holding the others constant, and repeat for all terms
• By assuming a form for 𝑞 we approximate a (typically) intractable true
posterior
Mean Field update derivation
𝑝(𝑋, 𝑍)
ℒ 𝑞 = න 𝑞 𝑍 log 𝑑𝑍 = න 𝑞 𝑍 log 𝑝(𝑋, 𝑍) − 𝑞 𝑍 log 𝑞(𝑍) 𝑑𝑍
𝑞(𝑍)

= න ෑ 𝑞𝑖 (𝑍𝑖 ) log 𝑝(𝑋, 𝑍) − ෍ log 𝑞𝑘 (𝑍𝑘 ) 𝑑𝑍


𝑖 𝑘

= න 𝑞𝑗 (𝑍𝑗 ) න ෑ 𝑞𝑖 (𝑍𝑖 ) log 𝑝(𝑋, 𝑍) − ෍ log 𝑞𝑘 (𝑍𝑘 ) 𝑑𝑍𝑖 𝑑𝑍𝑗


𝑖≠𝑗 𝑘

= න 𝑞𝑗 (𝑍𝑗 ) න log 𝑝(𝑋, 𝑍) ෑ 𝑞𝑖 𝑍𝑖 𝑑𝑍𝑖 − න ෑ ෍ 𝑞𝑖 (𝑍𝑖 ) log 𝑞𝑘 (𝑍𝑘 ) 𝑑𝑍𝑖 𝑑𝑍𝑗


𝑖≠𝑗 𝑖≠𝑗 𝑘

= න 𝑞𝑗 (𝑍𝑗 ) න log 𝑝(𝑋, 𝑍) ෑ 𝑞𝑖 𝑍𝑖 𝑑𝑍𝑖 − log 𝑞𝑗 (𝑍𝑗 ) න ෑ 𝑞𝑖 (𝑍𝑖 ) 𝑑𝑍𝑖 𝑑𝑍𝑗 + const
𝑖≠𝑗 𝑖≠𝑗

= න 𝑞𝑗 (𝑍𝑗 ) න log 𝑝(𝑋, 𝑍) ෑ 𝑞𝑖 𝑍𝑖 𝑑𝑍𝑖 𝑑𝑍𝑗 − න 𝑞𝑗 𝑍𝑗 log 𝑞𝑗 𝑍𝑗 𝑑𝑍𝑗 + const


𝑖≠𝑗

= න 𝑞𝑗 𝑍𝑗 𝔼𝑖≠𝑗 [log 𝑝(𝑋, 𝑍)] 𝑑𝑍𝑗 − න 𝑞𝑗 (𝑍𝑗 ) log 𝑞𝑗 𝑍𝑗 𝑑𝑍𝑗 + const


Mean Field update
(𝑡)
𝑞𝑗 𝑍𝑗
← argmax𝑞𝑗 (𝑍𝑗 ) න 𝑞𝑗 𝑍𝑗 𝔼𝑖≠𝑗 [log 𝑝(𝑋, 𝑍)] 𝑑𝑍𝑗

− න 𝑞𝑗 (𝑍𝑗 ) log 𝑞𝑗 𝑍𝑗 𝑑𝑍𝑗

• The point of this is not the update equations themselves, but the
general idea:
• freeze some of the variables, compute expectations over those
• update the rest using these expectations
Why does Variational Inference work?
• The argument is similar to the argument for EM
• When expectations are computed using the current values for the
variables not being updated, we implicitly set the KL divergence
between the weighting distributions and the posterior distributions to
0
• The update then pushes up the data likelihood

Bishop – Pattern Recognition and Machine Learning


Variational Autoencoder
• Kingma & Welling: Auto-Encoding Variational Bayes proposes
maximizing the ELBO with a trick to make it differentiable
• Discusses both the variational autoencoder model using parametric
distributions and fully Bayesian variational inference, but we will only
discuss the variational autoencoder
Problem Setup
• Assume a generative model with a
𝑝(𝑥𝑖 |𝑧𝑖 , 𝜃)
latent variable distributed according
to some distribution 𝑝(𝑧𝑖 )
• The observed variable is distributed
according to a conditional distribution
𝑧𝑖 ~𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙)
𝑝(𝑥𝑖 |𝑧𝑖 , 𝜃)
• Note the similarity to the Factor
Analysis (FA) setup so far
𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙)
Problem Setup
• We also create a weighting
𝑝(𝑥𝑖 |𝑧𝑖 , 𝜃)
distribution 𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙)
• This will play the same role as 𝑞(𝑧𝑖 ) in
the EM algorithm, as we will see.
𝑧𝑖 ~𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙) • Note that when we discussed EM, this
weighting distribution could be
arbitrary: we choose to condition on
𝑥𝑖 here. This is a choice.
𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙)
• Why does this make sense?
Using a conditional weighting distribution
• There are many values of the latent variables that don’t matter in
practice – by conditioning on the observed variables, we emphasize
the latent variable values we actually care about: the ones most likely
given the observations
• We would like to be able to encode our data into the latent variable
space. This conditional weighting distribution enables that encoding
Problem setup
• Implement 𝑝(𝑥𝑖 |𝑧𝑖 , 𝜃) as a neural
𝑝(𝑥𝑖 |𝑧𝑖 , 𝜃)
network, this can also be seen as a
probabilistic decoder
• Implement 𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙) as a neural
network, we also can see this as a
𝑧𝑖 ~𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙)
probabilistic encoder
• Sample 𝑧𝑖 from 𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙) in the
middle
𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙)
Unpacking the encoder
𝝁 = 𝒖 𝒙𝒊 , 𝑾𝟏 𝚺 = 𝐝𝐢𝐚𝐠(𝒔 𝒙𝒊 , 𝑾𝟐 )

𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙)

𝒙𝒊

• We choose a family of distributions for our conditional distribution 𝑞. For example


Gaussian with diagonal covariance:
𝑞 𝑧𝑖 𝑥𝑖 , 𝜙 = 𝒩 𝑧𝑖 𝜇 = 𝑢 𝑥𝑖 , 𝑊1 , Σ = diag(𝑠 𝑥𝑖 , 𝑊2 )
Unpacking the encoder
𝝁 = 𝒖 𝒙𝒊 , 𝑾𝟏 𝚺 = 𝐝𝐢𝐚𝐠(𝒔 𝒙𝒊 , 𝑾𝟐 )

𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙)

𝒙𝒊

• We create neural networks to predict the parameters of 𝑞 from our data


• In this case, the outputs of our networks are 𝜇 and Σ
Unpacking the encoder
𝝁 = 𝒖 𝒙𝒊 , 𝑾𝟏 𝚺 = 𝐝𝐢𝐚𝐠(𝒔 𝒙𝒊 , 𝑾𝟐 )

𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙)

𝒙𝒊

• We refer to the parameters of our networks, 𝑾𝟏 and 𝑾𝟐 collectively as 𝜙


• Together, networks 𝒖 and 𝒔 parameterize a distribution, 𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙), of the latent
variable 𝒛𝒊 that depends in a complicated, non-linear way on 𝒙𝒊
Unpacking the decoder
𝝁 = 𝒖𝒅 𝒛𝒊 , 𝑾𝟑 𝚺 = 𝐝𝐢𝐚𝐠(𝒔𝒅 𝒛𝒊 , 𝑾𝟒 )

𝑝(𝑥𝑖 |𝑧𝑖 , 𝜃)

𝒛𝒊 ~𝒒(𝒛𝒊 |𝒙𝒊 , 𝝓)

• The decoder follows the same logic, just swapping 𝒙𝒊 and 𝒛𝒊


• We refer to the parameters of our networks, 𝑾𝟑 and 𝑾𝟒 collectively as 𝜃
• Together, networks 𝒖𝒅 and 𝒔𝒅 parameterize a distribution, 𝑝(𝑥𝑖 |𝑧𝑖 , 𝜃), of the
latent variable 𝒙𝒊 that depends in a complicated, non-linear way on 𝒛𝒊
Understanding the setup
• Note that 𝑝 and 𝑞 do not have to use
the same distribution family, this was
𝑝(𝑥𝑖 |𝑧𝑖 , 𝜃) just an example
• This basically looks like an
autoencoder, but the outputs of both
𝑧𝑖 ~𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙) the encoder and decoder are
parameters of the distributions of the
latent and observed variables
respectively
𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙) • We also have a sampling step in the
middle
Using EM for training
• Initialize 𝜃 (0)
• At each iteration 𝑡 = 1, … , 𝑇
• E step: Hold 𝜃 (𝑡−1) fixed, find 𝑞 (𝑡) which maximizes ℒ 𝑞, 𝜃 (𝑡−1)
• M step: Hold 𝑞 (𝑡) fixed, find 𝜃 (𝑡) which maximizes ℒ 𝑞 (𝑡) , 𝜃

• We will use a modified EM to train the model, but we will transform it


so we can use standard back propagation!
Using EM for training
• Initialize 𝜃 (0)
• At each iteration 𝑡 = 1, … , 𝑇
• E step: Hold 𝜃 (𝑡−1) fixed, find 𝜙 (𝑡) which maximizes ℒ 𝜙, 𝜃 𝑡−1 , 𝑥
• M step: Hold 𝜙 (𝑡) fixed, find 𝜃 (𝑡) which maximizes ℒ 𝜙 (𝑡) , 𝜃, 𝑥

• First we modify the notation to account for our choice of using a


parametric, conditional distribution 𝑞
Using EM for training
• Initialize 𝜃 (0)
• At each iteration 𝑡 = 1, … , 𝑇
𝜕ℒ
• E step: Hold 𝜃 (𝑡−1) fixed, find to increase ℒ 𝜙, 𝜃 𝑡−1 ,𝑥
𝜕𝜙
𝜕ℒ
• M step: Hold 𝜙 (𝑡) fixed, find to increase ℒ 𝜙 (𝑡) , 𝜃, 𝑥
𝜕𝜃

• Instead of fully maximizing at each iteration, we just take a step in the


direction that increases ℒ
Computing the loss
• We need to compute the gradient for each mini-batch with 𝐵 data samples using the ELBO/variational
bound ℒ 𝜙, 𝜃, 𝑥𝑖 as the loss
𝐵 𝐵 𝐵
𝑞 𝑧𝑖 𝑥𝑖 , 𝜙
෍ ℒ 𝜙, 𝜃, 𝑥𝑖 = ෍ −KL 𝑞 𝑧𝑖 |𝑥𝑖 , 𝜙 || 𝑝 𝑥𝑖 , 𝑧𝑖 |𝜃 = ෍ −𝔼𝑞 𝑧 𝑥 , 𝜙 log
𝑖 𝑖 𝑝 𝑥𝑖 , 𝑧𝑖 |𝜃
𝑖=1 𝑖=1 𝑖=1

• Notice that this involves an intractable integral over all values of 𝑧


• We can use Monte Carlo sampling to approximate the expectation using 𝐿 samples from 𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙):
𝐿
1
𝔼𝑞(𝑧𝑖 |𝑥𝑖,𝜙) 𝑓 𝑧𝑖 ≃ ෍ 𝑓(𝑧𝑖,𝑗 )
𝐿
𝑗=1
𝐿
1
ℒ 𝜙, 𝜃, 𝑥𝑖 ≃ ℒሚ 𝐴 𝜙, 𝜃, 𝑥𝑖 = ෍ log 𝑝 𝑥𝑖 , 𝑧𝑖,𝑗 |𝜃 − log 𝑞(𝑧𝑖,𝑗 |𝑥𝑖 , 𝜙)
𝐿
𝑗=1
A lower variance estimator of the loss
• We can rewrite
ℒ 𝜙, 𝜃, 𝑥 = −KL 𝑞 𝑧 𝑥, 𝜙 || 𝑝 𝑥, 𝑧|𝜃
𝑞 𝑧 𝑥, 𝜙
= − න 𝑞 𝑧 𝑥, 𝜙 log 𝐝𝑧
𝑝 𝑥|𝑧, 𝜃 𝑝(𝑧)
𝑞 𝑧 𝑥, 𝜙
= − න 𝑞 𝑧 𝑥, 𝜙 log − log 𝑝 𝑥|𝑧, 𝜃 𝐝𝑧 =
𝑝(𝑧)
= −KL 𝑞 𝑧 𝑥, 𝜙 || 𝑝 𝑧 + 𝔼𝑞 𝑧 𝑥, 𝜙 log 𝑝 𝑥|𝑧, 𝜃
• The first term can be computed analytically for some families of distributions (e.g.
Gaussian); only the second term must be estimated
ℒ 𝜙, 𝜃, 𝑥𝑖
𝐿
1
ሚ 𝐵
≃ ℒ 𝜙, 𝜃, 𝑥𝑖 = −KL 𝑞 𝑧𝑖 |𝑥𝑖 , 𝜙 || 𝑝 𝑧𝑖 + ෍ log 𝑝 𝑥𝑖 |𝑧𝑖,𝑗 , 𝜃
𝐿
𝑗=1
Full EM training procedure (not really used)
• For 𝑡 = 1: 𝑏: 𝑇
𝜕ℒ
• Estimate 𝜕𝜙 (How do we do this? We’ll get to it shortly)
𝑝(𝑥𝑖 |𝑧𝑖 , 𝜃)
• Update 𝜙
𝜕ℒ
• Estimate :
𝜕𝜃
• Initialize Δ𝜃 = 0
• For 𝑖 = 𝑡: 𝑡 + 𝑏 − 1
𝑧𝑖 ~𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙)
• Compute the outputs of the encoder (parameters of 𝑞) for 𝑥𝑖
• For ℓ = 1, … 𝐿
• Sample 𝑧𝑖 ~ 𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙)
• Δ𝜃𝑖,ℓ ← Run forward/backward pass on the decoder
(standard back propagation) using either ℒሚ 𝐴 or ℒሚ 𝐵 as
𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙) the loss
• Δ𝜃 ← Δ𝜃 + Δ𝜃𝑖,ℓ
• Update 𝜃
Full EM training procedure (not really used)
• For 𝑡 = 1: 𝑏: 𝑇
𝜕ℒ
• Estimate 𝜕𝜙 (How do we do this? We’ll get to it shortly)
𝑝(𝑥𝑖 |𝑧𝑖 , 𝜃)
• Update 𝜙
𝜕ℒ
• Estimate :
𝜕𝜃
• Initialize Δ𝜃 = 0
• For 𝑖 = 𝑡: 𝑡 + 𝑏 − 1
𝑧𝑖 ~𝑞(𝑧𝑖 |𝑥𝑖 ,First
𝜙) simplification:
• Compute the outputs of the encoder (parameters of 𝑞) for 𝑥𝑖
Let 𝐿 = 1. We just want a
• Sample 𝑧𝑖 ~ 𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙)
stochastic estimate of the • Δ𝜃𝑖 ← Run forward/backward pass on the decoder (standard
gradient. With a large enough 𝐵, back propagation) using either ℒሚ 𝐴 or ℒሚ 𝐵 as the loss
we get enough samples from • Δ𝜃 ← Δ𝜃 + Δ𝜃𝑖
𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙) 𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙) • Update 𝜃
The E step
• We can use standard back
𝜕ℒ
𝑝(𝑥𝑖 |𝑧𝑖 , 𝜃) propagation to estimate
𝜕𝜃
𝜕ℒ
• How do we estimate ?
𝜕𝜙

𝑧𝑖 ~𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙) • The sampling step blocks the gradient


flow
• Computing the derivatives through 𝑞
? 𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙)
via the chain rule gives a very high
variance estimate of the gradient
Reparameterization
• Instead of drawing 𝑧𝑖 ~ 𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙),
let 𝑧𝑖 = g(𝜖𝑖 , 𝑥𝑖 , 𝜙), and draw 𝜖𝑖 ~ 𝑝(𝜖)
• 𝑧𝑖 is still a random variable but depends on 𝜙 deterministically
• Replace 𝔼𝑞(𝑧𝑖 |𝑥𝑖 ,𝜙) 𝑓 𝑧𝑖 with 𝔼𝑝(𝜖) [𝑓 g 𝜖𝑖 , 𝑥𝑖 , 𝜙 ]
• Example – univariate normal:
𝑎 ~ 𝒩 𝜇, 𝜎 2 is equivalent to
𝑎 = g 𝜖 , 𝜖 ~𝒩 0, 1 , g 𝑏 ≜ 𝜇 + 𝜎𝑏
Reparameterization

𝑝(𝑥𝑖 |𝑧𝑖 , 𝜃) 𝑝(𝑥𝑖 |𝑧𝑖 , 𝜃)

𝑧𝑖 ~𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙) 𝑧𝑖 = 𝑔(𝜖𝑖 , 𝑥𝑖 , 𝜙)

? 𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙) 𝑔(𝜖𝑖 , 𝑥𝑖 , 𝜙) 𝜖𝑖 ~ 𝑝(𝜖)


Full EM training procedure (not really used)
• For 𝑡 = 1: 𝑏: 𝑇
𝑝(𝑥𝑖 |𝑧𝑖 , 𝜃) • E Step
𝜕ℒ
• Estimate using standard back
𝜕𝜙
propagation with either ℒሚ 𝐴 or ℒሚ 𝐵 as the loss
𝑧𝑖 = 𝑔(𝜖𝑖 , 𝑥𝑖 , 𝜙) • Update 𝜙
• M Step
𝜕ℒ
• Estimate using
standard back
𝜕𝜃
𝜖𝑖 ~𝑝(𝜖) 𝑔(𝜖𝑖 , 𝑥𝑖 , 𝜙) propagation with either ℒሚ 𝐴 or ℒሚ 𝐵 as the loss
• Update 𝜃
Full training procedure
• For 𝑡 = 1: 𝑏: 𝑇
𝜕ℒ 𝜕ℒ
𝑝(𝑥𝑖 |𝑧𝑖 , 𝜃) • Estimate , with either ℒሚ 𝐴 or ℒሚ 𝐵 as the loss
𝜕𝜙 𝜕𝜃
• Update 𝜙, 𝜃

𝑧𝑖 = 𝑔(𝜖𝑖 , 𝑥𝑖 , 𝜙)
• Final simplification: update all of the
parameters at the same time instead of
using separate E, M steps
• This is standard back propagation. Just use
𝜖𝑖 ~𝑝(𝜖) 𝑔(𝜖𝑖 , 𝑥𝑖 , 𝜙) −ℒሚ 𝐴 or −ℒሚ 𝐵 as the loss, and run your
favorite SGD variant
Running the model on new data
• To get a MAP estimate of the latent variables, just use the mean
output by the encoder (for a Gaussian distribution)
• No need to take a sample
• Give the mean to the decoder
• At test time, this is used just as an auto-encoder
• You can optionally take multiple samples of the latent variables to
estimate the uncertainty
Relationship to Factor Analysis
• VAE performs probabilistic, non-linear
dimensionality reduction
𝑝(𝑥𝑖 |𝑧𝑖 , 𝜃) • It uses a generative model with a latent
variable distributed according to some
prior distribution 𝑝(𝑧𝑖 )
• The observed variable is distributed
𝑧𝑖 ~𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙) according to a conditional distribution
𝑝(𝑥𝑖 |𝑧𝑖 , 𝜃)
• Training is approximately running
expectation maximization to maximize
the data likelihood
𝑞(𝑧𝑖 |𝑥𝑖 , 𝜙)
• This can be seen as a non-linear version
of Factor Analysis
Regularization by a prior
• Looking at the form of ℒ we used to justify ℒሚ 𝐵 gives us additional
insight
ℒ 𝜙, 𝜃, 𝑥 = −KL 𝑞 𝑧 𝑥, 𝜙 || 𝑝 𝑧 + 𝔼𝑞 𝑧 𝑥, 𝜙 log 𝑝 𝑥|𝑧, 𝜃
• We are making the latent distribution as close as possible to a prior
on 𝑧
• While maximizing the conditional likelihood of the data under our
model
• In other words this is an approximation to Maximum Likelihood
Estimation regularized by a prior on the latent space
Practical advantages of a VAE vs. an AE
• The prior on the latent space:
• Allows you to inject domain knowledge
• Can make the latent space more interpretable
• The VAE also makes it possible to estimate the variance/uncertainty in
the predictions
Interpreting the latent space

https://arxiv.org/pdf/1610.00291.pdf
Requirements of the VAE
• Note that the VAE requires 2 tractable distributions to be used:
• The prior distribution 𝑝(𝑧) must be easy to sample from
• The conditional likelihood 𝑝 𝑥|𝑧, 𝜃 must be computable
• In practice this means that the 2 distributions of interest are often
simple, for example uniform, Gaussian, or even isotropic Gaussian
The blurry image problem
• The samples from the VAE
look blurry
• Three plausible
explanations for this
• Maximizing the
likelihood
• Restrictions on the
family of distributions
https://blog.openai.com/generative-models/
• The lower bound
approximation
The maximum likelihood explanation

• Recent evidence
suggests that this is
not actually the
problem
• GANs can be trained
with maximum
likelihood and still
generate sharp
examples

https://arxiv.org/pdf/1701.00160.pdf
Investigations of blurriness
• Recent investigations suggest that both the simple probability
distributions and the variational approximation lead to blurry images
• Kingma & colleages: Improving Variational Inference with Inverse
Autoregressive Flow
• Zhao & colleagues: Towards a Deeper Understanding of Variational
Autoencoding Models
• Nowozin & colleagues: f-gan: Training generative neural samplers
using variational divergence minimization

You might also like