Image Classification - Building Image Classification Model
Image Classification - Building Image Classification Model
Q i
ARUNSPAI (HTTPS://ID.ANALYTICSVIDHYA.COM/ACCOUNTS/PROFILE/)
(https://www.analyticsvidhya.com/blog/)
(https://blackbelt.analyticsvidhya.com/plus?utm_source=blogtopbanner&utm_campaign=Blackbelt)
ADVANCED (HTTPS://WWW.ANALYTICSVIDHYA.COM/BLOG/CATEGORY/ADVANCED/)
CLASSIFICATION (HTTPS://WWW.ANALYTICSVIDHYA.COM/BLOG/CATEGORY/CLASSIFICATION/)
IMAGE (HTTPS://WWW.ANALYTICSVIDHYA.COM/BLOG/CATEGORY/IMAGE/)
PROJECT (HTTPS://WWW.ANALYTICSVIDHYA.COM/BLOG/CATEGORY/PROJECT/)
PYTHON (HTTPS://WWW.ANALYTICSVIDHYA.COM/BLOG/CATEGORY/PYTHON-2/)
SUPERVISED (HTTPS://WWW.ANALYTICSVIDHYA.COM/BLOG/CATEGORY/SUPERVISED/)
Introduction
“Build a deep learning model in a few minutes? It’ll take hours to train! I don’t even have a good enough machine.”
I’ve heard this countless times from aspiring data scientists who shy away from building deep learning models on
their own machines.
You don’t need to be working for Google or other big tech firms to work on deep learning datasets! It is entirely
possible to build your own neural network from the ground up in a matter of minutes without needing to lease out
Google’s servers. Fast.ai’s students designed a model on the Imagenet dataset in 18 minutes
(https://www.analyticsvidhya.com/blog/2018/08/fast-ais-algorithm-beat-googles-code-in-a-popular-image-
recognition-challenge/) – and I will showcase something similar in this article.
Deep learning is a vast field so we’ll narrow our focus a bit and take up the challenge of solving an Image
Classification project. Additionally, we’ll be using a very simple deep learning architecture to achieve a pretty
impressive accuracy score.
You can consider the Python code we’ll see in this article as a benchmark for building Image Classification
models. Once you get a good grasp on the concept, go ahead and play around with the code, participate in
competitions and climb up the leaderboard!
If you’re new to deep learning and are fascinated by the field of computer vision (who isn’t?!), do check out the
‘Computer Vision using Deep Learning (https://trainings.analyticsvidhya.com/courses/course-
v1:AnalyticsVidhya+CVDL101+CVDL101_T1/about?utm_source=imageclassarticle&utm_medium=blog)‘ course.
It’s a comprehensive introduction to this wonderful field and will set you up for what is inevitably going to a huge
job market in the near future.
Problem Statement
More than 25% of the entire revenue in E-Commerce is attributed to apparel & accessories. A major problem
they face is categorizing these apparels from just the images especially when the categories provided by the
brands are inconsistent. This poses an interesting computer vision problem that has caught the eyes of
several deep learning researchers.
Fashion MNIST is a drop-in replacement for the very well known, machine learning hello world – MNIST
dataset which can be checked out at ‘Identify the digits’ practice problem. Instead of digits, the images show a
type of apparel e.g. T-shirt, trousers, bag, etc. The dataset used in this problem was created by Zalando
Research.
Table of Contents
There are potentially n number of categories in which a given image can be classified. Manually checking and
classifying images is a very tedious process. The task becomes near impossible when we’re faced with a
massive number of images, say 10,000 or even 100,000. How useful would it be if we could automate this entire
process and quickly label images per their corresponding class?
Self-driving cars are a great example to understand where image classification is used in the real-world. To
enable autonomous driving, we can build an image classification model that recognizes various objects, such as
vehicles, people, moving objects, etc. on the road. We’ll see a couple more use cases later in this article but there
are plenty more applications around us. Use the comments section below the article to let me know what
potential use cases you can come with up!
Now that we have a handle on our subject matter, let’s dive into how an image classification model is built, what
are the prerequisites for it, and how it can be implemented in Python.
Our data needs to be in a particular format in order to solve an image classification problem. We will see this in
action in a couple of sections but just keep these pointers in mind till we get there.
You should have 2 folders, one for the train set and the other for the test set. In the training set, you will have a
.csv file and an image folder:
The .csv file contains the names of all the training images and their corresponding true labels
The image folder has all the training images.
The .csv file in our test set is different from the one present in the training set. This test set .csv file contains the
names of all the test images, but they do not have any corresponding labels. Can you guess why? Our model will
be trained on the images present in the training set and the label predictions will happen on the testing set
images
If your data is not in the format described above, you will need to convert it accordingly (otherwise the predictions
will be awry and fairly useless).
Before we deep dive into the Python code, let’s take a moment to understand how an image classification model
is typically designed. We can divide this process broadly into 4 stages. Each stage requires a certain amount of
time to execute:
Let me explain each of the above steps in a bit more detail. This section is crucial because not every model is
built in the first go. You will need to go back after each iteration, fine-tune your steps, and run it again. Having a
solid understanding of the underlying concepts will go a long way in accelerating the entire process.
Data is gold as far as deep learning models are concerned. Your image classification model has a far better
chance of performing well if you have a good amount of images in the training set. Also, the shape of the data
varies according to the architecture/framework that we use.
Hence, the critical data pre-processing step (the eternally important step in any project). I highly recommend
going through the ‘Basics of Image Processing in Python
(https://www.analyticsvidhya.com/blog/2014/12/image-processing-python-basics/)’ to understand more about
how pre-processing works with image data.
But we are not quite there yet. In order to see how our model performs on unseen data (and before exposing it to
the test set), we need to create a validation set. This is done by partitioning the training set data.
In short, we train the model on the training data and validate it on the validation data. Once we are satisfied with
the model’s performance on the validation set, we can use it for making predictions on the test data.
Time required for this step: We require around 2-3 minutes for this task.
This is another crucial step in our deep learning model building process. We have to define how our model will
look and that requires answering questions like:
And many more. These are essentially the hyperparameters of the model which play a MASSIVE part in deciding
how good the predictions will be.
How do we decide these values? Excellent question! A good idea is to pick these values based on existing
research/studies. Another idea is to keep experimenting with the values until you find the best match but this can
be quite a time consuming process.
Time required for this step: It should take around 1 minute to define the architecture of the model.
We also define the number of epochs in this step. For starters, we will run the model for 10 epochs (you can
change the number of epochs later).
Time required for this step: Since training requires the model to learn structures, we need around 5 minutes to go
through this step.
Stage 4: Estimating the model’s performance
Finally, we load the test data (images) and go through the pre-processing step here as well. We then predict the
classes for these images using the trained model.
We will be picking up a really cool challenge to understand image classification. We have to build a model that
can classify a given set of images according to the apparel (shirt, trousers, shoes, socks, etc.). It’s actually a
problem faced by many e-commerce retailers which makes it an even more interesting computer vision problem.
We have a total of 70,000 images (28 x 28 dimension), out of which 60,000 are from the training set and 10,000
from the test one. The training images are pre-labelled according to the apparel type with 10 total classes. The
test images are, of course, not labelled. The challenge is to identify the type of apparel present in all the test
images.
We will build our model on Google Colab (https://colab.research.google.com) since it provides a free GPU to train
our models.
Since we’re importing our data from a Google Drive link, we’ll need to add a few lines of code in our Google Colab
notebook. Create a new Python 3 notebook and write the following code blocks:
This will install PyDrive. Now we will import a few required libraries:
import os
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)
To download the dataset, we will use the ID of the file uploaded on Google Drive:
Replace the ‘id’ in the above code with the ID of your file. Now we will download this file and unzip it:
download.GetContentFile('train_LbELtWX.zip')
!unzip train_LbELtWX.zip
You have to run these code blocks every time you start your notebook.
Step 2 : Import the libraries we’ll need during our model building phase.
import keras
import numpy as np
import pandas as pd
Step 3: Recall the pre-processing steps we discussed earlier. We’ll be using them here after loading the data.
train = pd.read_csv('train.csv')
Next, we will read all the training images, store them in a list, and finally convert that list into a numpy array.
# We have grayscale images, so while loading the images we will keep grayscale=True, if you h
train_image = []
for i in tqdm(range(train.shape[0])):
grayscale=True)
train_image.append(img)
X = np.array(train_image)
As it is a multi-class classification problem (10 classes), we will one-hot encode the target variable.
y=train['label'].values
y = to_categorical(y)
We will create a simple architecture with 2 convolutional layers, one dense hidden layer and an output layer.
model = Sequential()
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))
model.compile(loss='categorical_crossentropy',optimizer='Adam',metrics=['accuracy'])
In this step, we will train the model on the training set images and validate it using, you guessed it, the validation
set.
We’ll initially follow the steps we performed when dealing with the training data. Load the test images and predict
their classes using the model.predict_classes() function.
download.GetContentFile('test_ScVgIM0.zip')
!unzip test_ScVgIM0.zip
test_image = []
for i in tqdm(range(test.shape[0])):
rayscale=True)
test_image.append(img)
test = np.array(test_image)
# making predictions
prediction = model.predict_classes(test)
We will also create a submission file to upload on the DataHack platform page (to see how our results fare on the
leaderboard).
download.GetContentFile('sample_submission_I5njJSF.csv')
sample = pd.read_csv('sample_submission_I5njJSF.csv')
sample['label'] = prediction
Download this sample_cnn.csv file and upload it on the contest page to generate your results and check your
ranking on the leaderboard. This will give you a benchmark solution to get you started with any Image
Classification problem!
You can try hyperparameter tuning and regularization techniques to improve your model’s performance further. I
ecnourage you to check out this article to understand this fine-tuning step in much more detail – ‘A
Comprehensive Tutorial to learn Convolutional Neural Networks from Scratch
(https://www.analyticsvidhya.com/blog/2018/12/guide-convolutional-neural-network-cnn/)’.
Let’s test our learning on a different dataset. We’ll be cracking the ‘Identify the Digits
(https://datahack.analyticsvidhya.com/contest/practice-problem-identify-the-digits/)’ practice problem in this
section. Go ahead and download the dataset. Before you proceed further, try to solve this on your own. You
already have the tools to solve it – you just need to apply them! Come back here to check your results or if you get
stuck at some point.
In this challenge, we need to identify the digit in a given image. We have a total of 70,000 images – 49,000
labelled ones in the training set and the remaining 21,000 in the test set (the test images are unlabelled). We need
to identify/predict the class of these unlabelled images.
Ready to begin? Awesome! Create a new Python 3 notebook and run the following code:
# Setting up Colab
import os
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)
download.GetContentFile('Train_UQcUa52.zip')
!unzip Train_UQcUa52.zip
# Importing libraries
import keras
import numpy as np
import pandas as pd
train = pd.read_csv('train.csv')
# Reading the training images
train_image = []
for i in tqdm(range(train.shape[0])):
le=True)
img = image.img_to_array(img)
img = img/255
train_image.append(img)
X = np.array(train_image)
y=train['label'].values
y = to_categorical(y)
model = Sequential()
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))
model.compile(loss='categorical_crossentropy',optimizer='Adam',metrics=['accuracy'])
# Training the model
download.GetContentFile('Test_fCbTej3.csv')
test_file = pd.read_csv('Test_fCbTej3.csv')
test_image = []
for i in tqdm(range(test_file.shape[0])):
scale=True)
test_image.append(img)
test = np.array(test_image)
prediction = model.predict_classes(test)
download.GetContentFile('Sample_Submission_lxuyBuB.csv')
sample = pd.read_csv('Sample_Submission_lxuyBuB.csv')
sample['filename'] = test_file['filename']
sample['label'] = prediction
Submit this file on the practice problem page to get a pretty decent accuracy number. It’s a good start but there’s
always scope for improvement. Keep playing around with the hyperparameter values and see if you can improve
on our basic model.
End Notes
Who said deep learning models required hours or days to train. My aim here was to showcase that you can come
up with a pretty decent deep learning model in double-quick time. You should pick up similar challenges and try
to code them from your end as well. There’s nothing like learning by doing!
The top data scientists and analysts have these codes ready before a Hackathon
(https://datahack.analyticsvidhya.com/) even begins. They use these codes to make early submissions before
diving into a detailed analysis. Once they have a benchmark solution, they start improving their model using
different techniques.
Did you find this article helpful? Do share your valuable feedback in the comments section below. Feel free to
share your complete code notebooks as well which will be helpful to our community members.
You can also read this article on our Mobile APP (//play.google.com/store/apps/details?
id=com.analyticsvidhya.android&utm_source=blog_article&utm_campaign=blog&pcampaignid=MKT-Other-global-
all-co-prtnr-py-PartBadge-Mar2515-1) (https://apps.apple.com/us/app/analytics-
vidhya/id1470025572)
Get Started with PyTorch – Learn How to Build Quick & Accurate Neural Networks (with 4 Case
Studies!)
(https://www.analyticsvidhya.com/blog/2019/01/guide-pytorch-neural-networks-case-studies/)
h
PREVIOUS ARTICLE
(https://www.analyticsvidhya.com/blog/author/pulkits/)
My research interests lies in the field of Machine Learning and Deep Learning. Possess an enthusiasm for
learning new skills and technologies.
This article is quite old and you might not get a prompt response from the author. We request you to post
this comment on Analytics Vidhya's Discussion portal (https://discuss.analyticsvidhya.com/) to get your
queries resolved
60 COMMENTS