Open In App

Image Segmentation Using TensorFlow

Last Updated : 11 Aug, 2025
Comments
Improve
Suggest changes
Like Article
Like
Report

Image segmentation is a computer method that breaks up a picture into different parts. It looks at the small details of each pixel (the tiny dots that make up the image) and decides what kind of thing it is like a pet, the pet’s outline or the background. The main goal is to give every pixel in a picture a label, so pixels that look alike are grouped together. This way, a computer can know exactly what is in the image and where things are.

  • In regular classification, a computer just says what the whole picture is (like “cat” or “dog”).
  • In object detection, the computer draws boxes around things it finds.
  • Segmentation shows the exact shape of objects.

Step-by-Step Image Segmentation

Let's see the image segmentation using TensorFlow,

Step 1: Import Libraries

We will import the required libraries,

Python
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras

Step 2: Load the Dataset

We load and split the Oxford-IIIT Pet data.

Python
dataset, info = tfds.load('oxford_iiit_pet:4.*.*', with_info=True)

Output:

dataset_loading
Loading the Oxford-IIIT dataset

Step 3: Set Constants

We set the constants that will be used,

  • Batch size and buffer control training efficiency and randomization.
  • Width/height standardize images for VGG16.
Python
BATCH_SIZE = 64
BUFFER_SIZE = 1000
width, height = 224, 224
TRAIN_LENGTH = info.splits['train'].num_examples
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

Step 4: Data Preprocessing and Augmentation

We perform the data preprocessing,

  • Converts image pixels to float and scales between 0–1.
  • Masks start from zero for correct class indexing.
  • Resizes images and masks.
  • Random flip adds variety for robust training.
Python
def normalize(input_image, input_mask):
    img = tf.cast(input_image, dtype=tf.float32) / 255.0
    input_mask -= 1
    return img, input_mask


@tf.function
def load_train_ds(example):
    img = tf.image.resize(example['image'], (width, height))
    mask = tf.image.resize(example['segmentation_mask'], (width, height))
    if tf.random.uniform(()) > 0.5:
        img = tf.image.flip_left_right(img)
        mask = tf.image.flip_left_right(mask)
    img, mask = normalize(img, mask)
    return img, mask

Step 5: Build Data Pipelines

We prepare the data pipelines,

  • map: Applies preprocessing to each sample.
  • cache, shuffle, batch, repeat, prefetch: Optimize data loading and training throughput.
Python
train = dataset['train'].map(
    load_train_ds, num_parallel_calls=tf.data.AUTOTUNE)
test = dataset['test'].map(load_train_ds)
train_ds = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_ds = train_ds.prefetch(buffer_size=tf.data.AUTOTUNE)
test_ds = test.batch(BATCH_SIZE)

Step 6: Visualize the Data

We visualize the input, ground-truth mask and prediction side by side for easy comparison

Python
def display_images(display_list):
    plt.figure(figsize=(15, 15))
    titles = ['Input Image', 'True Mask', 'Predicted Mask']
    for i, image in enumerate(display_list):
        plt.subplot(1, len(display_list), i + 1)
        plt.title(titles[i])
        plt.imshow(keras.preprocessing.image.array_to_img(image))
        plt.axis('off')
    plt.show()

for img, mask in train.take(1):
    display_images([img, mask])

Output:

input-image-display
Input Image

Step 7: Model Construction

We build a model with VGG16+ FCN-like Decoder,

  • Uses pre-trained VGG16 for feature extraction.
  • Only extracts essential intermediate layers.
  • Frozen weights ensure transfer learning stability.
  • Decoder upsamples deep features, merges skip connections and produces pixel-wise class probabilities.
Python
base_model = keras.applications.vgg16.VGG16(
    include_top=False, input_shape=(width, height, 3))
layer_names = ['block1_pool', 'block2_pool',
               'block3_pool', 'block4_pool', 'block5_pool']
base_model_outputs = [base_model.get_layer(
    name).output for name in layer_names]
base_model.trainable = False
VGG_16 = keras.Model(inputs=base_model.input, outputs=base_model_outputs)


