0% found this document useful (0 votes)
11 views16 pages

Image Segmentation ÔÇö A BeginnerÔÇÖs Guide - Medium

Uploaded by

pedro garcia
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
11 views16 pages

Image Segmentation ÔÇö A BeginnerÔÇÖs Guide - Medium

Uploaded by

pedro garcia
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 16

23/9/24, 18:23 Image Segmentation — A Beginner’s Guide | Medium

Search Write

Get unlimited access to the best of Medium for less than $1/week. Become a member

Image Segmentation — A
Beginner’s Guide
The essentials of Image Segmentation + implementation in
TensorFlow

Raj Pulapakura · Follow


6 min read · Feb 4, 2024

96

Image segmentation is a computer vision technique that assigns a label to


every pixel in an image such that pixels with the same label share certain
characteristics.

For example, in a street scene, all pixels belonging to cars might be labeled
with one color, while those belonging to the road might be labeled with
https://medium.com/@raj.pulapakura/image-segmentation-a-beginners-guide-0ede91052db7 1/22
23/9/24, 18:23 Image Segmentation — A Beginner’s Guide | Medium

another.

But to understand image segmentation and why it is useful, let’s go back to


basics….

Boring Classifiers

Cute doggo. Source

https://medium.com/@raj.pulapakura/image-segmentation-a-beginners-guide-0ede91052db7 2/22
23/9/24, 18:23 Image Segmentation — A Beginner’s Guide | Medium

Is there a cute little doggo in this picture? Of course there is.

This is a classification task. It tells us if there is a dog in the image.

But what if we want to know exactly where the dog is.

One approach is to draw a bounding box around the dog, which is called
Object Detection.

https://medium.com/@raj.pulapakura/image-segmentation-a-beginners-guide-0ede91052db7 3/22
23/9/24, 18:23 Image Segmentation — A Beginner’s Guide | Medium

Cute doggo + bounding box. Source + Author

If that’s all you want, then you’re done! But if you want to know exactly where
the dog is, on the pixel level, then you’ll need something better. That’s where
image segmentation comes into play.

Image Segmentation

https://medium.com/@raj.pulapakura/image-segmentation-a-beginners-guide-0ede91052db7 4/22
23/9/24, 18:23 Image Segmentation — A Beginner’s Guide | Medium

Street segmentation. Source

The core task of image segmentation is to classify each pixel in an image. In


the above street scene, there are 5 classes: road (pink), vehicles (red),
buildings (yellow), nature (green), sky (blue). Each pixel is assigned one of
these classes.
https://medium.com/@raj.pulapakura/image-segmentation-a-beginners-guide-0ede91052db7 5/22
23/9/24, 18:23 Image Segmentation — A Beginner’s Guide | Medium

But sometimes you want to be able to differentiate between different cars, or


different trees. To this end, there are 3 main types of image segmentation,
each providing a different level of detail and information.

Semantic vs. Instance vs. Panoptic

Semantic vs. Instance vs. Panoptic segmentation. Source

Semantic segmentation classifies each pixel based on its semantic class.


All the birds belong to the same class.
https://medium.com/@raj.pulapakura/image-segmentation-a-beginners-guide-0ede91052db7 6/22
23/9/24, 18:23 Image Segmentation — A Beginner’s Guide | Medium

Instance segmentation assigns unique labels to different instances, even


if they are of the same semantic class. Each bird belongs to a different
class.

Panoptic segmentation combines the two, providing both class-level and


instance-level labels. Each bird has its own class, but they are all
identified as a “bird”.

Cool, but how do we actually implement image segmentation?

There are a couple of ways, such as thresholding and clustering, but deep
learning (my fav) really takes the spotlight when it comes to image
segmentation.

Real-time body part panoptic segmentation. GIF from TensorFlow Blog


https://medium.com/@raj.pulapakura/image-segmentation-a-beginners-guide-0ede91052db7 7/22
23/9/24, 18:23 Image Segmentation — A Beginner’s Guide | Medium

U-Net
The U-Net architecture was initially designed for medical image
segmentation, but it has since been adapted for many other use cases.

U-Net. Image by author.

The U-Net has an encoder-decoder structure.

https://medium.com/@raj.pulapakura/image-segmentation-a-beginners-guide-0ede91052db7 8/22
23/9/24, 18:23 Image Segmentation — A Beginner’s Guide | Medium

The encoder is used to compress the input image into a latent space
representation through convolutions and downsampling.

The decoder is used to extrapolate the latent representation into a


segmented image, through convolutions and upsampling.

The long gray arrows running across the “U” are skip connections, and they
serve two main purposes:

1. During the forward pass, they enable the decoder to access information
from the encoder.

2. During the backward pass, they act as a “gradient superhighway” for


gradients from the decoder to flow to the encoder.

The output of the model has the same width and height as the input,
however the number of channels will be equal to the number of classes we
are segmenting.

Code it up
If you’re keen to code, let’s implement the U-Net architecture for semantic
segmentation in TensorFlow.
https://medium.com/@raj.pulapakura/image-segmentation-a-beginners-guide-0ede91052db7 9/22
23/9/24, 18:23 Image Segmentation — A Beginner’s Guide | Medium

U-Net Architecture
Defining the model architecture is rather straightforward.

from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D,


concatenate, Conv2DTranspose

def conv_block(x, n_filters):


"""two convolutions"""
x = Conv2D(n_filters, (3, 3), padding='same', activation='relu')(x)
x = Conv2D(n_filters, (3, 3), padding='same', activation='relu')(x)
return x

def encoder_block(x, n_filters):


