0% found this document useful (0 votes)
12 views28 pages

10 - Fine-Tuning - Pretrained - Models - For - Computer Vision - Ipynb - Colab

This document outlines a chapter on fine-tuning pretrained models for computer vision, specifically using the ResNet18 model to classify images of figs, oranges, mandarines, and onions (FOMO dataset). It covers data augmentation techniques, model freezing, and handling low-confidence predictions. The chapter also provides instructions for setting up the necessary libraries and downloading the FOMO dataset for training purposes.

Uploaded by

whizbainz
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
12 views28 pages

10 - Fine-Tuning - Pretrained - Models - For - Computer Vision - Ipynb - Colab

This document outlines a chapter on fine-tuning pretrained models for computer vision, specifically using the ResNet18 model to classify images of figs, oranges, mandarines, and onions (FOMO dataset). It covers data augmentation techniques, model freezing, and handling low-confidence predictions. The chapter also provides instructions for setting up the necessary libraries and downloading the FOMO dataset for training purposes.

Uploaded by

whizbainz
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 28

2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.

ipynb - Colab

keyboard_arrow_down Overview
In this chapter, you'll learn how to fine-tune a pretrained model for computer vision to perform
multi-class classification while using data augmentation to improve its performance. We'll be
using a subset of a popular dataset, Fruits-360, to fine-tune a relatively lightweight model,
ResNet18, to classify images from four different fruits and vegetables: figs, oranges,
mandarines, and onions (FOMO, for short).

By the end of this chapter, you'll be familiar with data augmentation techniques for images and
you'll know how to freeze parts of a model in order to speed up training. You'll also be able to
make your model give you an "I don't know" answer and to use test-time augmentation to handle
"split" predictions.

keyboard_arrow_down Learning Objectives


By the end of this chapter, you should be able to:

Apply data augmentation techniques to your input images


Freeze (and unfreeze) parts of a pretrained model to selectively fine-tune them
Use predicted probabilities to allow models to give "I don't know" answers
Use test-time augmentation to handle low-confidence predictions

To run this notebook on Google Colab, you will need to install the following libraries:
transformers and evaluate.

In Google Colab, you can run the following command to install these libraries:

!pip install transformers evaluate

Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packa