def fcn8_decoder(convs, n_classes):
    f1, f2, f3, f4, p5 = convs
    n = 4096
    c6 = keras.layers.Conv2D(n, (7, 7), activation='relu', padding='same')(p5)
    c7 = keras.layers.Conv2D(n, (1, 1), activation='relu', padding='same')(c6)
    f5 = c7
    o = keras.layers.Conv2DTranspose(
        n_classes, (4, 4), strides=(2, 2), use_bias=False)(f5)
    o = keras.layers.Cropping2D((1, 1))(o)
    o2 = keras.layers.Conv2D(
        n_classes, (1, 1), activation='relu', padding='same')(f4)
    o = keras.layers.Add()([o, o2])
    o = keras.layers.Conv2DTranspose(
        n_classes, (4, 4), strides=(2, 2), use_bias=False)(o)
    o = keras.layers.Cropping2D((1, 1))(o)
    o2 = keras.layers.Conv2D(
        n_classes, (1, 1), activation='relu', padding='same')(f3)
    o = keras.layers.Add()([o, o2])
    o = keras.layers.Conv2DTranspose(
        n_classes, (8, 8), strides=(8, 8), use_bias=False)(o)
    o = keras.layers.Activation('softmax')(o)
    return o

Output:

building-model
Building the Model

Step 8: Build and Compile Segmentation Model

We build the segmentation model which defines, connects and compiles the full pipeline into a trainable segmentation network.

Python
def segmentation_model():
    inputs = keras.layers.Input(shape=(width, height, 3))
    convs = VGG_16(inputs)
    outputs = fcn8_decoder(convs, 3)
    return keras.Model(inputs, outputs)


model = segmentation_model()
model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

Step 9: Train the Model

We train the model for 15 epochs, reporting validation results at intervals.

Python
EPOCHS = 15
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples // BATCH_SIZE // VAL_SUBSPLITS

model_history = model.fit(
    train_ds, epochs=EPOCHS,
    steps_per_epoch=STEPS_PER_EPOCH,
    validation_data=test_ds,
    validation_steps=VALIDATION_STEPS
)

Output:

Training
Training the Model

Step 10: Predict and visualize the Results

Model makes the predictions and we visualize it,

  • Converts model output to a simple mask for visualization.
  • Displays results for sample images to verify segmentation performance.
Python
def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]


def show_predictions(dataset=None, num=1):
    for image, mask in dataset.take(num):
        pred_mask = model.predict(image)
        display_images([image[0], mask[0], create_mask(pred_mask)])

Output:

prediction
Prediction

Step 11: Compute Segmentation Metrics

We compute the segmentation metrics which measures performance using overlap (IoU) and consolidation (Dice Score)metrics which are critical for segmentation success.

Python
def compute_metrics(y_true, y_pred):
    class_wise_iou, class_wise_dice_score = [], []
    smooth = 1e-5
    for i in range(3):
        intersection = np.sum((y_pred == i) & (y_true == i))
        y_true_area = np.sum(y_true == i)
        y_pred_area = np.sum(y_pred == i)
        combined_area = y_true_area + y_pred_area
        iou = (intersection + smooth) / (combined_area - intersection + smooth)
        dice = 2 * (intersection + smooth) / (combined_area + smooth)
        class_wise_iou.append(iou)
        class_wise_dice_score.append(dice)
    return class_wise_iou, class_wise_dice_score

Output:

prediction-and-Segmentation-metrics
IoU and Dice Score

We used TensorFlow and the Oxford-IIIT Pet Dataset to build a deep learning image segmentation model that assigns class labels to every pixel, allowing us to accurately separate pet images into distinct regions. Through a step-by-step pipeline, covering data preparation, model design using a VGG16 encoder and FCN-style decoder, training and evaluation, we demonstrated how raw image data can be turned into detailed, pixel-level segmentations, providing both clear visual results and reliable quantitative metrics for assessing model performance.


Image Segmentation Using TensorFlow

Similar Reads