Unit 5
Unit 5
Variational Autoencoders
Mitesh M. Khapra
1/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Acknowledgments
Tutorial on Variational Autoencoders by Carl Doersch1
Blog on Variational Autoencoders by Jaan Altosaar2
1
Tutorial
2
Blog
2/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
3/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Module 21.1: Revisiting Autoencoders
4/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Before we start talking about VAEs, let us
X̂ quickly revisit autoencoders
W∗ An autoencoder contains an encoder which
takes the input X and maps it to a hidden
h representation
W The decoder then takes this hidden represent-
ation and tries to reconstruct the input from
X it as X̂
The training happens using the following ob-
h = g(W X + b) jective function
m n
X̂ = f (W ∗ h + c) 1 XX
min (x̂ij − xij )2
∗
W,W ,c,b m i=1 j=1
6/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Can we do generation with autoencoders ?
X̂ In other words, once the autoencoder is
W∗ trained can I remove the encoder, feed a hid-
den representation h to the decoder and de-
h code a X̂ from it ?
W In principle, yes! But in practice there is a
problem with this approach
X h is a very high dimensional vector and only
a few vectors in this space would actually cor-
h = g(W X + b) respond to meaningful latent representations
X̂ = f (W ∗ h + c) of our input
So of all the possible value of h which values
should I feed to the decoder (we had asked a
similar question before: slide 67, bullet 5 of
lecture 19)
7/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Ideally, we should only feed those values of h
X̂ which are highly likely
W∗ In other words, we are interested in sampling
from P (h|X) so that we pick only those h’s
h which have a high probability
But unlike RBMs, autoencoders do not have
such a probabilistic interpretation
They learn a hidden representation h but not
a distribution P (h|X)
X̂ = f (W ∗ h + c) Similarly the decoder is also deterministic and
does not learn a distribution over X (given a
h we can get a X but not P (X|h) )
8/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
We will now look at variational autoencoders which have the same structure as
autoencoders but they learn a distribution over the hidden variables
9/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Module 21.2: Variational Autoencoders: The Neural
Network Perspective
10/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Let {X = xi }N
i=1 be the training data
We can think of X as a random variable in Rn
For example, X could be an image and the
dimensions of X correspond to pixels of the
image
We are interested in learning an abstraction
Figure: Abstraction
(i.e., given an X find the hidden representa-
tion z)
We are also interested in generation (i.e.,
given a hidden representation generate an X)
In probabilistic terms we are interested in
P (z|X) and P (X|z) (to be consistent with the
Figure: Generation literation on VAEs we will use z instead of H
and X instead of V )
11/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Earlier we saw RBMs where we learnt P (z|X)
H∈ {0, 1}n and P (X|z)
c1 c2 cn Below we list certain characteristics of RBMs
h1 h2 ··· hn Structural assumptions: We assume cer-
tain independencies in the Markov Network
Computational: When training with Gibbs
w1,1 wm,n W ∈ Rm×n Sampling we have to run the Markov Chain
for many time steps which is expensive
Approximation: When using Contrastive
v1 v2 ··· vm Divergence, we approximate the expectation
by a point estimate
b1 b2 bm
V ∈ {0, 1}m (Nothing wrong with the above but we just
mention them to make the reader aware of
these characteristics)
12/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
We now return to our goals
Reconstruction: X̂
Goal 1: Learn a distribution over the latent
variables (Q(z|X))
Decoder Pφ (X|z) Goal 2: Learn a distribution over the visible
variables (P (X|z))
z VAEs use a neural network based encoder for
Goal 1
and a neural network based decoder for Goal
Encoder Qθ (z|X)
2
We will look at the encoder first
Data: X
14/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
X̂i
Now what about the decoder?
Pφ (X|z) The job of the decoder is to predict a probab-
ility distribution over X : P (X|z)
Once again we will assume a certain form for
z this distribution
Sample For example, if we want to predict 28 x 28
pixels and each pixel belongs to R (i.e., X ∈
R784 ) then what would be a suitable family
µ Σ for P (X|z)?
We could assume that P (X|z) is a Gaussian
Qθ (z|X) distribution with unit variance
Xi
The job of the decoder f would then be to
predict the mean of this distribution as fφ (z)
15/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
X̂i
What would be the objective function of the
Pφ (X|z) decoder ?
For any given training sample xi it should
maximize P (xi ) given by
z
ˆ
Sample P (xi ) = P (z)P (xi |z)dz
Qθ (z|X)
Xi
16/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
X̂i
This is the loss function for one data point
Pφ (X|z) (li (θ)) and we will just sum over all the data
points to get the total loss L (θ)
Xm
L (θ) = li (θ)
z i=1
Sample
In addition, we also want a constraint on the
distribution over the latent variables
µ Σ Specifically, we had assumed P (z) to be
N (0, I) and we want Q(z|X) to be as close
Qθ (z|X) to P (z) as possible
Thus, we will modify the loss function such
Xi that
KL divergence captures
the difference (or distance) li (θ, φ) = −Ez∼Qθ (z|xi ) [log Pφ (xi |z)]
between 2 distributions +KL(Qθ (z|xi )||P (z))
17/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
X̂i
The second term in the loss function can actually be
thought of as a regularizer
Pφ (X|z)
It ensures that the encoder does not cheat by mapping
each xi to a different point (a normal distribution with
very low variance) in the Euclidean space
z In other words, in the absence of the regularizer the
encoder can learn a unique mapping for each xi and
Sample
the decoder can then decode from this unique mapping
Even with high variance in samples from the distribu-
tion, we want the decoder to be able to reconstruct
µ Σ the original data very well (motivation similar to the
adding noise)
Qθ (z|X) To summarize, for each data point we predict a distri-
bution such that, with high probability a sample from
Xi
this distribution should be able to reconstruct the ori-
ginal data point
li (θ, φ) = −Ez∼Qθ (z|xi ) [log Pφ (xi |z)] But why do we choose a normal distribution? Isn’t
it too simplistic to assume that z follows a normal
+KL(Qθ (z|xi )||P (z)) distribution
18/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Isn’t it a very strong assumption that P (z) ∼
N (0, I) ?
For example, in the 2-dimensional case how
can we be sure that P (z) is a normal distri-
bution and not any other distribution
The key insight here is that any distribution
in d dimensions can be generated by the fol-
lowing steps
Step 1: Start with a set of d variables that are
normally distributed (that’s exactly what we
are assuming for P (z))
Step 2: Mapping these variables through a
li (θ, φ) = −Ez∼Qθ (z|xi ) [log Pφ (xi |z)] sufficiently complex function (that’s exactly
what the first few layers of the decoder can
+KL(Qθ (z|xi )||P (z)) do)
19/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
In particular, note that in the adjoining example if z
is 2-D and normally distributed then f (z) is roughly
ring shaped (giving us the distribution in the bottom
figure) z z
f (z) = +
10 ||z||
A non-linear neural network, such as the one we use
for the decoder, could learn a complex mapping from
z to fφ (z) using its parameters φ
The initial layers of a non linear decoder could learn
their weights such that the output is fφ (z)
The above argument suggests that even if we start with
normally distributed variables the initial layers of the
decoder could learn a complex transformation of these
variables say fφ (z) if required
The objective function of the decoder will ensure that
li (θ, φ) = −Ez∼Qθ (z|xi ) [log Pφ (xi |z)] an appropriate transformation of z is learnt to recon-
struct X
+KL(Qθ (z|xi )||P (z))
20/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Module 21.3: Variational autoencoders: (The graphical
model perspective)
21/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Here we can think of z and X as random vari-
z ables
We are then interested in the joint prob-
ability distribution P (X, z) which factorizes
X as P (X, z) = P (z)P (X|z)
N This factorization is natural because we can
imagine that the latent variables are fixed first
and then the visible variables are drawn based
on the latent variables
For example, if we want to draw a digit we
could first fix the latent variables: the digit,
size, angle, thickness, position and so on and
then draw a digit which corresponds to these
latent variables
And of course, unlike RBMs, this is a directed
graphical model
22/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Now at inference time, we are given an X (observed
variable) and we are interested in finding the most
z likely assignments of latent variables z which would
have resulted in this observation
Mathematically, we want to find
X
P (X|z)P (z)
P (z|X) =
N P (X)
24/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
But what is the objective function for this
z neural network
Well we want the proposed distribution
Qθ (z|X) to be as close to the true distribu-
X tion
N We can capture this using the following ob-
jective function
25/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Let us expand the KL divergence term
ˆ ˆ
D[Qθ (z|X)||P (z|X)] = Qθ (z|X) log Qθ (z|X)dz − Qθ (z|X) log P (z|X)dz
D[Qθ (z|X)||P (z|X)] = EQ [log Qθ (z|X) − log P (X|z) − log P (z) + log P (X)]
= EQ [log Qθ (z|X) − log P (z)] − EQ [log P (X|z)] + log P (X)
= D[Qθ (z|X)||p(z)] − EQ [log P (X|z)] + log P (X)
∴ log p(X) = EQ [log P (X|z)] − D[Qθ (z|X)||P (z)] + D[Qθ (z|X)||P (z|X)]
26/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
So, we have
log P (X) = EQ [log P (X|z)] − D[Qθ (z|X)||P (z)] + D[Qθ (z|X)||P (z|X)]
Recall that we are interested in maximizing the log likelihood of the data i.e.
P (X)
Since KL divergence (the red term) is always >= 0 we can say that
EQ [log P (X|z)] − D[Qθ (z|X)||P (z)] <= log P (X)
The quantity on the LHS is thus a lower bound for the quantity that we want
to maximize and is knows as the Evidence lower bound (ELBO)
Maximizing this lower bound is the same as maximizing log P (X) and hence
our equivalent objective now becomes
maximize EQ [log P (X|z)] − D[Qθ (z|X)||P (z)]
And, this method of learning parameters of probability distributions associ-
ated with graphical models using optimization (by maximizing ELBO) is called
variational inference
Why is this any easier? It is easy because of certain assumptions that we make
as discussed on the next slide 27/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
First we will just reintroduce the parameters in the
equation to make things explicit
29/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Now let us look at the other term in the ob-
jective function
X̂i n
X
EQ [log Pφ (X|z)]
Pφ (X|z)
i=1
Xi
30/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Further, as usual, we need to assume some
parametric form for P (X|z)
X̂i For example, if we assume that P (X|z) is a
Gaussian with mean µ(z) and variance I then
Pφ (X|z) 1
log P (X = Xi |z) = C − ||Xi − µ(z)||2
2
µ(z) in turn is a function of the parameters of
z the decoder and can be written as fφ (z)
1
Sample
log P (X = Xi |z) = C − ||Xi − fφ (z)||2
2
Our effective objective function thus becomes
µ Σ N
X 1
minimize (tr(Σ(Xi )) + (µ(Xi ))T [µ(Xi )) − k
θ,φ
n=1
2
Qθ (z|X)
− log det(Σ(Xi ))] + ||Xi − fφ (z)||2
Xi
31/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
The above loss can be easily computed and we
can update the parameters θ of the encoder
X̂i and φ of decoder using backpropagation
However, there is a catch !
Pφ (X|z)
The network is not end to end differentiable
because the output fφ (z) is not an end to end
differentiable function of the input X
z Why? because after passing X through the
Sample
Sample
network we simply compute µ(X) and Σ(X)
and then sample a z to be fed to the decoder
µ
This makes the entire process non-
Σ
deterministic and hence fφ (z) is not a
continuous function of the input X
Qθ (z|X)
Xi
32/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
VAEs use a neat trick to get around this prob-
lem
X̂i This is known as the reparameterization trick
wherein we move the process of sampling to
Pφ (X|z)
an input layer
For 1 dimensional case, given µ and σ we can
sample from N (µ, σ) by first sampling ∼
z N (0, 1), and then computing
Sample
z =µ+σ∗
34/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Abstraction
After the model parameters are learned we
feed a X to the encoder
X̂i
By doing a forward pass using the learned
Pφ (X|z) parameters of the model we compute µ(X)
and Σ(X)
z
We then sample a z from the distribution
+
µ(X) and Σ(X) or using the same reparamet-
∗ ∼ N (0, I)
erization trick
In other words, once we have obtained
µ(X) and Σ(X), we first sample ∼
µ Σ N (µ(X), Σ(X)) and then compute z
Qθ (z|X)
z =µ+σ∗
Xi
35/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21
Generation
After the model parameters are learned we re-
move the encoder and feed a z ∼ N (0, I) to
X̂i
the decoder
Pφ (X|z) The decoder will then predict fφ (z) and we
can draw an X ∼ N (fφ (z), I)
z
Why would this work ?
+
Well, we had trained the model to minimize
∗ ∼ N (0, I) D(Qθ (z|X)||p(z)) where p(z) was N (0, I)
If the model is trained well then Qθ (z|X)
should also become N (0, I)
µ Σ
Hence, if we feed z ∼ N (0, I), it is almost
as if we are feeding a z ∼ Qθ (z|X) and the
Qθ (z|X)
decoder was indeed trained to produce a good
fφ (z) from such a z
Xi
Hence this will work !
36/36
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 21