Finalised Question 2
Finalised Question 2
Overview
This notebook provides a helper function to load in the Oxford-IIIT Pets dataset suitable for classification and semantic segmentation, to help with
Assignment 1B, Question 2.
It also provides an example of how to load in the MobileNetV3Small Network which you are required to fine tune for the second part of the
question.
Please read the comments and instructions within this notebook. It has been carefully designed to help you with many of the tasks required.
Please make sure you read the assignment brief on canvas, and check the FAQ for other information.
import tensorflow as tf
import keras
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, concatenate, BatchNormalization, SpatialDropout2D, Act
from keras.models import Model
import numpy as np
import pandas as pd
import tensorflow_datasets as tfds
import glob
Requirement already satisfied: tensorflow_datasets in c:\users\acer\anaconda3\lib\site-packages (4.9.8)
Requirement already satisfied: tqdm in c:\users\acer\anaconda3\lib\site-packages (from tensorflow_datasets) (4.64.1)
Requirement already satisfied: requests>=2.19.0 in c:\users\acer\anaconda3\lib\site-packages (from tensorflow_datasets) (2.28.1)
Requirement already satisfied: promise in c:\users\acer\anaconda3\lib\site-packages (from tensorflow_datasets) (2.3)
Requirement already satisfied: protobuf>=3.20 in c:\users\acer\anaconda3\lib\site-packages (from tensorflow_datasets) (4.21.12)
Requirement already satisfied: simple_parsing in c:\users\acer\anaconda3\lib\site-packages (from tensorflow_datasets) (0.1.7)
Requirement already satisfied: termcolor in c:\users\acer\anaconda3\lib\site-packages (from tensorflow_datasets) (2.5.0)
Requirement already satisfied: toml in c:\users\acer\anaconda3\lib\site-packages (from tensorflow_datasets) (0.10.2)
Requirement already satisfied: wrapt in c:\users\acer\anaconda3\lib\site-packages (from tensorflow_datasets) (1.14.1)
Requirement already satisfied: absl-py in c:\users\acer\anaconda3\lib\site-packages (from tensorflow_datasets) (2.1.0)
Requirement already satisfied: pyarrow in c:\users\acer\anaconda3\lib\site-packages (from tensorflow_datasets) (20.0.0)
Requirement already satisfied: psutil in c:\users\acer\anaconda3\lib\site-packages (from tensorflow_datasets) (5.9.0)
Requirement already satisfied: etils[edc,enp,epath,epy,etree]>=1.6.0 in c:\users\acer\anaconda3\lib\site-packages (from tensorfl
ow_datasets) (1.12.2)
Requirement already satisfied: dm-tree in c:\users\acer\anaconda3\lib\site-packages (from tensorflow_datasets) (0.1.9)
Requirement already satisfied: tensorflow-metadata in c:\users\acer\anaconda3\lib\site-packages (from tensorflow_datasets) (1.1
7.1)
Requirement already satisfied: numpy in c:\users\acer\anaconda3\lib\site-packages (from tensorflow_datasets) (1.26.4)
Requirement already satisfied: immutabledict in c:\users\acer\anaconda3\lib\site-packages (from tensorflow_datasets) (4.2.1)
Requirement already satisfied: typing_extensions in c:\users\acer\anaconda3\lib\site-packages (from etils[edc,enp,epath,epy,etre
e]>=1.6.0->tensorflow_datasets) (4.12.2)
Requirement already satisfied: zipp in c:\users\acer\anaconda3\lib\site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0->te
nsorflow_datasets) (3.11.0)
Requirement already satisfied: fsspec in c:\users\acer\anaconda3\lib\site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0->
tensorflow_datasets) (2022.11.0)
Requirement already satisfied: importlib_resources in c:\users\acer\anaconda3\lib\site-packages (from etils[edc,enp,epath,epy,et
ree]>=1.6.0->tensorflow_datasets) (6.5.2)
Requirement already satisfied: einops in c:\users\acer\anaconda3\lib\site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0->
tensorflow_datasets) (0.8.1)
Requirement already satisfied: certifi>=2017.4.17 in c:\users\acer\anaconda3\lib\site-packages (from requests>=2.19.0->tensorflo
w_datasets) (2024.12.14)
Requirement already satisfied: idna<4,>=2.5 in c:\users\acer\anaconda3\lib\site-packages (from requests>=2.19.0->tensorflow_data
sets) (3.4)
Requirement already satisfied: charset-normalizer<3,>=2 in c:\users\acer\anaconda3\lib\site-packages (from requests>=2.19.0->ten
sorflow_datasets) (2.0.4)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\users\acer\anaconda3\lib\site-packages (from requests>=2.19.0->tensor
flow_datasets) (1.26.14)
Requirement already satisfied: attrs>=18.2.0 in c:\users\acer\anaconda3\lib\site-packages (from dm-tree->tensorflow_datasets) (2
2.1.0)
Requirement already satisfied: six in c:\users\acer\anaconda3\lib\site-packages (from promise->tensorflow_datasets) (1.16.0)
Requirement already satisfied: docstring-parser<1.0,>=0.15 in c:\users\acer\anaconda3\lib\site-packages (from simple_parsing->te
nsorflow_datasets) (0.16)
Requirement already satisfied: colorama in c:\users\acer\anaconda3\lib\site-packages (from tqdm->tensorflow_datasets) (0.4.6)
Data loading and pre-processing functions
We first provide some helper functions to format the data in the way we need. You shouldn't need to change these, though you are welcome to if
you like.
One thing you may want to do is create additional augmentation functions, and the flip_lr_augmentation function below could be used as a
template to create additional augmentation types.
we want to just keep the merge the edges and foreground of the doggo/catto, and
then treat it as a binary semantic segmentation task.
To achieve this, we will just subtract two, converting to values of [-1, 0, 1],
and then apply the abs function to convert the -1 values (edges) to the foreground.
Will also convert it to 32 bit float which will be needed for working with tf.
Args:
segmentation_mask (array):
original segmentation mask
Returns:
preprocessed segmentation_mask
"""
return tf.abs(tf.cast(segmentation_mask, tf.float32) - 2)
def return_image_label_mask(ds_out):
""" function to return image, class label and segmentation mask
The original dataset contains additional information, such as the filename and
the species. We don't care about any of that for this work, so will
discard them and just keep the original image as our input, and then
a tuple of our outputs that will be the class label and the semantic
segmentation mask.
Args:
ds_out: dict
original dataset output
Returns:
RGB image
tuple of class label and preprocessed segmentation mask
"""
# preprocess the segmentation mask
seg_mask = preprocess_segmentation_mask(ds_out['segmentation_mask'])
image = tf.cast(ds_out['image'], tf.float32)
# image = standardise_image(image)
return image, (ds_out['label'], seg_mask)
def mobilenet_preprocess_image(image):
"""Apply preprocessing that is suitable for MobileNetV3.
you should use this preprocessing for both your model and the mobilenet model
"""
image = (image - 127.5) / 255.0
return image
def unprocess_image(image):
""" undo preprocessing above so can plot images"""
image = image * 255.0 + 127.5
return image
Each image in the dataset is of a different size. The resizing will make sure
each image is the same size.
"""
# resize the image and the semantic segmentation mask
image = tf.image.resize(image, [image_size, image_size])
image = mobilenet_preprocess_image(image)
mask = tf.image.resize(output[1], [image_size, image_size])
return image, (output[0], mask)
The function will flip the image along the left-right axis with
a defined probability.
"""
# apply augmentation
image, seg = tf.cond(flip_lr_cond, flip, no_flip)
# return the image, and output
return image, (output[0], seg)
By default for each input there are two outputs. This function allows
you to select which outputs to use, so the problem can be reduced to a
single task problem for initial experimenting.
"""
# both tasks
if classification and segmentation:
return image, output
# just classification
elif classification:
return image, output[0]
# just segmentation
elif segmentation:
return image, output[1]
# neither task, doesn't really make sense, so return the image
# for a self-supervised task
else:
return image, image
class TrainForTime(keras.callbacks.Callback):
"""callback to terminate training after a time limit is reached
Can be used to control how long training runs for, and will terminate
training once a specified time limit is reached.
"""
def __init__(
self,
train_time_mins=15,
):
super().__init__()
self.train_time_mins = train_time_mins
self.epochs = 0
self.train_time = 0
self.end_early = False
# Plot side-by-side
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.imshow(image_vis)
plt.title("Original Image")
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(aug_image_vis)
plt.title("Horizontally Flipped Image")
plt.axis('off')
plt.show()
Data Loader
We will now put the above functions together into a data loader that we can use to feed directly to the network. You can you this directly as it is.
However, you may modify it to add some additional functionality such as further data augmentations.
Function handles loading of data for 1b, included processing of images and
semantic segmentation masks. This function will
organise the tensorflow dataset to return an output that is a tuple, where
the tuple will be (classification_labels, segmentation_masks).
Parameters
----------
split : string
either train, val or test string
classification : bool
whether to include classification labels
segmentation : bool
whether to include semantic segmentation masks
batch_size : int
size of batches to use
shuffle : bool
whether to shuffle the dataset (WILL ONLY APPLY TO TRAIN)
augment : bool
whether to augment the dataset (WILL ONLY APPLY TO TRAIN)
image_size : int
new image size
Returns
-------
tf.Dataset containing the Oxford pets dataset
"""
# lets do some error checking first
# Check fior a valid dataset split, this must be train or test
if (split != 'train') and (split != 'val') and (split != 'test'):
raise ValueError('Arg for split must be either \'train\' or \'test\'')
if (not classification) and (not segmentation):
print("WARNING: One of the tasks (classification and segmentation) must be selected")
print("Setting both to enabled")
classification = True
segmentation = True
# check that if using the val split, shuffle if false. If not, print a warning and force shuffle to be false
if (split == 'val') and shuffle:
print("WARNING: shuffle is set to true, but have specified split to be \'val\'")
print('The shuffle argument will be ignored')
shuffle = False
# check that if using the test split, shuffle if false. If not, print a warning and force shuffle to be false
if (split == 'test') and shuffle:
print("WARNING: shuffle is set to true, but have specified split to be \'test\'")
print('The shuffle argument will be ignored')
shuffle = False
# check that if using the val split, augment if false. If not, print a warning and force augment to be false
if (split == 'val') and augment:
print("WARNING: augment is set to true, but have specified split to be \'val\'")
print('The augment argument will be ignored')
augment = False
# check that if using the test split, augment if false. If not, print a warning and force augment to be false
if (split == 'test') and augment:
print("WARNING: augment is set to true, but have specified split to be \'test\'")
print('The augment argument will be ignored')
augment = False
# the dataset by default only has train and test splits. If val is requested, pull the first 30% of the test set
if (split == 'val'):
split = 'test[:30%]'
# the test set then becomes the remaining 70% of the original test set
elif (split == 'test'):
split = 'test[30%:]'
# augmentation
# only apply if in the training split and augment has been set to True
if split == 'train' and augment:
# apply a left-right flip with 50% probability
flip_lr_prob = 0.5
# flip operation
ds = ds.map(lambda inp, out: flip_lr_augmentation(inp, out, flip_lr_prob), num_parallel_calls=tf.data.AUTOTUNE)
# and now remove any tasks that we don't want. Note that we call this last as it means that all the other functions
# can safely assume that data for both tasks is in the dataset
ds = ds.map(lambda inp, out: select_tasks(inp, out, classification, segmentation))
print(output[1].shape)
i += 1
if i >= num_plot:
break
plt.savefig('doggos_cattos.png')
We can use the classification and segmentation flags to pull out just one output as well, as the below demonstrates.
In [6]: # classification only; classification = True, segmentation = False (note batch size is 1 here)
train_class_only = load_oxford_pets('train', classification=True, segmentation=False, shuffle=True, augment=True, batch_size=1, i
# segmentation only; classification = False, segmentation = True (note batch size is 1 here)
train_seg_only = load_oxford_pets('train', classification=False, segmentation=True, shuffle=True, augment=True, batch_size=1, ima
[29]
(1, 300, 300, 1)
While for the question you do need to train networks to do both tasks simultaenously, when you starting playing with the problem it might be
easier to get things working for one task, and then add the second.
Note that we will need to set the preprocessing option when loading this base network to False. This is because the include_preprocessing
step is implemented in the Datasets we defined above.
We also set include_top=False , to avoid loading our model with the final Dense classification layer which is used for the original Imagenet
model.
For this task, can ignore the input_shape warning, though it is important to keep in mind the difference in size of data used for the pre-trained
model and our data may have an impact on our model (what that impact might be is for you to investigate :) ). Depending on what input shape
you select you may also be able to eliminate this.
For more information on fine-tuning models, can refer to many of the examples from class, or the Keras documentation
Data Loading
This section loads the Oxford-IIIT Pets dataset using the provided load_oxford_pets function. The batch size and image size are set
appropriately. Augmentation is turned on for training, and off for validation/testing. Additional augmentations have been added to improve
generalization.
test_ds = load_oxford_pets(
split='test',
image_size=image_size,
batch_size=batch_size,
classification=True,
segmentation=True,
shuffle=False,
augment=False
)
# Classification branch
class_branch = layers.GlobalAveragePooling2D()(x)
class_output = layers.Dense(num_classes, activation='softmax', name='classification')(class_branch)
# Segmentation branch
# Segmentation branch — upsample to match 160x160
seg_branch = layers.Conv2D(128, 3, activation='relu', padding='same')(x) # assume x is ~20x20
seg_branch = layers.UpSampling2D(2)(seg_branch) # ~40x40
seg_branch = layers.Conv2D(64, 3, activation='relu', padding='same')(seg_branch)
seg_branch = layers.UpSampling2D(2)(seg_branch) # ~80x80
seg_branch = layers.Conv2D(32, 3, activation='relu', padding='same')(seg_branch)
seg_branch = layers.UpSampling2D(2)(seg_branch) # ~160x160
seg_branch = layers.Conv2D(3, 1, activation='softmax', name='segmentation')(seg_branch) # Output: (160, 160, 3)
scratch_model = build_scratch_model()
scratch_model.compile(
optimizer='adam',
loss={
'classification': 'sparse_categorical_crossentropy',
'segmentation': 'sparse_categorical_crossentropy' # <-- change this line
},
metrics={
'classification': 'accuracy',
'segmentation': 'accuracy'
}
)
# Train
scratch_model.fit(train_ds, epochs=5, validation_data=test_ds)
Epoch 1/5
115/115 ━━━━━━━━━━━━━━━━━━━━ 186s 2s/step - classification_accuracy: 0.0231 - classification_loss: 3.6208 - loss: 4.3136 - segme
ntation_accuracy: 0.5973 - segmentation_loss: 0.6928 - val_classification_accuracy: 0.0327 - val_classification_loss: 3.6073 - v
al_loss: 4.1386 - val_segmentation_accuracy: 0.7455 - val_segmentation_loss: 0.5307
Epoch 2/5
115/115 ━━━━━━━━━━━━━━━━━━━━ 182s 2s/step - classification_accuracy: 0.0354 - classification_loss: 3.6037 - loss: 4.0832 - segme
ntation_accuracy: 0.7666 - segmentation_loss: 0.4795 - val_classification_accuracy: 0.0370 - val_classification_loss: 3.5980 - v
al_loss: 4.0159 - val_segmentation_accuracy: 0.8129 - val_segmentation_loss: 0.4170
Epoch 3/5
115/115 ━━━━━━━━━━━━━━━━━━━━ 168s 1s/step - classification_accuracy: 0.0382 - classification_loss: 3.5884 - loss: 4.0221 - segme
ntation_accuracy: 0.7975 - segmentation_loss: 0.4336 - val_classification_accuracy: 0.0323 - val_classification_loss: 3.5831 - v
al_loss: 3.9780 - val_segmentation_accuracy: 0.8206 - val_segmentation_loss: 0.3932
Epoch 4/5
115/115 ━━━━━━━━━━━━━━━━━━━━ 172s 1s/step - classification_accuracy: 0.0617 - classification_loss: 3.5292 - loss: 3.9158 - segme
ntation_accuracy: 0.8217 - segmentation_loss: 0.3866 - val_classification_accuracy: 0.0572 - val_classification_loss: 3.5329 - v
al_loss: 3.9430 - val_segmentation_accuracy: 0.8131 - val_segmentation_loss: 0.4088
<keras.src.callbacks.history.History at 0x2ca959a72e0>
Out[8]:
# Extract predictions
y_true_cls = []
y_pred_cls = []
y_true_seg = []
y_pred_seg = []
# Classification
y_true_cls.extend(labels_cls.numpy())
y_pred_cls.extend(np.argmax(preds_cls, axis=-1))
# Segmentation
labels_seg_np = labels_seg.numpy().reshape(-1).astype(int)
preds_seg_np = np.argmax(preds_seg, axis=-1).reshape(-1).astype(int)
y_true_seg.extend(labels_seg_np.tolist())
y_pred_seg.extend(preds_seg_np.tolist())
# Classification evaluation
print("Classification Report (Scratch Model):")
print(classification_report(y_true_cls, y_pred_cls))
# Segmentation evaluation
print("Segmentation IoU:", jaccard_score(y_true_seg, y_pred_seg, average='macro', zero_division=0))
print("Segmentation F1 Score:", f1_score(y_true_seg, y_pred_seg, average='macro', zero_division=0))
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 475ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 413ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 488ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 425ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 378ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 423ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 385ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 383ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 397ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 436ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 375ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 351ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 338ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 358ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 381ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 471ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 443ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 559ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 614ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 517ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 408ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 416ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 567ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 457ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 439ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 536ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 487ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 506ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 487ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 864ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 676ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 481ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 405ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 498ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 574ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 454ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 429ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 457ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 559ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 497ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 507ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 454ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 419ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 489ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 419ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 569ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 747ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 594ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 718ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 739ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 686ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 574ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 514ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 577ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 627ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 524ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 591ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 756ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 590ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 701ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 643ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 515ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 505ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 494ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 647ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 893ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 667ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 636ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 561ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 704ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 657ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 746ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 633ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 681ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 625ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 723ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 648ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 610ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 311ms/step
Classification Report (Scratch Model):
precision recall f1-score support
inputs = layers.Input(shape=input_shape)
x = base_model(inputs)
# Classification head
class_branch = layers.GlobalAveragePooling2D()(x)
class_output = layers.Dense(num_classes, activation='softmax', name='classification')(class_branch)
mobilenet_model = build_mobilenet_model()
mobilenet_model.compile(
optimizer='adam',
loss={
'classification': 'sparse_categorical_crossentropy',
'segmentation': 'sparse_categorical_crossentropy' # use sparse loss for integer masks
},
metrics={
'classification': 'accuracy',
'segmentation': 'accuracy'
}
)
Epoch 1/5
115/115 ━━━━━━━━━━━━━━━━━━━━ 135s 615ms/step - classification_accuracy: 0.0241 - classification_loss: 3.6584 - loss: 4.2824 - se
gmentation_accuracy: 0.6977 - segmentation_loss: 0.6239 - val_classification_accuracy: 0.0296 - val_classification_loss: 3.6206
- val_loss: 4.1400 - val_segmentation_accuracy: 0.7406 - val_segmentation_loss: 0.5193
Epoch 2/5
115/115 ━━━━━━━━━━━━━━━━━━━━ 81s 689ms/step - classification_accuracy: 0.0303 - classification_loss: 3.6310 - loss: 4.1411 - seg
mentation_accuracy: 0.7461 - segmentation_loss: 0.5100 - val_classification_accuracy: 0.0280 - val_classification_loss: 3.6204 -
val_loss: 4.1256 - val_segmentation_accuracy: 0.7495 - val_segmentation_loss: 0.5050
Epoch 3/5
115/115 ━━━━━━━━━━━━━━━━━━━━ 70s 599ms/step - classification_accuracy: 0.0245 - classification_loss: 3.6240 - loss: 4.1266 - seg
mentation_accuracy: 0.7504 - segmentation_loss: 0.5025 - val_classification_accuracy: 0.0288 - val_classification_loss: 3.6132 -
val_loss: 4.1158 - val_segmentation_accuracy: 0.7513 - val_segmentation_loss: 0.5025
Epoch 4/5
115/115 ━━━━━━━━━━━━━━━━━━━━ 73s 618ms/step - classification_accuracy: 0.0273 - classification_loss: 3.6180 - loss: 4.1163 - seg
mentation_accuracy: 0.7545 - segmentation_loss: 0.4983 - val_classification_accuracy: 0.0374 - val_classification_loss: 3.6082 -
val_loss: 4.0899 - val_segmentation_accuracy: 0.7640 - val_segmentation_loss: 0.4818
Epoch 5/5
115/115 ━━━━━━━━━━━━━━━━━━━━ 76s 655ms/step - classification_accuracy: 0.0326 - classification_loss: 3.6134 - loss: 4.1033 - seg
mentation_accuracy: 0.7594 - segmentation_loss: 0.4898 - val_classification_accuracy: 0.0327 - val_classification_loss: 3.5969 -
val_loss: 4.1494 - val_segmentation_accuracy: 0.7229 - val_segmentation_loss: 0.5536
<keras.src.callbacks.history.History at 0x168468d5a20>
Out[18]:
In [19]: # Evaluate scratch model
scratch_eval = scratch_model.evaluate(test_ds)
print("Scratch model evaluation:", scratch_eval)
81/81 ━━━━━━━━━━━━━━━━━━━━ 26s 321ms/step - classification_accuracy: 0.0631 - classification_loss: 3.5424 - loss: 3.9710 - segme
ntation_accuracy: 0.7992 - segmentation_loss: 0.4286
Scratch model evaluation: [3.9609272480010986, 3.531313180923462, 0.4281884729862213, 0.06542056053876877, 0.7993440628051758]
81/81 ━━━━━━━━━━━━━━━━━━━━ 19s 236ms/step - classification_accuracy: 0.0323 - classification_loss: 3.5896 - loss: 4.1431 - segme
ntation_accuracy: 0.7223 - segmentation_loss: 0.5535
MobileNet model evaluation: [4.149447441101074, 3.596855640411377, 0.5536401271820068, 0.032710280269384384, 0.7228771448135376]