"""conv block and max pooling"""
x = conv_block(x, n_filters)
p = MaxPooling2D((2, 2))(x)
return x, p # we will need x for the skip connections later

def decoder_block(x, p, n_filters):


"""upsample, skip connection, and conv block"""
x = Conv2DTranspose(n_filters, (2, 2), strides=(2, 2), padding='same')(x)
x = concatenate([x, p]) # concatenate = skip connection
x = conv_block(x, n_filters)
return x

def unet_model(n_classes, img_height, img_width, img_channels):


inputs = Input((img_height, img_width, img_channels)) # 512x512x3

# Contraction path, encoder


c1, p1 = encoder_block(inputs, n_filters=64) # c1=512x512x64 p1=256x256x64
c2, p2 = encoder_block(p1, n_filters=128) # c2=256x256x128 p2=128x128x128
c3, p3 = encoder_block(p2, n_filters=256) # c3=128x128x256 p3=64x64x256
https://medium.com/@raj.pulapakura/image-segmentation-a-beginners-guide-0ede91052db7 10/22
23/9/24, 18:23 Image Segmentation — A Beginner’s Guide | Medium

c4, p4 = encoder_block(p3, n_filters=512) # c4=64x64x512 p4=32x32x512

# Bottleneck
bridge = conv_block(p5, n_filters=1024) # bridge=32x32x1024

# Expansive path, decoder


u4 = decoder_block(bridge, p4, n_filters=512) # 64x64x512
u3 = decoder_block(u4, p3, n_filters=256) # 128x128x256
u2 = decoder_block(u3, p2, n_filters=128) # 256x256x128
u1 = decoder_block(u2, p1, n_filters=64) # 512x512x64

outputs = Conv2D(n_classes, (1, 1), activation='softmax')(u1) # 512x512xn_cla


# notice the softmax activation in the final layer

model = Model(inputs=[inputs], outputs=[outputs])

return model

# example classes: [road, vehicles, buildings, nature, background]


# instantiate model to predict 5 classes
unet_model = multi_unet_model(
n_classes=5,
img_height=IMG_HEIGHT,
img_width=IMG_WIDTH,
img_channels=3
)
# input: 512x512x3
# output: 512x512x5

Loss Function: Categorical Cross Entropy


How do we optimize this model? Well, since image segmentation is really
just classification on the pixel level, we can use the standard classification
https://medium.com/@raj.pulapakura/image-segmentation-a-beginners-guide-0ede91052db7 11/22
23/9/24, 18:23 Image Segmentation — A Beginner’s Guide | Medium

loss function, which is Categorical Cross Entropy.

model.compile(
loss="categorical_crossentropy",
categorical_crossentropy
)

We can interpret each pixel of the resulting (512x512x5) volume as a vector


of length 5. Since the last layer uses a softmax activation across the last
dimension, each pixel vector contains the probabilities of that pixel
belonging to each class.

https://medium.com/@raj.pulapakura/image-segmentation-a-beginners-guide-0ede91052db7 12/22
23/9/24, 18:23 Image Segmentation — A Beginner’s Guide | Medium

Intuition for model output

Before we can train the model, we need a dataset. The dataset should
contain (image, mask) pairs, where the image (x) is of shape (512x512x3) and
the mask (y) is of shape (512x512x5).

Here is an example ground truth mask:

https://medium.com/@raj.pulapakura/image-segmentation-a-beginners-guide-0ede91052db7 13/22
23/9/24, 18:23 Image Segmentation — A Beginner’s Guide | Medium

Image by Prince Canuma

Each pixel can only belong to one class, so it contains a “1” in one of the class
channels, and a “0” in the other channels. You can think of each pixel as a
one-hot vector (because that’s what it is).

Once you have your dataset prepared, you’re ready to train:

https://medium.com/@raj.pulapakura/image-segmentation-a-beginners-guide-0ede91052db7 14/22
23/9/24, 18:23 Image Segmentation — A Beginner’s Guide | Medium

model.fit(
train_ds,
validation_data=val_ds,
epochs=10,
)

Of course, this code would not be enough to run a successful model. If you
actually want to implement this, you need to consider preprocessing,
rescaling, batching etc.

I’ve prepared a Kaggle notebook which tackles car segmentation


(segmenting different parts of a car). It contains the complete code to run an
image segmentation model, so check it out here.

Final Notes
Class Imbalance: Often in image segmentation, there is severe class
imbalance. For example, in an average street view image, cars and
buildings take up a lot of pixels, but stop signs take up very few pixels.
The model has less data on stop signs, so it will perform poorly in
segmenting stop signs. To solve this, you can use Focal Categorical Cross
Entropy and class weights, which place emphasis on minority classes.

https://medium.com/@raj.pulapakura/image-segmentation-a-beginners-guide-0ede91052db7 15/22
23/9/24, 18:23 Image Segmentation — A Beginner’s Guide | Medium

Other Architectures: U-Net is not the only image segmentation


architecture, although it is conceptually the simplest. Others include
SegNet, Mask R-CNN, and PSPNet.

Binary Segmentation: If there is only one class your segmenting (e.g.


segmenting a brain tumor in an MRI scan), then the output of the model
only needs to be (512x512). For the mask, each pixel will contain a “1” if
that pixel belongs to a tumor, or “0” if that pixel does not belong to a
tumor. Make sure to also change “softmax” to “sigmoid” in the final
activation of the model, and use the (Focal) Binary Cross Entropy loss
function.

Thanks for reading!

Follow me for more great content:

📃 Medium
🌐 LinkedIn
📽️ YouTube
https://medium.com/@raj.pulapakura/image-segmentation-a-beginners-guide-0ede91052db7 16/22

You might also like