C3W1 Data Augmentation Assignment
C3W1 Data Augmentation Assignment
Goals
In this notebook you're going to build a generator that can be used to help create data to train a
classifier. There are many cases where this might be useful. If you are interested in any of these
topics, you are welcome to explore the linked papers and articles!
• With smaller datasets, GANs can provide useful data augmentation that substantially
improve classifier performance.
• You have one type of data already labeled and would like to make predictions on another
related dataset for which you have no labels. (You'll learn about the techniques for this
use case in future notebooks!)
• You want to protect the privacy of the people who provided their information so you can
provide access to a generator instead of real data.
• You have input data with many missing values, where the input dimensions are
correlated and you would like to train a model on complete inputs.
• You would like to be able to identify a real-world abnormal feature in an image for the
purpose of diagnosis, but have limited access to real examples of the condition.
In this assignment, you're going to be acting as a bug enthusiast — more on that later.
Learning Objectives
1. Understand some use cases for data augmentation and why GANs suit this task.
2. Implement a classifier that takes a mixed dataset of reals/fakes and analyze its accuracy.
Getting Started
Data Augmentation
Before you implement GAN-based data augmentation, you should know a bit about data
augmentation in general, specifically for image datasets. It is very common practice to augment
image-based datasets in ways that are appropriate for a given dataset. This may include having
your dataloader randomly flipping images across their vertical axis, randomly cropping your
image to a particular size, randomly adding a bit of noise or color to an image in ways that are
true-to-life.
In general, data augmentation helps to stop your model from overfitting to the data, and allows
you to make small datasets many times larger. However, a sufficiently powerful classifier often
still overfits to the original examples which is why GANs are particularly useful here. They can
generate new images instead of simply modifying existing ones.
CIFAR
The CIFAR-10 and CIFAR-100 datasets are extremely widely used within machine learning --
they contain many thousands of “tiny” 32x32 color images of different classes representing
relatively common real-world objects like airplanes and dogs, with 10 classes in CIFAR-10 and
100 classes in CIFAR-100. In CIFAR-100, there are 20 “superclasses” which each contain five
classes. For example, the “fish” superclass contains “aquarium fish, flatfish, ray, shark, trout”.
For the purposes of this assignment, you’ll be looking at a small subset of these images to
simulate a small data regime, with only 40 images of each class for training.
alt text
Initializations
You will begin by importing some useful libraries and packages and defining a visualization
function that has been provided. You will also be re-using your conditional generator and
functions code from earlier assignments. This will let you control what class of images to
augment for your classifier.
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
torch.manual_seed(0) # Set for our testing purposes, please do not
change!
Generator
class Generator(nn.Module):
'''
Generator Class
Values:
input_dim: the dimension of the input vector, a scalar
im_chan: the number of channels of the output image, a scalar
(CIFAR100 is in color (red, green, blue), so 3 is your
default)
hidden_dim: the inner dimension, a scalar
'''
def __init__(self, input_dim=10, im_chan=3, hidden_dim=64):
super(Generator, self).__init__()
self.input_dim = input_dim
# Build the neural network
self.gen = nn.Sequential(
self.make_gen_block(input_dim, hidden_dim * 4,
kernel_size=4),
self.make_gen_block(hidden_dim * 4, hidden_dim * 2,
kernel_size=4, stride=1),
self.make_gen_block(hidden_dim * 2, hidden_dim,
kernel_size=4),
self.make_gen_block(hidden_dim, im_chan, kernel_size=2,
final_layer=True),
)
Training
Now you can begin training your models. First, you will define some new parameters:
• cifar100_shape: the number of pixels in each CIFAR image, which has dimensions 32 x 32
and three channel (for red, green, and blue) so 3 x 32 x 32
• n_classes: the number of classes in CIFAR100 (e.g. airplane, automobile, bird, cat, deer,
dog, frog, horse, ship, truck)
cifar100_shape = (3, 32, 32)
n_classes = 100
And you also include the same parameters from previous assignments:
Then, you want to set your generator's input dimension. Recall that for conditional GANs, the
generator's input is the noise vector concatenated with the class vector.
Classifier
For the classifier, you will use the same code that you wrote in an earlier assignment (the same
as previous code for the discriminator as well since the discriminator is a real/fake classifier).
class Classifier(nn.Module):
'''
Classifier Class
Values:
im_chan: the number of channels of the output image, a scalar
n_classes: the total number of classes in the dataset, an
integer scalar
hidden_dim: the inner dimension, a scalar
'''
def __init__(self, im_chan, n_classes, hidden_dim=32):
super(Classifier, self).__init__()
self.disc = nn.Sequential(
self.make_classifier_block(im_chan, hidden_dim),
self.make_classifier_block(hidden_dim, hidden_dim * 2),
self.make_classifier_block(hidden_dim * 2, hidden_dim *
4),
self.make_classifier_block(hidden_dim * 4, n_classes,
final_layer=True),
)
Pre-training (Optional)
You are provided the code to pre-train the models (GAN and classifier) given to you in this
assignment. However, this is intended only for your personal curiosity -- for the assignment to
run as intended, you should not use any checkpoints besides the ones given to you.
# This code is here for you to train your own generator or classifier
# outside the assignment on the full dataset if you'd like -- for the
purposes
# of this assignment, please use the provided checkpoints
class Discriminator(nn.Module):
'''
Discriminator Class
Values:
im_chan: the number of channels of the output image, a scalar
(MNIST is black-and-white, so 1 channel is your default)
hidden_dim: the inner dimension, a scalar
'''
def __init__(self, im_chan=3, hidden_dim=64):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
self.make_disc_block(im_chan, hidden_dim, stride=1),
self.make_disc_block(hidden_dim, hidden_dim * 2),
self.make_disc_block(hidden_dim * 2, hidden_dim * 4),
self.make_disc_block(hidden_dim * 4, 1, final_layer=True),
)
def train_generator():
gen = Generator(generator_input_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
discriminator_input_dim = cifar100_shape[0] + n_classes
disc = Discriminator(discriminator_input_dim).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)
def weights_init(m):
if isinstance(m, nn.Conv2d) or isinstance(m,
nn.ConvTranspose2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
if isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)
criterion = nn.BCEWithLogitsLoss()
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
for epoch in range(n_epochs):
# Dataloader returns the batches and the labels
for real, labels in dataloader:
cur_batch_size = len(real)
# Flatten the batch of real images from the dataset
real = real.to(device)
disc_fake_loss = criterion(disc_fake_pred,
torch.zeros_like(disc_fake_pred))
disc_real_loss = criterion(disc_real_pred,
torch.ones_like(disc_real_pred))
disc_loss = (disc_fake_loss + disc_real_loss) / 2
disc_loss.backward(retain_graph=True)
disc_opt.step()
disc_fake_pred = disc(fake_image_and_labels)
gen_loss = criterion(disc_fake_pred,
torch.ones_like(disc_fake_pred))
gen_loss.backward()
gen_opt.step()
def train_classifier():
criterion = nn.CrossEntropyLoss()
n_epochs = 10
validation_dataloader = DataLoader(
CIFAR100(".", train=False, download=True,
transform=transform),
batch_size=batch_size)
display_step = 10
batch_size = 512
lr = 0.0002
if cur_step % display_step == 0:
classifier_val_loss = 0
classifier_correct = 0
num_validation = 0
for val_example, val_label in validation_dataloader:
cur_batch_size = len(val_example)
num_validation += cur_batch_size
val_example = val_example.to(device)
val_label = val_label.to(device)
labels_hat = classifier(val_example)
classifier_val_loss += criterion(labels_hat,
val_label) * cur_batch_size
classifier_correct += (labels_hat.argmax(1) ==
val_label).float().sum()
As a bug master, you want a classifier capable of classifying different species of bugs: bees,
beetles, butterflies, caterpillar, and more. Luckily, you found a great dataset with a lot of animal
species and objects, and you trained your classifier on that.
But the bug classes don't do as well as you would like. Now your plan is to train a GAN on the
same data so it can generate new bugs to make your classifier better at distinguishing between
all of your favorite bugs!
You will fine-tune your model by augmenting the original real data with fake data and during
that process, observe how to increase the accuracy of your classifier with these fake, GAN-
generated bugs. After this, you will prove your worth as a bug master.
Sampling Ratio
Suppose that you've decided that although you have this pre-trained general generator and this
general classifier, capable of identifying 100 classes with some accuracy (~17%), what you'd
really like is a model that can classify the five different kinds of bugs in the dataset. You'll fine-
tune your model by augmenting your data with the generated images. Keep in mind that both
the generator and the classifier were trained on the same images: the 40 images per class you
painstakingly found so your generator may not be great. This is the caveat with data
augmentation, ultimately you are still bound by the real data that you have but you want to try
and create more. To make your models even better, you would need to take some more bug
photos, label them, and add them to your training set and/or use higher quality photos.
To start, you'll first need to write some code to sample a combination of real and generated
images. Given a probability, p_real, you'll need to generate a combined tensor where roughly
p_real of the returned images are sampled from the real images. Note that you should not
interpolate the images here: you should choose each image from the real or fake set with a given
probability. For example, if your real images are a tensor of [[1, 2, 3, 4, 5]] and your
fake images are a tensor of [[-1, -2, -3, -4, -5]], and p_real = 0.2, two potential
random return values are [[1, -2, 3, -4, -5]] or [[-1, 2, -3, -4, -5]].
Notice that p_real = 0.2 does not guarantee that exactly 20% of the samples are real, just
that when choosing an image for the combined set, there is a 20% probability that that image
will be chosen from the real images, and an 80% probability that it will be selected from the
fake images.
In addition, we will expect the images to remain in the same order to maintain their alignment
with their labels (this applies to the fake images too!).
n_test_samples = 9999
test_combination = combine_sample(
torch.ones(n_test_samples, 1),
torch.zeros(n_test_samples, 1),
0.3
)
# Check that the shape is right
assert tuple(test_combination.shape) == (n_test_samples, 1)
# Check that the ratio is right
assert torch.abs(test_combination.mean() - 0.3) < 0.05
# Make sure that no mixing happened
assert test_combination.median() < 1e-5
test_combination = combine_sample(
torch.ones(n_test_samples, 10, 10),
torch.zeros(n_test_samples, 10, 10),
0.8
)
# Check that the shape is right
assert tuple(test_combination.shape) == (n_test_samples, 10, 10)
# Make sure that no mixing happened
assert torch.abs((test_combination.sum([1, 2]).median()) - 100) < 1e-5
When you're training a generator, you will often have to look at different checkpoints and
choose one that does the best (either empirically or using some evaluation method). Here, you
are given four generator checkpoints: gen_1.pt, gen_2.pt, gen_3.pt, gen_4.pt. You'll also
have some scratch area to write whatever code you'd like to solve this problem, but you must
return a p_real and an image name of your selected generator checkpoint. You can hard-
code/brute-force these numbers if you would like, but you are encouraged to try to solve this
problem in a more general way. In practice, you would also want a test set (since it is possible to
overfit on a validation set), but for simplicity you can just focus on the validation set.
train_set = torch.load("insect_train.pt")
val_set = torch.load("insect_val.pt")
dataloader = DataLoader(
torch.utils.data.TensorDataset(train_set["images"],
train_set["labels"]),
batch_size=batch_size,
shuffle=True
)
validation_dataloader = DataLoader(
torch.utils.data.TensorDataset(val_set["images"],
val_set["labels"]),
batch_size=batch_size
)
display_step = 1
lr = 0.0002
n_epochs = 20
classifier_opt = torch.optim.Adam(classifier.parameters(), lr=lr)
cur_step = 0
best_score = 0
for epoch in range(n_epochs):
for real, labels in dataloader:
real = real.to(device)
# Flatten the image
labels = labels.to(device)
one_hot_labels = get_one_hot_labels(labels.to(device),
n_classes).float()
You'll likely find that the worst performance is when the generator is performing alone: this
corresponds to the case where you might be trying to hide the underlying examples from the
classifier. Perhaps you don't want other people to know about your specific bugs!
accuracies = []
p_real_all = torch.linspace(0, 1, 21)
for p_real_vis in tqdm(p_real_all):
accuracies += [eval_augmentation(p_real_vis, best_gen_name,
n_test=4)]
plt.plot(p_real_all.tolist(), accuracies)
plt.ylabel("Accuracy")
_ = plt.xlabel("Percent Real Images")
Here's a visualization of what the generator is actually generating, with real examples of each
class above the corresponding generated image.
one_hot_labels = get_one_hot_labels(train_labels.to(device),
n_classes).float()
fake_noise = get_noise(len(train_images), z_dim, device=device)
noise_and_labels = combine_vectors(fake_noise, one_hot_labels)
gen = Generator(generator_input_dim).to(device)
gen.load_state_dict(torch.load(best_gen_name))
fake = gen(noise_and_labels)
show_tensor_images(torch.cat([train_images.cpu(), fake.cpu()]))