0% found this document useful (0 votes)
96 views17 pages

Variational AutoEncoders (VAE) With PyTorch - Alexander Van de Kleut

This document defines the architecture for a variational autoencoder (VAE) model. It includes encoder, decoder, and autoencoder classes. The encoder learns the mean and variance of the latent space and samples from the latent prior. The decoder reconstructs the input from the latent space. It trains the model on MNIST data and visualizes the latent space and image reconstructions/interpolations.

Uploaded by

markus.aurelius
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
96 views17 pages

Variational AutoEncoders (VAE) With PyTorch - Alexander Van de Kleut

This document defines the architecture for a variational autoencoder (VAE) model. It includes encoder, decoder, and autoencoder classes. The encoder learns the mean and variance of the latent space and samples from the latent prior. The decoder reconstructs the input from the latent space. It trains the model on MNIST data and visualizes the latent space and image reconstructions/interpolations.

Uploaded by

markus.aurelius
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 17

import torch; torch.

manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import torchvision
import numpy as np
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200

device = 'cuda' if torch.cuda.is_available() else 'cpu'

Encoder torch.nn.Module
__init__ forward

class Encoder(nn.Module):
def __init__(self, latent_dims):
super(Encoder, self).__init__()
self.linear1 = nn.Linear(784, 512)
self.linear2 = nn.Linear(512, latent_dims)

def forward(self, x):


x = torch.flatten(x, start_dim=1)
x = F.relu(self.linear1(x))
return self.linear2(x)

Decoder

class Decoder(nn.Module):
def __init__(self, latent_dims):
super(Decoder, self).__init__()
self.linear1 = nn.Linear(latent_dims, 512)
self.linear2 = nn.Linear(512, 784)

def forward(self, z):


z = F.relu(self.linear1(z))
z = torch.sigmoid(self.linear2(z))
return z.reshape((-1, 1, 28, 28))

Autoencoder
class Autoencoder(nn.Module):
def __init__(self, latent_dims):
super(Autoencoder, self).__init__()
self.encoder = Encoder(latent_dims)
self.decoder = Decoder(latent_dims)

def forward(self, x):


z = self.encoder(x)
return self.decoder(z)

def train(autoencoder, data, epochs=20):


opt = torch.optim.Adam(autoencoder.parameters())
for epoch in range(epochs):
for x, y in data:
x = x.to(device) # GPU
opt.zero_grad()
x_hat = autoencoder(x)
loss = ((x - x_hat)**2).sum()
loss.backward()
opt.step()
return autoencoder

latent_dims = 2
autoencoder = Autoencoder(latent_dims).to(device) # GPU

data = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./data',
transform=torchvision.transforms.ToTensor(),
download=True),
batch_size=128,
shuffle=True)

autoencoder = train(autoencoder, data)


def plot_latent(autoencoder, data, num_batches=100):
for i, (x, y) in enumerate(data):
z = autoencoder.encoder(x.to(device))
z = z.to('cpu').detach().numpy()
plt.scatter(z[:, 0], z[:, 1], c=y, cmap='tab10')
if i > num_batches:
plt.colorbar()
break

plot_latent(autoencoder, data)
def plot_reconstructed(autoencoder, r0=(-5, 10), r1=(-10, 5), n=12):
w = 28
img = np.zeros((n*w, n*w))
for i, y in enumerate(np.linspace(*r1, n)):
for j, x in enumerate(np.linspace(*r0, n)):
z = torch.Tensor([[x, y]]).to(device)
x_hat = autoencoder.decoder(z)
x_hat = x_hat.reshape(28, 28).to('cpu').detach().numpy()
img[(n-1-i)*w:(n-1-i+1)*w, j*w:(j+1)*w] = x_hat
plt.imshow(img, extent=[*r0, *r1])

plot_reconstructed(autoencoder)
plot_reconstructed
Decoder
Encoder
class VariationalEncoder(nn.Module):
def __init__(self, latent_dims):
super(VariationalEncoder, self).__init__()
self.linear1 = nn.Linear(784, 512)
self.linear2 = nn.Linear(512, latent_dims)
self.linear3 = nn.Linear(512, latent_dims)

self.N = torch.distributions.Normal(0, 1)
self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
self.N.scale = self.N.scale.cuda()
self.kl = 0

def forward(self, x):


