Assignment-U-Net
Assignment-U-Net
Goals
In this notebook, you're going to implement a U-Net for a biomedical imaging segmentation
task. Specifically, you're going to be labeling neurons, so one might call this a neural neural
network! ;)
Note that this is not a GAN, generative model, or unsupervised learning task. This is a supervised
learning task, so there's only one correct answer (like a classifier!) You will see how this
component underlies the Generator component of Pix2Pix in the next notebook this week.
Learning Objectives
1. Implement your own U-Net.
2. Observe your U-Net's performance on a challenging segmentation task.
Getting Started
You will start by importing libraries, defining a visualization function, and getting the neural
dataset that you will be using.
Dataset
For this notebook, you will be using a dataset of electron microscopy images and segmentation
data. The information about the dataset you'll be using can be found here!
dataset example
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0)
U-Net Architecture
Now you can build your U-Net from its components. The figure below is from the paper, U-Net:
Convolutional Networks for Biomedical Image Segmentation, by Ronneberger et al. 2015. It
shows the U-Net architecture and how it contracts and then expands.
In other words, images are first fed through many convolutional layers which reduce height and
width while increasing the channels, which the authors refer to as the "contracting path." For
example, a set of two 2 x 2 convolutions with a stride of 2, will take a 1 x 28 x 28 (channels,
height, width) grayscale image and result in a 2 x 14 x 14 representation. The "expanding path"
does the opposite, gradually growing the image with fewer and fewer channels.
Contracting Path
You will first implement the contracting blocks for the contracting path. This path is the encoder
section of the U-Net, which has several downsampling steps as part of it. The authors give more
detail of the remaining parts in the following paragraph from the paper (Renneberger, 2015):
#UNIT TEST
def test_contracting_block(test_samples=100, test_channels=10,
test_size=50):
test_block = ContractingBlock(test_channels)
test_in = torch.randn(test_samples, test_channels, test_size,
test_size)
test_out_conv1 = test_block.conv1(test_in)
# Make sure that the first convolution has the right shape
assert tuple(test_out_conv1.shape) == (test_samples, test_channels
* 2, test_size - 2, test_size - 2)
# Make sure that the right activation is used
assert torch.all(test_block.activation(test_out_conv1) >= 0)
assert torch.max(test_block.activation(test_out_conv1)) >= 1
test_out_conv2 = test_block.conv2(test_out_conv1)
# Make sure that the second convolution has the right shape
assert tuple(test_out_conv2.shape) == (test_samples, test_channels
* 2, test_size - 4, test_size - 4)
test_out = test_block(test_in)
# Make sure that the pooling has the right shape
assert tuple(test_out.shape) == (test_samples, test_channels * 2,
test_size // 2 - 2, test_size // 2 - 2)
test_contracting_block()
test_contracting_block(10, 9, 8)
print("Success!")
Expanding Path
Next, you will implement the expanding blocks for the expanding path. This is the decoding
section of U-Net which has several upsampling steps as part of it. In order to do this, you'll also
need to write a crop function. This is so you can crop the image from the contracting path and
concatenate it to the current image on the expanding path—this is to form a skip connection.
Again, the details are from the paper (Renneberger, 2015):
Every step in the expanding path consists of an upsampling of the feature map
followed by a 2 x 2 convolution (“up-convolution”) that halves the number of feature
channels, a concatenation with the correspondingly cropped feature map from the
contracting path, and two 3 x 3 convolutions, each followed by a ReLU. The cropping is
necessary due to the loss of border pixels in every convolution.
Fun fact: later models based on this architecture often use padding in the convolutions to
prevent the size of the image from changing outside of the upsampling / downsampling steps!
#UNIT TEST
def test_expanding_block_crop(test_samples=100, test_channels=10,
test_size=100):
# Make sure that the crop function is the right shape
skip_con_x = torch.randn(test_samples, test_channels, test_size +
6, test_size + 6)
x = torch.randn(test_samples, test_channels, test_size, test_size)
cropped = crop(skip_con_x, x.shape)
assert tuple(cropped.shape) == (test_samples, test_channels,
test_size, test_size)
# Make sure that the crop function takes the right area
test_meshgrid = torch.meshgrid([torch.arange(0, test_size),
torch.arange(0, test_size)])
test_meshgrid = test_meshgrid[0] + test_meshgrid[1]
test_meshgrid = test_meshgrid[None, None, :, :].float()
cropped = crop(test_meshgrid, torch.Size([1, 1, test_size // 2,
test_size // 2]))
assert cropped.max() == (test_size - 1) * 2 - test_size // 2
assert cropped.min() == test_size // 2
assert cropped.mean() == test_size - 1
test_expanding_block_crop()
print("Success!")
#UNIT TEST
def test_expanding_block(test_samples=100, test_channels=10,
test_size=50):
test_block = ExpandingBlock(test_channels)
skip_con_x = torch.randn(test_samples, test_channels // 2,
test_size * 2 + 6, test_size * 2 + 6)
x = torch.randn(test_samples, test_channels, test_size, test_size)
x = test_block.upsample(x)
x = test_block.conv1(x)
# Make sure that the first convolution produces the right shape
assert tuple(x.shape) == (test_samples, test_channels // 2,
test_size * 2 - 1, test_size * 2 - 1)
orginal_x = crop(skip_con_x, x.shape)
x = torch.cat([x, orginal_x], axis=1)
x = test_block.conv2(x)
# Make sure that the second convolution produces the right shape
assert tuple(x.shape) == (test_samples, test_channels // 2,
test_size * 2 - 3, test_size * 2 - 3)
x = test_block.conv3(x)
# Make sure that the final convolution produces the right shape
assert tuple(x.shape) == (test_samples, test_channels // 2,
test_size * 2 - 5, test_size * 2 - 5)
x = test_block.activation(x)
test_expanding_block()
print("Success!")
Final Layer
Now you will write the final feature mapping block, which takes in a tensor with arbitrarily many
tensors and produces a tensor with the same number of pixels but with the correct number of
output channels. From the paper (Renneberger, 2015):
At the final layer a 1x1 convolution is used to map each 64-component feature vector
to the desired number of classes. In total the network has 23 convolutional layers.
# UNQ_C4 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED CLASS: FeatureMapBlock
class FeatureMapBlock(nn.Module):
'''
FeatureMapBlock Class
The final layer of a UNet -
maps each pixel to a pixel with the correct number of output
dimensions
using a 1x1 convolution.
Values:
input_channels: the number of channels to expect from a given
input
'''
def __init__(self, input_channels, output_channels):
super(FeatureMapBlock, self).__init__()
# "Every step in the expanding path consists of an upsampling
of the feature map"
#### START CODE HERE ####
self.conv = nn.Conv2d(None, None, kernel_size=None)
#### END CODE HERE ####
# UNIT TEST
assert tuple(FeatureMapBlock(10, 60)(torch.randn(1, 10, 10,
10)).shape) == (1, 60, 10, 10)
print("Success!")
U-Net
Now you can put it all together! Here, you'll write a UNet class which will combine a series of the
three kinds of blocks you've implemented.
#UNIT TEST
test_unet = UNet(1, 3)
assert tuple(test_unet(torch.randn(1, 1, 256, 256)).shape) == (1, 3,
117, 117)
print("Success!")
Training
Finally, you will put this into action! Remember that these are your parameters:
import torch.nn.functional as F
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
input_dim = 1
label_dim = 1
display_step = 20
batch_size = 4
lr = 0.0002
initial_shape = 512
target_shape = 373
device = 'cuda'
if cur_step % display_step == 0:
print(f"Epoch {epoch}: Step {cur_step}: U-Net loss:
{unet_loss.item()}")
show_tensor_images(
crop(real, torch.Size([len(real), 1, target_shape,
target_shape])),
size=(input_dim, target_shape, target_shape)
)
show_tensor_images(labels, size=(label_dim,
target_shape, target_shape))
show_tensor_images(torch.sigmoid(pred),
size=(label_dim, target_shape, target_shape))
cur_step += 1
train()