10 - Fine-Tuning - Pretrained - Models - For - Computer Vision - Ipynb - Colab
10 - Fine-Tuning - Pretrained - Models - For - Computer Vision - Ipynb - Colab
ipynb - Colab
keyboard_arrow_down Overview
In this chapter, you'll learn how to fine-tune a pretrained model for computer vision to perform
multi-class classification while using data augmentation to improve its performance. We'll be
using a subset of a popular dataset, Fruits-360, to fine-tune a relatively lightweight model,
ResNet18, to classify images from four different fruits and vegetables: figs, oranges,
mandarines, and onions (FOMO, for short).
By the end of this chapter, you'll be familiar with data augmentation techniques for images and
you'll know how to freeze parts of a model in order to speed up training. You'll also be able to
make your model give you an "I don't know" answer and to use test-time augmentation to handle
"split" predictions.
To run this notebook on Google Colab, you will need to install the following libraries:
transformers and evaluate.
In Google Colab, you can run the following command to install these libraries:
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 1/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
Collecting xxhash (from evaluate)
Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.w
Collecting multiprocess (from evaluate)
Downloading multiprocess-0.70.17-py310-none-any.whl.metadata (7.2 kB)
Requirement already satisfied: fsspec>=2021.05.0 in /usr/local/lib/python3.10/dist-
Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-pa
Collecting dill (from evaluate)
Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting multiprocess (from evaluate)
Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec>=2021.05.0 (from fsspec[http]>=2021.05.0->evaluate)
Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.1
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packa
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist
Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packa
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-pac
Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-p
Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-pack
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dis
Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-p
Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 84.0/84.0 kB 4.8 MB/s eta 0:00:00
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 480.6/480.6 kB 19.4 MB/s eta 0:00:00
Downloading dill-0.3.8-py3-none-any.whl (116 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 116.3/116.3 kB 8.9 MB/s eta 0:00:00
Downloading fsspec-2024.9.0-py3-none-any.whl (179 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 179.3/179.3 kB 14.5 MB/s eta 0:00:00
l di lti 0 0 3 0 hl ( 3 k )
Moreover, the get_image_from_url() and save_images() functions can be easily imported from a
set of helper functions we're making available for your convenience. You can download it from
the following link:
https://raw.githubusercontent.com/lftraining/LFD273-code/main/helper_functions.py
In Google Colab, you can run the following command to download the file:
!wget https://raw.githubusercontent.com/lftraining/LFD273-code/main/helper_functions.py
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 2/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
helper_functions.py 100%[===================>] 3.50K --.-KB/s in 0s
Once the file is downloaded, you only need to import the required helper functions:
Fine-tuning a model entails updating its weights by training it further on your own data. Just
like feature extraction, we need to replace the model's "head" with our own to account for a
different number of classes. However, unlike feature extraction, we won't be training the new
"head" separately.
We'll be using ResNet18 once again but, this time, our task is going to be multi-class
classification.
You can build your own FOMO dataset by (partially) cloning the original repository from the
Fruits-360 dataset as shown below (if you're not running the commands below on Google Colab,
we highly recommend you download the prepared dataset instead):
Alternatively, you can download the version we have already prepared from the link below:
https://raw.githubusercontent.com/lftraining/LFD273-code/main/data/Fruits360/FOMO.tar.gz
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 3/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
Once the file is downloaded, you need to uncompress it. If you are using Google Colab, simply
run the commands below to perform both actions:
!wget https://raw.githubusercontent.com/lftraining/LFD273-code/main/data/Fruits360/FOMO.t
!tar -xvzf FOMO.tar.gz
FOMO-Dataset/Training/
FOMO-Dataset/Training/Orange/
FOMO-Dataset/Training/Orange/21_100.jpg
FOMO-Dataset/Training/Orange/r_232_100.jpg
FOMO-Dataset/Training/Orange/296_100.jpg
FOMO-Dataset/Training/Orange/r_0_100.jpg
FOMO-Dataset/Training/Orange/255_100.jpg
FOMO-Dataset/Training/Orange/276_100.jpg
FOMO-Dataset/Training/Orange/110_100.jpg
FOMO-Dataset/Training/Orange/r_161_100.jpg
FOMO-Dataset/Training/Orange/r_156_100.jpg
FOMO-Dataset/Training/Orange/290_100.jpg
FOMO-Dataset/Training/Orange/130_100.jpg
FOMO-Dataset/Training/Orange/r_246_100.jpg
FOMO-Dataset/Training/Orange/r_284_100.jpg
FOMO-Dataset/Training/Orange/207_100.jpg
FOMO-Dataset/Training/Orange/r_318_100.jpg
FOMO-Dataset/Training/Orange/r_180_100.jpg
FOMO-Dataset/Training/Orange/r_259_100.jpg
FOMO-Dataset/Training/Orange/151_100.jpg
FOMO-Dataset/Training/Orange/260_100.jpg
FOMO-Dataset/Training/Orange/r_299_100.jpg
FOMO-Dataset/Training/Orange/r_211_100.jpg
FOMO-Dataset/Training/Orange/137_100.jpg
FOMO-Dataset/Training/Orange/r_178_100.jpg
FOMO-Dataset/Training/Orange/132_100.jpg
FOMO-Dataset/Training/Orange/100_100.jpg
FOMO-Dataset/Training/Orange/r_127_100.jpg
FOMO-Dataset/Training/Orange/r_234_100.jpg
FOMO-Dataset/Training/Orange/r_155_100.jpg
FOMO-Dataset/Training/Orange/r_294_100.jpg
FOMO-Dataset/Training/Orange/r_15_100.jpg
FOMO-Dataset/Training/Orange/r_313_100.jpg
FOMO-Dataset/Training/Orange/r_136_100.jpg
FOMO-Dataset/Training/Orange/r_119_100.jpg
FOMO-Dataset/Training/Orange/r_194_100.jpg
FOMO-Dataset/Training/Orange/135_100.jpg
FOMO-Dataset/Training/Orange/r_273_100.jpg
FOMO-Dataset/Training/Orange/251_100.jpg
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 4/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
FOMO-Dataset/Training/Orange/226_100.jpg
FOMO-Dataset/Training/Orange/r_189_100.jpg
FOMO-Dataset/Training/Orange/278_100.jpg
FOMO-Dataset/Training/Orange/r_148_100.jpg
FOMO-Dataset/Training/Orange/r_175_100.jpg
FOMO-Dataset/Training/Orange/127_100.jpg
FOMO D t t/T i i /O /250 100 j
The FOMO dataset contains only images of four fruits/vegetables for easier and faster training.
It is organized into two main folders, Training and Test, and each folder has four subfolders, one
for each class, Fig, Mandarine, Onion White, and Orange. There are 2,109 images in total for
training, and 706 images for testing.
Training/Fig/0_100.jpg
Training/Fig/1_100.jpg
...
Training/Orange/0_100.jpg
Training/Orange/1_100.jpg
...
Test/Fig/27_100.jpg
Test/Fig/28_100.jpg
...
What we're actually missing in our data pipeline is some data augmentation!
In our previous, concrete, example, we used a pretrained model as a feature extractor, and we
saved resources during training by extracting features from all images in our dataset
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 5/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
beforehand.
On the one hand, it was much faster to train the model's new "head". On the other hand, we
deprived ourselves of the possibility of using data augmentation on our images. This wasn't
much of an issue in our previous example because we had more than 30,000 images to train our
model; it wasn't like we were lacking data.
Now, we have a little bit less than 3,000 images that we can use to train our model. Still far from
"too few" images, but our training may still benefit from some new, augmented, data.
Let's take a look at what data augmentation is and how we can easily augment our data using
Torchvision's transforms.
fig.transpose(Image.FLIP_LEFT_RIGHT)
fig.rotate(45)
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 6/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
We could go on and on with this, but our point should already be clear: we can make changes to
the original image such as rotating it, flipping it, zooming, cropping, adding noise to it, and it will
still be an image of a fig, as if the picture were taken from a different angle, or from a different
viewpoint, for example.
Augmented images may not be as good as the original ones, but they often suffice for the
purpose of training a model and, best of all, they're free! Torchvision offers plenty of different
transforms, divided into several groups.
Note: Don't mind the transformations prefixed with v2 just yet. They perform exactly the same
transformations on images as their counterparts, and they were created to handle additional
objects as well, such as bounding boxes. We'll get to them later on.
One particularly interesting transformation is the RandAugment() which, as its name suggests,
applies random augmenting transformations to the images, so let's try it out.
It is important to highlight that data augmentation belongs in the training set only as far as we're
concerned here, so we need to have different transformations for the training and validation/test
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 7/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
sets. Both transformations must include the prescribed transformations of the underlying model
(ResNet18 in our case), but only the former should include augmentation.
Let's create a helper function that returns a composition of transformations according to the
split they're going to be applied to:
Once the transformations are taken care of, we can once again create ImageFolder datasets
using them:
The transformation is an attribute of the dataset, as we can see below. Notice that it was
wrapped with a StandardTransform. The StandardTransform combines both transform (applied
to the dataset's features) and target_transform (obviously applied to the dataset's targets) in a
single object.
The base class for making datasets compatible with Torchvision, VisionDataset, which
ImageFolder inherits from, actually has three arguments for transformations:
transforms: a callable that takes both an image and label and transforms both
transform: a callable that takes an image and transforms it
target_transform: a callable that takes a label/target and transforms it
As you probably guessed, they are mutually exclusive: you can either use transforms alone, or a
combination of transform and target_transform. If you choose separate transformations, they
will nonetheless be internally combined as a single StandardTransform.
datasets['train'].transforms
StandardTransform
Transform: Compose(
RandAugment(num_ops=2, magnitude=9, num_magnitude_bins=31,
interpolation=InterpolationMode.NEAREST, fill=None)
ImageClassification(
crop_size=[224]
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 8/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
resize_size=[256]
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
interpolation=InterpolationMode.BILINEAR
)
)
keyboard_arrow_down Model
We'll keep using ResNet18 as our main model, as it is powerful and small enough for our needs.
Let's load its pretrained weights one more time:
import torch
repo = 'pytorch/vision'
weights = get_weight('ResNet18_Weights.DEFAULT')
model = torch.hub.load(repo, 'resnet18', weights=weights)
It works like a switch: if it's on, that layer will learn/be updated during backpropagation but, if it's
off, it will simply produce outputs through its forward() method and nothing else.
We can easily freeze the whole model using a simple helper function, as shown below:
def freeze_model(model):
for parameter in model.parameters():
parameter.requires_grad = False
freeze_model(model)
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 9/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
Of course, a completely frozen model cannot learn anything, so what's the point of that? Well,
even if the model is frozen, we may still replace its parts, so let's replace its "head" (the fc layer)
with our own:
import torch.nn as nn
# Classification Head for the FOMO dataset
torch.manual_seed(42)
model.fc = nn.Linear(512, 4)
Newly created layers are, by default, not frozen. So, if we double-check which parameters from
our modified Resnet18 are capable of learning, what would be the result?
fc.weight
fc.bias
Only tensors from our new fc layer, its weight and its bias, can learn anything. The whole "body"
of the model is simply producing outputs that will be fed into the last, learnable, layer.
Sounds familiar? We're doing feature extraction again! Instead of creating an explicit headless
model to extract features and then training a new "head" to be reattached later, we're doing
everything in one go.
Our initial motivation for this was data augmentation, but there is an added bonus to this
approach: we can choose how many parts of the underlying ResNet18 model we want to
actually freeze. The number of layers you may choose not to freeze is roughly a function of the
size of your dataset. The more data you have, the more layers you can afford to update.
Technically speaking, we're only truly fine-tuning a model if we do not freeze any of its
pretrained weights. On the other side of the spectrum, if we freeze everything but the
classification "head", we're doing feature extraction.
In this example, we'll keep everything but the "head" frozen - so we're effectively doing feature
extraction once again. But we encourage you to try unfreezing everything and truly fine-tune your
modified ResNet18 model, just keep in mind that the training process will take longer.
Moreover, if you'd like to train different base models, such as AlexNet, VGG, or Inception, to
name a few, please refer to the table below for each model's expected input image size, the
layer(s) corresponding to the model's "head", and the appropriate replacement layer
(considering the number of extracted features):
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 10/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
Notice that the Inception model has two "heads", as it uses a side (auxiliary) "head" during
training, originally developed to mitigate the vanishing gradients problem that makes stop
learning, but that proved to have a regularizer effect instead. During evaluation/prediction time,
though, only the main classifier "head" is used.
So, let's load a mini-batch from our data loader and pass it through our model, just to make sure
we're getting the expected output back, namely, four logits (one for each class) for each data
point:
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 11/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
[-0.4013, -0.3365, 0.5862, 0.2260],
[-1.0981, 0.1739, 0.8421, -0.4000],
[-0.7232, -0.1500, 0.3659, -0.0563],
[-1.0256, -0.7758, 0.3387, 0.2696],
[-0.4930, -0.0232, 0.2135, -0.6374],
[-0.9779, -0.4253, -0.1125, -0.7417],
[-2.0073, -0.2654, 0.5722, 0.4444],
[-0.9135, 0.0850, 0.0961, -0.6709],
[-0.8877, -1.1032, 0.5218, -0.1693],
[-0.8426, -0.6283, 0.9137, 0.2830],
[-0.5762, -0.2691, 0.8709, 0.2452],
[-0.6045, -0.4768, 0.7342, -0.5818],
[-0.5095, -0.4764, -0.0650, -0.7979],
[-0.1320, -0.0939, 0.4242, -0.1109],
[-0.3943, 0.1010, 1.0155, -0.0632],
[-0.6062, -0.5808, 0.5493, 0.2789],
[-0.8768, -0.5799, 0.3674, 0.0076],
[-1.7618, -0.4277, 0.0644, -0.1051],
[-1.0389, -0.1231, 0.1616, 0.0137],
[-0.6106, -0.4904, 0.7372, 0.1672],
[-1.1167, -0.4123, -0.0733, -0.0666]], device='cuda:0',
grad_fn=<AddmmBackward0>)
Let's fine-tune our model now! Sure, we're not as much fine-tuning it as we're extracting features
from it, but the training loop itself won't change regardless of how many layers you choose to
freeze or not.
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 12/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
loss_fn = nn.CrossEntropyLoss()
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 13/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
loss.backward()
batch_losses.append(loss.item())
# Step 4 - updating parameters and zeroing gradients
optimizer.step()
optimizer.zero_grad()
Your plot should look like the one above, but it will surely be slightly different since we're using
random augmentations during training.
Then, let's compute losses for our validation set and plot them as well:
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 14/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
## Validation
with torch.inference_mode():
val_losses = []
for i, (val_features, val_targets) in enumerate(dataloaders['test']):
model.eval()
val_features = val_features.to(device)
val_targets = val_targets.to(device)
# Step 1 - forward pass
predictions = model(val_features)
# Step 2 - computing the loss
loss = loss_fn(predictions, val_targets)
val_losses.append(loss.item())
plt.plot(val_losses)
plt.xlabel('Mini-Batches')
plt.ylabel('Loss')
plt.title('Validation Loss')
plt.legend()
That's a pretty atypical loss plot, right? But keep in mind that this plot is not over epochs, but
over mini-batches. Besides, remember that we did not shuffle the validation set, so the plot is
actually showing us there's one particular subset of our data that's more problematic than the
others.
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 15/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
By the way, your plot may look slightly different once again. Since the model itself was trained
on data that was randomly augmented, it will produce slightly different validation losses and
predictions too, as we'll shortly see.
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
datasets['test'].class_to_idx
keyboard_arrow_down Evaluation
We'll start with our model's accuracy on the training set:
import evaluate
accuracy = evaluate.load("accuracy")
model.eval()
for features, targets in dataloaders['train']:
pred = model(features.to(device))
pred_class = torch.nn.functional.softmax(pred, dim=1).argmax(dim=1)
accuracy.add_batch(references=targets, predictions=pred_class)
accuracy.compute()
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 16/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarnin
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public mo
warnings.warn(
Downloading builder script: 100% 4.20k/4.20k [00:00<00:00, 282kB/s]
{'accuracy': 0.9947842579421526}
Not bad, right? But, if we run it more than once, we may get slightly different results, which is
puzzling since we did set our model to evaluation mode. What can possibly be happening here?
It turns out, we are augmenting our training set using the RandAugment() transformation,
remember? It means that we'll get a slightly different, augmented, image every time we fetch it!
Therefore, we need to create a plain-vanilla dataset and data loaders that do not randomly
augment the training images, so we can get a consistent evaluation metric for our training set:
{'accuracy': 1.0}
accuracy = evaluate.load("accuracy")
model.eval()
for features, targets in dataloaders['test']:
pred = model(features.to(device))
pred_class = torch.nn.functional.softmax(pred, dim=1).argmax(dim=1)
accuracy.add_batch(references=targets, predictions=pred_class)
accuracy.compute()
{'accuracy': 1.0}
keyboard_arrow_down Inference
Our model is doing pretty well, but does it stand the test of real-world images? Let's find out!
We can use the same save_images() function as in "Transfer Learning and Pretrained Models"
to download a few images of fruits and vegetables.
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 17/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
If you prefer not to signup for an API key, you may skip the next cell and use only the provided
URLs in the cells that follow it.
Feel free to use those images to try out the model. We'll be using a few images that are typically
downloaded by that function (at the time of writing) to illustrate some ideas below. However,
since it is possible that the save_images() function downloads different images for you, we're
using the original images instead:
fig_url = 'https://cdn.pixabay.com/photo/2012/09/08/17/38/euro-dynasty-56405_1280.jpg'
get_image_from_url(fig_url)
We can also use the previous chapter's predict() function once again:
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 18/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
Let's use the function to make a prediction for our fig (don't worry if you get different values,
that's expected since each trained model will be unique thanks to random augmentation during
training):
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 19/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
What if we show our model a pineapple? It did not see pineapples before, so what do you think
will happen?
pineapple_url = 'https://storage.needpix.com/rsynced_images/pineapple-1477419208LRZ.jpg'
get_image_from_url(pineapple_url)
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 20/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 21/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
Once again, don't mind if you got different values, or even a different class at the top. Just like in
the previous call to the prediction function, this is expected. The same holds true for all calls to
the predict function we'll make until the end of this chapter. The model output suggests it's more
likely to be a mandarine, but then again, it only "knows" mandarines, onions, oranges, and figs.
Still, the highest probability (mandarine) is only 43%, and we can actually use this information to
our advantage. Instead of blindly returning the class with the highest probability back to the
user, why not use an acceptance threshold? If the highest probability falls under the threshold, it
may be the case that the input image does not actually belong to the distribution the model was
initially trained for, that is, it may be an image of a completely different object (e.g. pineapples).
def most_likely(predictions):
category, prob = predictions[0]['label'], predictions[0]['value']
if prob >= .5: # choose your threshold here
return category
else:
return 'Uncertain'
most_likely(predictions)
'Uncertain'
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 22/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
Unfortunately, sometimes the model will be very confident about its predictions, even if the input
image does not belong to one of the trained classes...
url = 'https://upload.wikimedia.org/wikipedia/commons/7/72/Igel.JPG'
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (K
predictions = predict(url, model, transforms_fn, datasets['train'].classes, 4, headers=he
predictions
most_likely(predictions)
'Onion White'
Or, sometimes, the image does belong to one of the classes, but the model is still confused
about it...
onion_url = 'https://www.haleo.co.uk/wp-content/uploads/2014/10/onions.jpg'
get_image_from_url(onion_url)
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 23/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
most_likely(predictions)
'Onion White'
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 24/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
One alternative to mitigate these issues is to use test-time augmentation. The idea is quite
simple: instead of making a single prediction (as shown above), submit several augmented
images to the model, and, just like in typical ensemble models, use a voting rule to pick the
winning prediction.
Let's create an ensemble of predictions by using the same random augmentations as the
training set (get_transform()). Keep in mind that, since these are random augmentations, your
results will be surely different than those shown below:
For each prediction, let's take the most likely class only:
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 25/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
Let's save our trained model to disk so we can load it and serve it using TorchServe in the next
chapter.
torch.save(model.state_dict(), 'fomo_model.pth')
url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/3/3e/Raccoon_in_Central_Park_
headers = {'User-Agent': 'CoolBot/0.0 (https://example.org/coolbot/; [email protected])
img1 = get_image_from_url(url, headers)
img1
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 26/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Sea_Gull_at_Point_Lobos_
headers = {'User-Agent': 'CoolBot/0.0 (https://example.org/coolbot/; [email protected])
img2 = get_image_from_url(url, headers)
img2
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 27/28
2/5/25, 10:54 AM 10_Fine-Tuning_Pretrained_Models_for_Computer Vision.ipynb - Colab
We got the images, so let's load the same image classification pipeline we used before, and use
it to classify both images:
https://colab.research.google.com/drive/1S6f5yOvBnt9yATd1tz2E0bWNVSWzvw7D?usp=sharing 28/28