Assignment_CycleGAN
Assignment_CycleGAN
Goals
In this notebook, you will write a generative model based on the paper Unpaired Image-to-
Image Translation using Cycle-Consistent Adversarial Networks by Zhu et al. 2017, commonly
referred to as CycleGAN.
You will be training a model that can convert horses into zebras, and vice versa. Once again, the
emphasis of the assignment will be on the loss functions. In order for you to see good outputs
more quickly, you'll be training your model starting from a pre-trained checkpoint. You are also
welcome to train it from scratch on your own, if you so choose.
Learning Objectives
1. Implement the loss functions of a CycleGAN model.
2. Observe the two GANs used in CycleGAN.
Getting Started
You will start by importing libraries, defining a visualization function, and getting the pre-trained
CycleGAN checkpoint.
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0)
import glob
import random
import os
from torch.utils.data import Dataset
from PIL import Image
# Inspired by
https://github.com/aitorzip/PyTorch-CycleGAN/blob/master/datasets.py
class ImageDataset(Dataset):
def __init__(self, root, transform=None, mode='train'):
self.transform = transform
self.files_A = sorted(glob.glob(os.path.join(root, '%sA' %
mode) + '/*.*'))
self.files_B = sorted(glob.glob(os.path.join(root, '%sB' %
mode) + '/*.*'))
if len(self.files_A) > len(self.files_B):
self.files_A, self.files_B = self.files_B, self.files_A
self.new_perm()
assert len(self.files_A) > 0, "Make sure you downloaded the
horse2zebra images!"
def new_perm(self):
self.randperm = torch.randperm(len(self.files_B))
[:len(self.files_A)]
def __len__(self):
return min(len(self.files_A), len(self.files_B))
Generator
The code for a CycleGAN generator is much like Pix2Pix's U-Net with the addition of the residual
block between the encoding (contracting) and decoding (expanding) blocks.
Diagram of a CycleGAN generator: composed of encoding blocks, residual blocks, then decoding
blocks Diagram of a CycleGAN generator: composed of encoding blocks, residual blocks, and
then decoding blocks.
Residual Block
Perhaps the most notable architectural difference between the U-Net you used for Pix2Pix and
the architecture you're using for CycleGAN are the residual blocks. In CycleGAN, after the
expanding blocks, there are convolutional layers where the output is ultimately added to the
original input so that the network can change as little as possible on the image. You can think of
this transformation as a kind of skip connection, where instead of being concatenated as new
channels before the convolution which combines them, it's added directly to the output of the
convolution. In the visualization below, you can imagine the stripes being generated by the
convolutions and then added to the original image of the horse to transform it into a zebra.
These skip connections also allow the network to be deeper, because they help with vanishing
gradients issues that come when a neural network gets too deep and the gradients multiply in
backpropagation to become very small; instead, these skip connections enable more gradient
flow. A deeper network is often able to learn more complex features.
Residual block explanation: shows horse going through convolutions leading to stripes, added to
the original horse image to get a zebra
class ResidualBlock(nn.Module):
'''
ResidualBlock Class:
Performs two convolutions and an instance normalization, the input
is added
to this output to form the residual block output.
Values:
input_channels: the number of channels to expect from a given
input
'''
def __init__(self, input_channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(input_channels, input_channels,
kernel_size=3, padding=1, padding_mode='reflect')
self.conv2 = nn.Conv2d(input_channels, input_channels,
kernel_size=3, padding=1, padding_mode='reflect')
self.instancenorm = nn.InstanceNorm2d(input_channels)
self.activation = nn.ReLU()
class ContractingBlock(nn.Module):
'''
ContractingBlock Class
Performs a convolution followed by a max pool operation and an
optional instance norm.
Values:
input_channels: the number of channels to expect from a given
input
'''
def __init__(self, input_channels, use_bn=True, kernel_size=3,
activation='relu'):
super(ContractingBlock, self).__init__()
self.conv1 = nn.Conv2d(input_channels, input_channels * 2,
kernel_size=kernel_size, padding=1, stride=2, padding_mode='reflect')
self.activation = nn.ReLU() if activation == 'relu' else
nn.LeakyReLU(0.2)
if use_bn:
self.instancenorm = nn.InstanceNorm2d(input_channels * 2)
self.use_bn = use_bn
class ExpandingBlock(nn.Module):
'''
ExpandingBlock Class:
Performs a convolutional transpose operation in order to upsample,
class FeatureMapBlock(nn.Module):
'''
FeatureMapBlock Class
The final layer of a Generator -
maps each the output to the desired number of output channels
Values:
input_channels: the number of channels to expect from a given
input
output_channels: the number of channels to expect for a given
output
'''
def __init__(self, input_channels, output_channels):
super(FeatureMapBlock, self).__init__()
self.conv = nn.Conv2d(input_channels, output_channels,
kernel_size=7, padding=3, padding_mode='reflect')
CycleGAN Generator
Finally, you can put all the blocks together to create your CycleGAN generator.
class Generator(nn.Module):
'''
Generator Class
A series of 2 contracting blocks, 9 residual blocks, and 2
expanding blocks to
transform an input image into an image from the other class, with
an upfeature
layer at the start and a downfeature layer at the end.
Values:
input_channels: the number of channels to expect from a given
input
output_channels: the number of channels to expect for a given
output
'''
def __init__(self, input_channels, output_channels,
hidden_channels=64):
super(Generator, self).__init__()
self.upfeature = FeatureMapBlock(input_channels,
hidden_channels)
self.contract1 = ContractingBlock(hidden_channels)
self.contract2 = ContractingBlock(hidden_channels * 2)
res_mult = 4
self.res0 = ResidualBlock(hidden_channels * res_mult)
self.res1 = ResidualBlock(hidden_channels * res_mult)
self.res2 = ResidualBlock(hidden_channels * res_mult)
self.res3 = ResidualBlock(hidden_channels * res_mult)
self.res4 = ResidualBlock(hidden_channels * res_mult)
self.res5 = ResidualBlock(hidden_channels * res_mult)
self.res6 = ResidualBlock(hidden_channels * res_mult)
self.res7 = ResidualBlock(hidden_channels * res_mult)
self.res8 = ResidualBlock(hidden_channels * res_mult)
self.expand2 = ExpandingBlock(hidden_channels * 4)
self.expand3 = ExpandingBlock(hidden_channels * 2)
self.downfeature = FeatureMapBlock(hidden_channels,
output_channels)
self.tanh = torch.nn.Tanh()
PatchGAN Discriminator
Next, you will define the discriminator—a PatchGAN. It will be very similar to what you saw in
Pix2Pix.
class Discriminator(nn.Module):
'''
Discriminator Class
Structured like the contracting path of the U-Net, the
discriminator will
output a matrix of values classifying corresponding portions of
the image as real or fake.
Parameters:
input_channels: the number of image input channels
hidden_channels: the initial number of discriminator
convolutional filters
'''
def __init__(self, input_channels, hidden_channels=64):
super(Discriminator, self).__init__()
self.upfeature = FeatureMapBlock(input_channels,
hidden_channels)
self.contract1 = ContractingBlock(hidden_channels,
use_bn=False, kernel_size=4, activation='lrelu')
self.contract2 = ContractingBlock(hidden_channels * 2,
kernel_size=4, activation='lrelu')
self.contract3 = ContractingBlock(hidden_channels * 4,
kernel_size=4, activation='lrelu')
self.final = nn.Conv2d(hidden_channels * 8, 1, kernel_size=1)
Training Preparation
Now you can put everything together for training! You will start by defining your parameters:
• adv_criterion: an adversarial loss function to keep track of how well the GAN is fooling
the discriminator and how well the discriminator is catching the GAN
• recon_criterion: a loss function that rewards similar images to the ground truth, which
"reconstruct" the image
• n_epochs: the number of times you iterate through the entire dataset when training
• dim_A: the number of channels of the images in pile A
• dim_B: the number of channels of the images in pile B (note that in the visualization this
is currently treated as equivalent to dim_A)
• display_step: how often to display/visualize the images
• batch_size: the number of images per forward/backward pass
• lr: the learning rate
• target_shape: the size of the input and output images (in pixels)
• load_shape: the size for the dataset to load the images at before randomly cropping
them to target_shape as a simple data augmentation
• device: the device type
import torch.nn.functional as F
adv_criterion = nn.MSELoss()
recon_criterion = nn.L1Loss()
n_epochs = 20
dim_A = 3
dim_B = 3
display_step = 200
batch_size = 1
lr = 0.0002
load_shape = 286
target_shape = 256
device = 'cuda'
You will then load the images of the dataset while introducing some data augmentation (e.g.
crops and random horizontal flips).
transform = transforms.Compose([
transforms.Resize(load_shape),
transforms.RandomCrop(target_shape),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
import torchvision
dataset = ImageDataset("horse2zebra", transform=transform)
Next, you can initialize your generators and discriminators, as well as their optimizers. For
CycleGAN, you will have two generators and two discriminators since there are two GANs:
def weights_init(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
if isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
torch.nn.init.constant_(m.bias, 0)
# Feel free to change pretrained to False if you're training the model
from scratch
pretrained = True
if pretrained:
pre_dict = torch.load('cycleGAN_100000.pth')
gen_AB.load_state_dict(pre_dict['gen_AB'])
gen_BA.load_state_dict(pre_dict['gen_BA'])
gen_opt.load_state_dict(pre_dict['gen_opt'])
disc_A.load_state_dict(pre_dict['disc_A'])
disc_A_opt.load_state_dict(pre_dict['disc_A_opt'])
disc_B.load_state_dict(pre_dict['disc_B'])
disc_B_opt.load_state_dict(pre_dict['disc_B_opt'])
else:
gen_AB = gen_AB.apply(weights_init)
gen_BA = gen_BA.apply(weights_init)
disc_A = disc_A.apply(weights_init)
disc_B = disc_B.apply(weights_init)
Discriminator Loss
First, you're going to be implementing the discriminator loss. This is the same as in previous
assignments, so it should be a breeze :) Don't forget to detach your generator!
# UNIT TEST
test_disc_X = lambda x: x * 97
test_real_X = torch.tensor(83.)
test_fake_X = torch.tensor(89.)
test_adv_criterion = lambda x, y: x * 79 + y * 73
assert torch.abs((get_disc_loss(test_real_X, test_fake_X, test_disc_X,
test_adv_criterion)) - 659054.5000) < 1e-6
test_disc_X = lambda x: x.mean(0, keepdim=True)
test_adv_criterion = torch.nn.BCEWithLogitsLoss()
test_input = torch.ones(20, 10)
# If this runs, it's a pass - checks that the shapes are treated
correctly
get_disc_loss(test_input, test_input, test_disc_X, test_adv_criterion)
print("Success!")
Generator Loss
While there are some changes to the CycleGAN architecture from Pix2Pix, the most important
distinguishing feature of CycleGAN is its generator loss. You will be implementing that here!
Adversarial Loss
The first component of the generator's loss you're going to implement is its adversarial loss—
this once again is pretty similar to the GAN loss that you've implemented in the past. The
important thing to note is that the criterion now is based on least squares loss, rather than
binary cross entropy loss or W-loss.
# UNIT TEST
test_disc_Y = lambda x: x * 97
test_real_X = torch.tensor(83.)
test_gen_XY = lambda x: x * 89
test_adv_criterion = lambda x, y: x * 79 + y * 73
test_res = get_gen_adversarial_loss(test_real_X, test_disc_Y,
test_gen_XY, test_adv_criterion)
assert torch.abs(test_res[0] - 56606652) < 1e-6
assert torch.abs(test_res[1] - 7387) < 1e-6
test_disc_Y = lambda x: x.mean(0, keepdim=True)
test_adv_criterion = torch.nn.BCEWithLogitsLoss()
test_input = torch.ones(20, 10)
# If this runs, it's a pass - checks that the shapes are treated
correctly
get_gen_adversarial_loss(test_input, test_disc_Y, test_gen_XY,
test_adv_criterion)
print("Success!")
Identity Loss
Here you get to see some of the superbly new material! You'll want to measure the change in an
image when you pass the generator an example from the target domain instead of the input
domain it's expecting. The output should be the same as the input since it is already of the target
domain class. For example, if you put a horse through a zebra -> horse generator, you'd expect
the output to be the same horse because nothing needed to be transformed. It's already a horse!
You don't want your generator to be transforming it into any other thing, so you want to
encourage this behavior. In encouraging this identity mapping, the authors of CycleGAN found
that for some tasks, this helped properly preserve the colors of an image, even when the
expected input (here, a zebra) was put in. This was particularly useful for the photos <->
paintings mapping and, while an optional aesthetic component, you might find it useful for your
applications down the line.
Diagram showing a real horse image going through a zebra -> horse generator and the ideal
output being the same input image
# UNIT TEST
test_real_X = torch.tensor(83.)
test_gen_YX = lambda x: x * 89
test_identity_criterion = lambda x, y: (x + y) * 73
test_res = get_identity_loss(test_real_X, test_gen_YX,
test_identity_criterion)
assert torch.abs(test_res[0] - 545310) < 1e-6
assert torch.abs(test_res[1] - 7387) < 1e-6
print("Success!")
Diagram showing a real zebra image being transformed into a horse and then back into a zebra.
The output zebra should be the same as the input zebra.
Since you've already generated a fake image for the adversarial part, you can pass that fake
image back to produce a full cycle—this loss will encourage the cycle to preserve as much
information as possible.
Fun fact: Cycle consistency is a broader concept that's used outside of CycleGAN a lot too! It's
helped with data augmentation and has been used on text translation too, e.g. French -> English
-> French should get the same phrase back.
# UNIT TEST
test_real_X = torch.tensor(83.)
test_fake_Y = torch.tensor(97.)
test_gen_YX = lambda x: x * 89
test_cycle_criterion = lambda x, y: (x + y) * 73
test_res = get_cycle_consistency_loss(test_real_X, test_fake_Y,
test_gen_YX, test_cycle_criterion)
assert torch.abs(test_res[1] - 8633) < 1e-6
assert torch.abs(test_res[0] - 636268) < 1e-6
print("Success!")
# Total loss
#### END CODE HERE ####
return gen_loss, fake_A, fake_B
# UNIT TEST
test_real_A = torch.tensor(97)
test_real_B = torch.tensor(89)
test_gen_AB = lambda x: x * 83
test_gen_BA = lambda x: x * 79
test_disc_A = lambda x: x * 47
test_disc_B = lambda x: x * 43
test_adv_criterion = lambda x, y: x * 73 + y * 71
test_recon_criterion = lambda x, y: (x + y) * 61
test_lambda_identity = 59
test_lambda_cycle = 53
test_res = get_gen_loss(
test_real_A,
test_real_B,
test_gen_AB,
test_gen_BA,
test_disc_A,
test_disc_B,
test_adv_criterion,
test_recon_criterion,
test_recon_criterion,
test_lambda_identity,
test_lambda_cycle)
assert test_res[0].item() == 4047804560
assert test_res[1].item() == 7031
assert test_res[2].item() == 8051
print("Success!")
CycleGAN Training
Lastly, you can train the model and see some of your zebras, horses, and some that might not
quite look like either! Note that this training will take a long time, so feel free to use the pre-
trained checkpoint as an example of what a pretty-good CycleGAN does.
def train(save_model=False):
mean_generator_loss = 0
mean_discriminator_loss = 0
dataloader = DataLoader(dataset, batch_size=batch_size,
shuffle=True)
cur_step = 0