Collecting evaluate
Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages
Requirement already satisfied: huggingface-hub<1.0,>=0.24.0 in /usr/local/lib/pytho
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packag
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-pa
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packag
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages
Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.10/
Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-package
Collecting datasets>=2.0.0 (from evaluate)
Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill (from evaluate)
Downloading dill-0.3.9-py3-none-any.whl.metadata (10 kB)
Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (f

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 1/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
Collecting xxhash (from evaluate)
Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.w
Collecting multiprocess (from evaluate)
Downloading multiprocess-0.70.17-py310-none-any.whl.metadata (7.2 kB)
Requirement already satisfied: fsspec>=2021.05.0 in /usr/local/lib/python3.10/dist-
Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-pa
Collecting dill (from evaluate)
Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting multiprocess (from evaluate)
Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec>=2021.05.0 (from fsspec[http]>=2021.05.0->evaluate)
Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.1
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packa
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist
Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packa
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-pac
Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-p
Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-pack
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dis
Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-p
Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 84.0/84.0 kB 4.8 MB/s eta 0:00:00
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 480.6/480.6 kB 19.4 MB/s eta 0:00:00
Downloading dill-0.3.8-py3-none-any.whl (116 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 116.3/116.3 kB 8.9 MB/s eta 0:00:00
Downloading fsspec-2024.9.0-py3-none-any.whl (179 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 179.3/179.3 kB 14.5 MB/s eta 0:00:00
l di lti 0 0 3 0 hl ( 3 k )

Moreover, the get_image_from_url() and save_images() functions can be easily imported from a
set of helper functions we're making available for your convenience. You can download it from
the following link:

https://raw.githubusercontent.com/lftraining/LFD273-code/main/helper_functions.py

In Google Colab, you can run the following command to download the file:

!wget https://raw.githubusercontent.com/lftraining/LFD273-code/main/helper_functions.py

--2025-01-07 18:27:21-- https://raw.githubusercontent.com/lftraining/LFD273-code/mai


Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 1
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:
HTTP request sent, awaiting response... 200 OK
Length: 3583 (3.5K) [text/plain]
Saving to: ‘helper_functions.py’

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 2/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
helper_functions.py 100%[===================>] 3.50K --.-KB/s in 0s

2025-01-07 18:27:21 (50.4 MB/s) - ‘helper_functions.py’ saved [3583/3583]

Once the file is downloaded, you only need to import the required helper functions:

from helper_functions import get_image_from_url, save_images

keyboard_arrow_down Fine-Tuning Pretrained Models


Fine-Tuning Pretrained Models: Overview

Fine-tuning a model entails updating its weights by training it further on your own data. Just
like feature extraction, we need to replace the model's "head" with our own to account for a
different number of classes. However, unlike feature extraction, we won't be training the new
"head" separately.

We'll be using ResNet18 once again but, this time, our task is going to be multi-class
classification.

keyboard_arrow_down FOMO Dataset


We'd like to introduce the FOMO dataset! FOMO stands for "Figs, Oranges, Mandarines, and
Onions", and it is a subset of the Fruits-360: A dataset of images containing fruits and
vegetables by Horea Muresan and Mihai Oltean, licensed under MIT license.

You can build your own FOMO dataset by (partially) cloning the original repository from the
Fruits-360 dataset as shown below (if you're not running the commands below on Google Colab,
we highly recommend you download the prepared dataset instead):

!git clone -n --depth=1 --filter=tree:0 https://github.com/Horea94/Fruit-Images-Dataset.g


!cd Fruit-Images-Dataset && git sparse-checkout set --no-cone Training/Fig Training/Manda
!mv Fruit-Images-Dataset/ FOMO-Dataset/

Cloning into 'Fruit-Images-Dataset'...


remote: Enumerating objects: 1, done.
remote: Counting objects: 100% (1/1), done.
remote: Total 1 (delta 0), reused 1 (delta 0), pack-reused 0 (from 0)
Receiving objects: 100% (1/1), done.

Alternatively, you can download the version we have already prepared from the link below:

https://raw.githubusercontent.com/lftraining/LFD273-code/main/data/Fruits360/FOMO.tar.gz
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 3/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

Once the file is downloaded, you need to uncompress it. If you are using Google Colab, simply
run the commands below to perform both actions:

!wget https://raw.githubusercontent.com/lftraining/LFD273-code/main/data/Fruits360/FOMO.t
!tar -xvzf FOMO.tar.gz

--2025-01-07 18:27:31-- https://raw.githubusercontent.com/lftraining/LFD273-code/m


Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133,
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133
HTTP request sent, awaiting response... 200 OK
Length: 12701839 (12M) [application/octet-stream]
Saving to: ‘FOMO.tar.gz’

FOMO.tar.gz 100%[===================>] 12.11M --.-KB/s in 0.1s

2025-01-07 18:27:32 (111 MB/s) - ‘FOMO.tar.gz’ saved [12701839/12701839]

FOMO-Dataset/Training/
FOMO-Dataset/Training/Orange/
FOMO-Dataset/Training/Orange/21_100.jpg
FOMO-Dataset/Training/Orange/r_232_100.jpg
FOMO-Dataset/Training/Orange/296_100.jpg
FOMO-Dataset/Training/Orange/r_0_100.jpg
FOMO-Dataset/Training/Orange/255_100.jpg
FOMO-Dataset/Training/Orange/276_100.jpg
FOMO-Dataset/Training/Orange/110_100.jpg
FOMO-Dataset/Training/Orange/r_161_100.jpg
FOMO-Dataset/Training/Orange/r_156_100.jpg
FOMO-Dataset/Training/Orange/290_100.jpg
FOMO-Dataset/Training/Orange/130_100.jpg
FOMO-Dataset/Training/Orange/r_246_100.jpg
FOMO-Dataset/Training/Orange/r_284_100.jpg
FOMO-Dataset/Training/Orange/207_100.jpg
FOMO-Dataset/Training/Orange/r_318_100.jpg
FOMO-Dataset/Training/Orange/r_180_100.jpg
FOMO-Dataset/Training/Orange/r_259_100.jpg
FOMO-Dataset/Training/Orange/151_100.jpg
FOMO-Dataset/Training/Orange/260_100.jpg
FOMO-Dataset/Training/Orange/r_299_100.jpg
FOMO-Dataset/Training/Orange/r_211_100.jpg
FOMO-Dataset/Training/Orange/137_100.jpg
FOMO-Dataset/Training/Orange/r_178_100.jpg
FOMO-Dataset/Training/Orange/132_100.jpg
FOMO-Dataset/Training/Orange/100_100.jpg
FOMO-Dataset/Training/Orange/r_127_100.jpg
FOMO-Dataset/Training/Orange/r_234_100.jpg
FOMO-Dataset/Training/Orange/r_155_100.jpg
FOMO-Dataset/Training/Orange/r_294_100.jpg
FOMO-Dataset/Training/Orange/r_15_100.jpg
FOMO-Dataset/Training/Orange/r_313_100.jpg
FOMO-Dataset/Training/Orange/r_136_100.jpg
FOMO-Dataset/Training/Orange/r_119_100.jpg
FOMO-Dataset/Training/Orange/r_194_100.jpg
FOMO-Dataset/Training/Orange/135_100.jpg
FOMO-Dataset/Training/Orange/r_273_100.jpg
FOMO-Dataset/Training/Orange/251_100.jpg

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 4/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
FOMO-Dataset/Training/Orange/226_100.jpg
FOMO-Dataset/Training/Orange/r_189_100.jpg
FOMO-Dataset/Training/Orange/278_100.jpg
FOMO-Dataset/Training/Orange/r_148_100.jpg
FOMO-Dataset/Training/Orange/r_175_100.jpg
FOMO-Dataset/Training/Orange/127_100.jpg
FOMO D t t/T i i /O /250 100 j

The FOMO dataset contains only images of four fruits/vegetables for easier and faster training.
It is organized into two main folders, Training and Test, and each folder has four subfolders, one
for each class, Fig, Mandarine, Onion White, and Orange. There are 2,109 images in total for
training, and 706 images for testing.

Training/Fig/0_100.jpg

Training/Fig/1_100.jpg

...

Training/Orange/0_100.jpg

Training/Orange/1_100.jpg

...

Test/Fig/27_100.jpg

Test/Fig/28_100.jpg

...

Let's visualize one of each:

from torchvision.io import read_image, write_png


from torchvision.utils import make_grid
from PIL import Image
train_folder = './FOMO-Dataset/Training'
classes = ['Fig', 'Orange', 'Mandarine', 'Onion White']
images = [read_image(f'{train_folder}/{cl}/0_100.jpg') for cl in classes]
image_grid = make_grid(images, nrow=4)
write_png(image_grid, 'fomo.png')
Image.open('fomo.png')

What we're actually missing in our data pipeline is some data augmentation!

In our previous, concrete, example, we used a pretrained model as a feature extractor, and we
saved resources during training by extracting features from all images in our dataset
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 5/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

beforehand.

On the one hand, it was much faster to train the model's new "head". On the other hand, we
deprived ourselves of the possibility of using data augmentation on our images. This wasn't
much of an issue in our previous example because we had more than 30,000 images to train our
model; it wasn't like we were lacking data.

Now, we have a little bit less than 3,000 images that we can use to train our model. Still far from
"too few" images, but our training may still benefit from some new, augmented, data.

Let's take a look at what data augmentation is and how we can easily augment our data using
Torchvision's transforms.

keyboard_arrow_down Transforms and Augmentations


The purpose of data augmentation is to artificially create new data without incurring the cost of
actually acquiring it. Let's take an image of a fig, for example:

# The funny thing is, it is both a FIG and a FIG(ure)!


fig = Image.open(f'{train_folder}/Fig/0_100.jpg')
fig

Now, let's flip it:

fig.transpose(Image.FLIP_LEFT_RIGHT)

It's still a fig, right?

fig.rotate(45)

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 6/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

What about now?

We could go on and on with this, but our point should already be clear: we can make changes to
the original image such as rotating it, flipping it, zooming, cropping, adding noise to it, and it will
still be an image of a fig, as if the picture were taken from a different angle, or from a different
viewpoint, for example.

Augmented images may not be as good as the original ones, but they often suffice for the
purpose of training a model and, best of all, they're free! Torchvision offers plenty of different
transforms, divided into several groups.

Transforms offered by Torchvision, divided into groups

Geometry: Transformations such as Resize(), CenterCrop(), or Rotation() that modify the


image's shape or its boundaries.
Color: These transformations modify the image's color only, such as Grayscale() or
ColorJitter().
Composition: The typical Compose() transformation is nothing but a way to compose
several transformations in a given list together, but this group also includes
RandomApply() to apply randomly a list of transformations
Miscellaneous: This group includes the typical Normalize() transformation that
standardizes the pixel values, and the custom Lambda() transformation
Conversion: These transformations are used mostly to convert images back and forth to
tensors, such as ToTensor() and ToPILImage()
Auto-augmentation: These transformations, such as RandAugment() and AutoAugment(),
implement random auto-augmentation techniques.
Functional: As opposed to the other groups, these transforms do not contain parameters
of their own, so you're in charge of specifying them (for example, to draw a random angle
to perform rotation).

Note: Don't mind the transformations prefixed with v2 just yet. They perform exactly the same
transformations on images as their counterparts, and they were created to handle additional
objects as well, such as bounding boxes. We'll get to them later on.

One particularly interesting transformation is the RandAugment() which, as its name suggests,
applies random augmenting transformations to the images, so let's try it out.

It is important to highlight that data augmentation belongs in the training set only as far as we're
concerned here, so we need to have different transformations for the training and validation/test

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 7/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

sets. Both transformations must include the prescribed transformations of the underlying model
(ResNet18 in our case), but only the former should include augmentation.

Let's create a helper function that returns a composition of transformations according to the
split they're going to be applied to:

from torchvision import transforms as T


from torchvision.models import get_weight
weights = get_weight('ResNet18_Weights.DEFAULT')
transforms_fn = weights.transforms()
def get_transforms(train=True):
transfs = [T.RandAugment()] if train else []
transfs.append(transforms_fn)
return T.Compose(transfs)

Once the transformations are taken care of, we can once again create ImageFolder datasets
using them:

from torchvision.datasets import ImageFolder


datasets = {}
datasets['train'] = ImageFolder(root='./FOMO-Dataset/Training', transform=get_transforms(
datasets['test'] = ImageFolder(root='./FOMO-Dataset/Test', transform=get_transforms(False

The transformation is an attribute of the dataset, as we can see below. Notice that it was
wrapped with a StandardTransform. The StandardTransform combines both transform (applied
to the dataset's features) and target_transform (obviously applied to the dataset's targets) in a
single object.

The base class for making datasets compatible with Torchvision, VisionDataset, which
ImageFolder inherits from, actually has three arguments for transformations:

transforms: a callable that takes both an image and label and transforms both
transform: a callable that takes an image and transforms it
target_transform: a callable that takes a label/target and transforms it

As you probably guessed, they are mutually exclusive: you can either use transforms alone, or a
combination of transform and target_transform. If you choose separate transformations, they
will nonetheless be internally combined as a single StandardTransform.

datasets['train'].transforms

StandardTransform
Transform: Compose(
RandAugment(num_ops=2, magnitude=9, num_magnitude_bins=31,
interpolation=InterpolationMode.NEAREST, fill=None)
ImageClassification(
crop_size=[224]
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 8/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
resize_size=[256]
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
interpolation=InterpolationMode.BILINEAR
)
)

Then, let's create some data loaders for our datasets:

from torch.utils.data import DataLoader


dataloaders = {}
dataloaders['train'] = DataLoader(datasets['train'], batch_size=32, shuffle=True)
dataloaders['test'] = DataLoader(datasets['test'], batch_size=32)

keyboard_arrow_down Model
We'll keep using ResNet18 as our main model, as it is powerful and small enough for our needs.
Let's load its pretrained weights one more time:

import torch
repo = 'pytorch/vision'
weights = get_weight('ResNet18_Weights.DEFAULT')
model = torch.hub.load(repo, 'resnet18', weights=weights)

Downloading: "https://github.com/pytorch/vision/zipball/main" to /root/.cache/torch/h


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.ca
100%|██████████| 44.7M/44.7M [00:00<00:00, 113MB/s]

keyboard_arrow_down Model: Model Freezing


Freezing the model means that it won't learn anymore, that is, its weights won't be updated
anymore. Do you remember what makes a tensor of parameters updateable? It was its
requires_grad attribute.

It works like a switch: if it's on, that layer will learn/be updated during backpropagation but, if it's
off, it will simply produce outputs through its forward() method and nothing else.

We can easily freeze the whole model using a simple helper function, as shown below:

def freeze_model(model):
for parameter in model.parameters():
parameter.requires_grad = False
freeze_model(model)

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 9/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

Of course, a completely frozen model cannot learn anything, so what's the point of that? Well,
even if the model is frozen, we may still replace its parts, so let's replace its "head" (the fc layer)
with our own:

import torch.nn as nn
# Classification Head for the FOMO dataset
torch.manual_seed(42)
model.fc = nn.Linear(512, 4)

Newly created layers are, by default, not frozen. So, if we double-check which parameters from
our modified Resnet18 are capable of learning, what would be the result?

for name, param in model.named_parameters():


if param.requires_grad == True:
print(name)

fc.weight
fc.bias

Only tensors from our new fc layer, its weight and its bias, can learn anything. The whole "body"
of the model is simply producing outputs that will be fed into the last, learnable, layer.

Sounds familiar? We're doing feature extraction again! Instead of creating an explicit headless
model to extract features and then training a new "head" to be reattached later, we're doing
everything in one go.

Our initial motivation for this was data augmentation, but there is an added bonus to this
approach: we can choose how many parts of the underlying ResNet18 model we want to
actually freeze. The number of layers you may choose not to freeze is roughly a function of the
size of your dataset. The more data you have, the more layers you can afford to update.

Technically speaking, we're only truly fine-tuning a model if we do not freeze any of its
pretrained weights. On the other side of the spectrum, if we freeze everything but the
classification "head", we're doing feature extraction.

In this example, we'll keep everything but the "head" frozen - so we're effectively doing feature
extraction once again. But we encourage you to try unfreezing everything and truly fine-tune your
modified ResNet18 model, just keep in mind that the training process will take longer.

Moreover, if you'd like to train different base models, such as AlexNet, VGG, or Inception, to
name a few, please refer to the table below for each model's expected input image size, the
layer(s) corresponding to the model's "head", and the appropriate replacement layer
(considering the number of extracted features):

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 10/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

Notice that the Inception model has two "heads", as it uses a side (auxiliary) "head" during
training, originally developed to mitigate the vanishing gradients problem that makes stop
learning, but that proved to have a regularizer effect instead. During evaluation/prediction time,
though, only the main classifier "head" is used.

So, let's load a mini-batch from our data loader and pass it through our model, just to make sure
we're getting the expected output back, namely, four logits (one for each class) for each data
point:

images, labels = next(iter(dataloaders['train']))


device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
logits = model(images.to(device))
logits

tensor([[-1.1355, -0.2022, -0.1163, 0.1853],


[-0.4704, -0.3673, 1.2446, 0.2377],
[-0.7080, 0.1837, 0.1494, -0.6758],
[-1.2240, -0.3268, 0.2787, -0.4828],
[-1.1817, -0.0728, -0.0032, -0.7081],
[-0.8558, -0.3409, 0.4991, -1.0397],
[-1.1391, -0.2130, 0.2410, 0.0195],
[-1.2571, -0.4477, 0.6080, -0.1691],
[-1.4153, -0.2740, 0.1672, -0.6846],
[-0.4690, -0.2555, 0.4823, 0.1124],
[-1.2957, -0.6879, 0.2374, -0.8903],

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 11/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
[-0.4013, -0.3365, 0.5862, 0.2260],
[-1.0981, 0.1739, 0.8421, -0.4000],
[-0.7232, -0.1500, 0.3659, -0.0563],
[-1.0256, -0.7758, 0.3387, 0.2696],
[-0.4930, -0.0232, 0.2135, -0.6374],
[-0.9779, -0.4253, -0.1125, -0.7417],
[-2.0073, -0.2654, 0.5722, 0.4444],
[-0.9135, 0.0850, 0.0961, -0.6709],
[-0.8877, -1.1032, 0.5218, -0.1693],
[-0.8426, -0.6283, 0.9137, 0.2830],
[-0.5762, -0.2691, 0.8709, 0.2452],
[-0.6045, -0.4768, 0.7342, -0.5818],
[-0.5095, -0.4764, -0.0650, -0.7979],
[-0.1320, -0.0939, 0.4242, -0.1109],
[-0.3943, 0.1010, 1.0155, -0.0632],
[-0.6062, -0.5808, 0.5493, 0.2789],
[-0.8768, -0.5799, 0.3674, 0.0076],
[-1.7618, -0.4277, 0.0644, -0.1051],
[-1.0389, -0.1231, 0.1616, 0.0137],
[-0.6106, -0.4904, 0.7372, 0.1672],
[-1.1167, -0.4123, -0.0733, -0.0666]], device='cuda:0',
grad_fn=<AddmmBackward0>)

Let's fine-tune our model now! Sure, we're not as much fine-tuning it as we're extracting features
from it, but the training loop itself won't change regardless of how many layers you choose to
freeze or not.

keyboard_arrow_down Training Loop


We're dealing with a multi-class classification task now, and that asks for cross-entropy loss
(see the table , reproduced below for your convenience):

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 12/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

loss_fn = nn.CrossEntropyLoss()

We'll keep using Adam as the optimizer:

import torch.optim as optim


## Suggested learning rate
lr = 3e-3
optimizer = optim.Adam(model.parameters(), lr=lr)

Finally, the training loop itself is the same as before:

from tqdm import tqdm


device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
batch_losses = []
## Training
for i, (batch_features, batch_targets) in tqdm(enumerate(dataloaders['train'])):
model.train()
batch_features = batch_features.to(device)
batch_targets = batch_targets.to(device)
# Step 1 - forward pass
predictions = model(batch_features)
# Step 2 - computing the loss
loss = loss_fn(predictions, batch_targets)
# Step 3 - computing the gradients

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 13/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

loss.backward()
batch_losses.append(loss.item())
# Step 4 - updating parameters and zeroing gradients
optimizer.step()
optimizer.zero_grad()

66it [00:14, 4.46it/s]

Let's plot the training losses:

from matplotlib import pyplot as plt


plt.plot(batch_losses)
plt.xlabel('Mini-Batches')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()

WARNING:matplotlib.legend:No artists with labels found to put in legend. Note that a


<matplotlib.legend.Legend at 0x7b59e08ff340>

Your plot should look like the one above, but it will surely be slightly different since we're using
random augmentations during training.

Then, let's compute losses for our validation set and plot them as well:

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 14/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

## Validation
with torch.inference_mode():
val_losses = []
for i, (val_features, val_targets) in enumerate(dataloaders['test']):
model.eval()
val_features = val_features.to(device)
val_targets = val_targets.to(device)
# Step 1 - forward pass
predictions = model(val_features)
# Step 2 - computing the loss
loss = loss_fn(predictions, val_targets)
val_losses.append(loss.item())
plt.plot(val_losses)
plt.xlabel('Mini-Batches')
plt.ylabel('Loss')
plt.title('Validation Loss')
plt.legend()

WARNING:matplotlib.legend:No artists with labels found to put in legend. Note that a


<matplotlib.legend.Legend at 0x7b59e24f1510>

That's a pretty atypical loss plot, right? But keep in mind that this plot is not over epochs, but
over mini-batches. Besides, remember that we did not shuffle the validation set, so the plot is
actually showing us there's one particular subset of our data that's more problematic than the
others.

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 15/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

By the way, your plot may look slightly different once again. Since the model itself was trained
on data that was randomly augmented, it will produce slightly different validation losses and
predictions too, as we'll shortly see.

Let's retrieve the labels of the problematic batches:

problematic_batches = [14, 15, 16]


high_loss_labels = []
for i, (images, targets) in enumerate(dataloaders['test']):
if i in problematic_batches:
high_loss_labels.append(targets)
torch.cat(high_loss_labels)

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

Which class is that?

datasets['test'].class_to_idx

{'Fig': 0, 'Mandarine': 1, 'Onion White': 2, 'Orange': 3}

Looks like our model doesn't like onions very much!

Next, let's properly evaluate our model.

keyboard_arrow_down Evaluation
We'll start with our model's accuracy on the training set:

import evaluate
accuracy = evaluate.load("accuracy")
model.eval()
for features, targets in dataloaders['train']:
pred = model(features.to(device))
pred_class = torch.nn.functional.softmax(pred, dim=1).argmax(dim=1)
accuracy.add_batch(references=targets, predictions=pred_class)
accuracy.compute()

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 16/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarnin
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public mo
warnings.warn(
Downloading builder script: 100% 4.20k/4.20k [00:00<00:00, 282kB/s]
{'accuracy': 0.9947842579421526}

Not bad, right? But, if we run it more than once, we may get slightly different results, which is
puzzling since we did set our model to evaluation mode. What can possibly be happening here?

It turns out, we are augmenting our training set using the RandAugment() transformation,
remember? It means that we'll get a slightly different, augmented, image every time we fetch it!

Therefore, we need to create a plain-vanilla dataset and data loaders that do not randomly
augment the training images, so we can get a consistent evaluation metric for our training set:

datasets['train_vanilla'] = ImageFolder(root='./FOMO-Dataset/Training', transform=get_tra


dataloaders['train_vanilla'] = DataLoader(datasets['train_vanilla'], batch_size=32, shuff
accuracy = evaluate.load("accuracy")
model.eval()
for features, targets in dataloaders['train_vanilla']:
pred = model(features.to(device))
pred_class = torch.nn.functional.softmax(pred, dim=1).argmax(dim=1)
accuracy.add_batch(references=targets, predictions=pred_class)
accuracy.compute()

{'accuracy': 1.0}

accuracy = evaluate.load("accuracy")
model.eval()
for features, targets in dataloaders['test']:
pred = model(features.to(device))
pred_class = torch.nn.functional.softmax(pred, dim=1).argmax(dim=1)
accuracy.add_batch(references=targets, predictions=pred_class)
accuracy.compute()

{'accuracy': 1.0}

keyboard_arrow_down Inference
Our model is doing pretty well, but does it stand the test of real-world images? Let's find out!

We can use the same save_images() function as in "Transfer Learning and Pretrained Models"
to download a few images of fruits and vegetables.

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 17/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

If you prefer not to signup for an API key, you may skip the next cell and use only the provided
URLs in the cells that follow it.

Feel free to use those images to try out the model. We'll be using a few images that are typically
downloaded by that function (at the time of writing) to illustrate some ideas below. However,
since it is possible that the save_images() function downloads different images for you, we're
using the original images instead:

fig_url = 'https://cdn.pixabay.com/photo/2012/09/08/17/38/euro-dynasty-56405_1280.jpg'
get_image_from_url(fig_url)

We can also use the previous chapter's predict() function once again:

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 18/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

def predict(path_or_url, model, transforms_fn, categories, topk=1, headers=None):


if path_or_url.startswith('http'):
img = get_image_from_url(path_or_url, headers=headers)
else:
img = Image.open(path_or_url)
# Preprocesses the image using the transforms_fn
preproc_img = transforms_fn(img)
# If there are only three dimensions (CHW), unsqueeze the first
# to get a mini-batch of one (NCHW)
if len(preproc_img.shape) == 3:
preproc_img = preproc_img.unsqueeze(0)
# Never forget to set the model to evaluation mode!
model.eval()
# We find in which device the model is loaded on
# and send the preprocessed image to the same device
# to get predictions from the model
device = next(iter(model.parameters())).device
pred = model(preproc_img.to(device))
# If the output is a dictionary, extract logits from it
if isinstance(pred, dict):
pred = pred['logits']
# Binary classification
is_binary = (pred.shape[1] == 1)
if is_binary:
# Uses sigmoid function to convert predicted logits into probabilities
probability = torch.sigmoid(pred).squeeze()
# In binary classification, we need to use a threshold to determine if
# it is a positive or a negative class
threshold = 0.5
pred_class = probability > threshold
values = (1-probability) if (probability <= threshold) else probability
return [{'label': categories[pred_class], 'value': values.tolist()}]
# Multi-class classification
else:
# Uses softmax function to convert predicted logits into probabilities
probabilities = torch.nn.functional.softmax(pred[0], dim=0)
# In multi-class classification, we may take the top-k results only
values, indices = torch.topk(probabilities, topk)
return [{'label': categories[i], 'value': v.item()} for i, v in zip(indices, valu

Let's use the function to make a prediction for our fig (don't worry if you get different values,
that's expected since each trained model will be unique thanks to random augmentation during
training):

predict(fig_url, model, transforms_fn, datasets['train'].classes, 4)

[{'label': 'Fig', 'value': 0.7930049896240234},


{'label': 'Mandarine', 'value': 0.1789637804031372},
{'label': 'Onion White', 'value': 0.02123061753809452},
{'label': 'Orange', 'value': 0.006800664588809013}]

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 19/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

What if we show our model a pineapple? It did not see pineapples before, so what do you think
will happen?

pineapple_url = 'https://storage.needpix.com/rsynced_images/pineapple-1477419208LRZ.jpg'
get_image_from_url(pineapple_url)

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 20/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 21/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

predictions = predict(pineapple_url, model, transforms_fn, datasets['train'].classes, 4)


predictions

[{'label': 'Onion White', 'value': 0.4970454275608063},


{'label': 'Mandarine', 'value': 0.2572656273841858},
{'label': 'Orange', 'value': 0.19827018678188324},
{'label': 'Fig', 'value': 0.04741879180073738}]

Once again, don't mind if you got different values, or even a different class at the top. Just like in
the previous call to the prediction function, this is expected. The same holds true for all calls to
the predict function we'll make until the end of this chapter. The model output suggests it's more
likely to be a mandarine, but then again, it only "knows" mandarines, onions, oranges, and figs.
Still, the highest probability (mandarine) is only 43%, and we can actually use this information to
our advantage. Instead of blindly returning the class with the highest probability back to the
user, why not use an acceptance threshold? If the highest probability falls under the threshold, it
may be the case that the input image does not actually belong to the distribution the model was
initially trained for, that is, it may be an image of a completely different object (e.g. pineapples).

def most_likely(predictions):
category, prob = predictions[0]['label'], predictions[0]['value']
if prob >= .5: # choose your threshold here
return category
else:
return 'Uncertain'
most_likely(predictions)

'Uncertain'

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 22/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

Unfortunately, sometimes the model will be very confident about its predictions, even if the input
image does not belong to one of the trained classes...

url = 'https://upload.wikimedia.org/wikipedia/commons/7/72/Igel.JPG'
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (K
predictions = predict(url, model, transforms_fn, datasets['train'].classes, 4, headers=he
predictions

[{'label': 'Onion White', 'value': 0.6845197677612305},


{'label': 'Fig', 'value': 0.20679008960723877},
{'label': 'Mandarine', 'value': 0.0595860555768013},
{'label': 'Orange', 'value': 0.049104101955890656}]

most_likely(predictions)

'Onion White'

Or, sometimes, the image does belong to one of the classes, but the model is still confused
about it...

onion_url = 'https://www.haleo.co.uk/wp-content/uploads/2014/10/onions.jpg'
get_image_from_url(onion_url)

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 23/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

predictions = predict(onion_url, model, transforms_fn, datasets['train'].classes, 4)


predictions

[{'label': 'Onion White', 'value': 0.5000604391098022},


{'label': 'Orange', 'value': 0.4420488774776459},
{'label': 'Fig', 'value': 0.04643503949046135},
{'label': 'Mandarine', 'value': 0.011455634608864784}]

most_likely(predictions)

'Onion White'

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 24/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

One alternative to mitigate these issues is to use test-time augmentation. The idea is quite
simple: instead of making a single prediction (as shown above), submit several augmented
images to the model, and, just like in typical ensemble models, use a voting rule to pick the
winning prediction.

Let's create an ensemble of predictions by using the same random augmentations as the
training set (get_transform()). Keep in mind that, since these are random augmentations, your
results will be surely different than those shown below:

ensemble = [predictions] + [predict(onion_url, model, get_transforms(), datasets['train']


ensemble

[[{'label': 'Onion White', 'value': 0.5000604391098022},


{'label': 'Orange', 'value': 0.4420488774776459},
{'label': 'Fig', 'value': 0.04643503949046135},
{'label': 'Mandarine', 'value': 0.011455634608864784}],
[{'label': 'Onion White', 'value': 0.531319260597229},
{'label': 'Orange', 'value': 0.4170907437801361},
{'label': 'Fig', 'value': 0.03984801098704338},
{'label': 'Mandarine', 'value': 0.011741990223526955}],
[{'label': 'Orange', 'value': 0.47790488600730896},
{'label': 'Onion White', 'value': 0.45384833216667175},
{'label': 'Fig', 'value': 0.05601492151618004},
{'label': 'Mandarine', 'value': 0.012231873348355293}],
[{'label': 'Orange', 'value': 0.4900090992450714},
{'label': 'Onion White', 'value': 0.47705844044685364},
{'label': 'Fig', 'value': 0.028102917596697807},
{'label': 'Mandarine', 'value': 0.004829528275877237}],
[{'label': 'Onion White', 'value': 0.7307137250900269},
{'label': 'Orange', 'value': 0.20167745649814606},
{'label': 'Fig', 'value': 0.05790228396654129},
{'label': 'Mandarine', 'value': 0.009706567972898483}],
[{'label': 'Onion White', 'value': 0.716722846031189},
{'label': 'Orange', 'value': 0.15526121854782104},
{'label': 'Fig', 'value': 0.12147890031337738},
{'label': 'Mandarine', 'value': 0.006537081208080053}],
[{'label': 'Onion White', 'value': 0.5456726551055908},
{'label': 'Orange', 'value': 0.40326961874961853},
{'label': 'Fig', 'value': 0.02918623387813568},
{'label': 'Mandarine', 'value': 0.02187146432697773}]]

For each prediction, let's take the most likely class only:

[p[0] for p in ensemble]

[{'label': 'Onion White', 'value': 0.5000604391098022},


{'label': 'Onion White', 'value': 0.531319260597229},
{'label': 'Orange', 'value': 0.47790488600730896},
{'label': 'Orange', 'value': 0.4900090992450714},
{'label': 'Onion White', 'value': 0.7307137250900269},
{'label': 'Onion White', 'value': 0.716722846031189},
{'label': 'Onion White', 'value': 0.5456726551055908}]

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 25/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

There we go! Onion White is the winner!

Let's save our trained model to disk so we can load it and serve it using TorchServe in the next
chapter.

torch.save(model.state_dict(), 'fomo_model.pth')

keyboard_arrow_down Zero-Shot Image Classification


keyboard_arrow_down Zero-Shot Image Classification
Before moving on to serving a trained model, let's briefly discuss Hugging Face pipelines for
image classification once again. We already used them in "Pretrained Models for Natural
Language Processing" to easily make predictions. Let's try it with a couple of images:

url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/3/3e/Raccoon_in_Central_Park_
headers = {'User-Agent': 'CoolBot/0.0 (https://example.org/coolbot/; [email protected])
img1 = get_image_from_url(url, headers)
img1

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 26/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Sea_Gull_at_Point_Lobos_
headers = {'User-Agent': 'CoolBot/0.0 (https://example.org/coolbot/; [email protected])
img2 = get_image_from_url(url, headers)
img2

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 27/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab

We got the images, so let's load the same image classification pipeline we used before, and use
it to classify both images:

from transformers import pipeline


device = 0 if torch.cuda.is_available() else -1
classifier = pipeline('image-classification', model='google/vit-base-patch16-224', device
classifier([img1, img2])

https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 28/28

You might also like