x = torch.flatten(x, start_dim=1)
x = F.relu(self.linear1(x))
mu = self.linear2(x)
sigma = torch.exp(self.linear3(x))
z = mu + sigma*self.N.sample(mu.shape)
self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
return z

Encoder
VariationalEncoder

class VariationalAutoencoder(nn.Module):
def __init__(self, latent_dims):
super(VariationalAutoencoder, self).__init__()
self.encoder = VariationalEncoder(latent_dims)
self.decoder = Decoder(latent_dims)

def forward(self, x):


z = self.encoder(x)
return self.decoder(z)

autoencoder.encoder.kl
def train(autoencoder, data, epochs=20):
opt = torch.optim.Adam(autoencoder.parameters())
for epoch in range(epochs):
for x, y in data:
x = x.to(device) # GPU
opt.zero_grad()
x_hat = autoencoder(x)
loss = ((x - x_hat)**2).sum() + autoencoder.encoder.kl
loss.backward()
opt.step()
return autoencoder

vae = VariationalAutoencoder(latent_dims).to(device) # GPU


vae = train(vae, data)

plot_latent(vae, data)
plot_reconstructed(vae, r0=(-3, 3), r1=(-3, 3))
def interpolate(autoencoder, x_1, x_2, n=12):
z_1 = autoencoder.encoder(x_1)
z_2 = autoencoder.encoder(x_2)
z = torch.stack([z_1 + (z_2 - z_1)*t for t in np.linspace(0, 1, n)])
interpolate_list = autoencoder.decoder(z)
interpolate_list = interpolate_list.to('cpu').detach().numpy()

w = 28
img = np.zeros((w, n*w))
for i, x_hat in enumerate(interpolate_list):
img[:, i*w:(i+1)*w] = x_hat.reshape(28, 28)
plt.imshow(img)
plt.xticks([])
plt.yticks([])

x, y = data.__iter__().next() # hack to grab a batch


x_1 = x[y == 1][1].to(device) # find a 1
x_2 = x[y == 0][1].to(device) # find a 0

interpolate(vae, x_1, x_2, n=20)

interpolate(autoencoder, x_1, x_2, n=20)


from PIL import Image

def interpolate_gif(autoencoder, filename, x_1, x_2, n=100):


z_1 = autoencoder.encoder(x_1)
z_2 = autoencoder.encoder(x_2)

z = torch.stack([z_1 + (z_2 - z_1)*t for t in np.linspace(0, 1, n)])

interpolate_list = autoencoder.decoder(z)
interpolate_list = interpolate_list.to('cpu').detach().numpy()*255

images_list = [Image.fromarray(img.reshape(28, 28)).resize((256, 256)) for img in


interpolate_list]
images_list = images_list + images_list[::-1] # loop back beginning

images_list[0].save(
f'{filename}.gif',
save_all=True,
append_images=images_list[1:],
loop=1)

interpolate_gif(vae, "vae", x_1, x_2)


5 Comments 1 Login

G Join the discussion…

LOG IN WITH OR SIGN UP WITH DISQUS ?

Name

 4 Share Best Newest Oldest

T
Timilehin Ayanlade − ⚑
7 months ago edited

Great post Alexandar. I believe there is an oversight in the architecture for VAE. the sigma
symbol in particular. Here is an edited image you could easily replace it with
https://drive.google.com/ ...

1 0 Reply • Share ›

fairlix − ⚑
6 months ago edited

Hey there,

really helpful. This is the rst time I grasp VAEs.

I spotted a small oversight in the VAE architecture image:


μ(x) is there two times whil I think one should be labelled μ(x) and the other one σ(x).

edit: Oh Timilehin Ayanlade commented the same...

0 0 Reply • Share ›

Boris Burkov − ⚑
8 months ago

Dear Alexander, thank you for a great post.

I think, I noticed a little mistake: the picture, illustrating VAE has 2 vectors of expectation
instead of a vector of expectation and a vector of variance. Cheers!

0 0 Reply • Share ›

D Daniel Kleine − ⚑
10 months ago edited

Great article, thanks!

Can you please x the equation after "(...) which is given by" in the text?

0 0 Reply • Share ›
0 0 Reply • Share ›

Alan − ⚑
2 years ago

This is very accessible and I really enjoyed the visualizations, thanks!

0 0 Reply • Share ›

Subscribe Privacy Do Not Sell My Data

You might also like