Image Datasets, Dataloaders, and Transforms in Pytorch
Last Updated :
08 Jun, 2023
Deep learning in Pytorch is becoming increasingly popular due to its ease of use, support for multiple hardware platforms, and efficient processing. Image datasets, dataloaders, and transforms are essential components for achieving successful results with deep learning models using Pytorch.
In this article, we will discuss Image datasets, dataloaders, and transforms in Python using the Pytorch library. Image datasets store collections of images that can be used in deep-learning models for training, testing, or validation. These images are collected from a variety of sources such as online websites, physical controllers, user-generated content, etc. Dataloaders are responsible for loading the image datasets and providing them in batches to the models. Transforms are algorithms used to alter certain aspects of the images such as color, size, shape, brightness, etc. In Pytorch, these components can be used to create deep learning models for tasks such as object recognition, image classification, and image segmentation.
Popular datasets such as ImageNet, CIFAR-10, and MNIST can be used as the basis for creating image datasets and Dataloaders. Popular image transforms such as random rotation, random crop, random horizontal or vertical flipping, normalization, and color augmentation can be used to create model-ready data. Dataloaders can be used to efficiently load batches of data from the dataset for model training.
Image Datasets, Dataloaders, and Transforms
We will be implementing them on a sample dataset which can be downloaded from this link. You can download this dataset and follow along with this article to understand the concept better.
Import the necessary libraries
We will first import the libraries we will be using in this article.
Python3
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torchvision
|
Image Dataset
An image dataset can be created by defining the class which inherits the properties of torch.utils.data.Dataset class. This class has two abstract methods which have to be present in the derived class:
- __len__(): returns the number of samples present in the dataset.
- __getitem__(): returns the sample at the ith index from the dataset.
We can load the image dataset in Pytorch as follows:
Python3
class ImageDataset(torch.utils.data.Dataset):
def __init__( self , dir , transform = None ):
self .data_dir = dir
self .images = os.listdir( dir )
self .transform = transform
def __len__( self ):
return len ( self .images)
def __getitem__( self , index):
image_path = os.path.join( self .data_dir, self .images[index])
image = np.array(Image. open (image_path))
if self .transform:
image = self .transform(image)
return image
|
Now let us use this class on our sample dataset.
Python3
data_path = './maps/train'
dataset = ImageDataset(data_path)
dataset_length = len (dataset)
print ( 'Number of training examples:' ,dataset_length)
random_index = random.randint( 0 , dataset_length - 1 )
plt.imshow(dataset[random_index])
plt.show()
|
Output:
Number of training examples: 1096
Custom Transforms
A custom transform can be created by defining a class with a __call__() method. This transforms can be used for defining functions preprocessing and data augmentation. We can define a custom transform which performs preprocessing on the input image by splitting the image in two equal parts as follows:
Python3
class CustomTransform( object ):
def __init__( self , split_percent = 0.5 ):
self .split_percent = split_percent
def __call__( self , image):
split = int (image.shape[ 1 ] * self .split_percent)
image1 = image[:, :split, :]
image2 = image[:, split:, :]
return image1, image2
|
To use multiple transform objects in PyTorch, you can make use of the torchvision.transforms.Compose class. This class allows you to create an object that represents a composition of different transform objects while maintaining the order in which you want them to be applied.
Python3
transform = torchvision.transforms.Compose([
CustomTransform(),
])
|
Using this transform with the custom dataset class.
Python3
dataset = ImageDataset(data_path, transform = transform)
image, target = dataset[random_index]
plt.figure(figsize = ( 10 , 5 ))
plt.subplot( 1 , 2 , 1 )
plt.imshow(image)
plt.title( 'Image' )
plt.subplot( 1 , 2 , 2 )
plt.imshow(target)
plt.title( 'Target' )
plt.show()
|
Output:
.png)
Data augmentation
We can also define a transform to perform data augmentation. Data augmentation is a very useful tool when we have less dataset size and we want to increase the amount and diversity of data. Below is an example of a transform which performs random vertical flip and applies random color jittering to the input image.
Python3
class CustomAugmentation( object ):
def __init__( self , flip_prob = 0.5 , jitter_prob = 0.5 ):
self .flip_prob = flip_prob
self .jitter_prob = jitter_prob
def __call__( self , image):
if np.random.random() < self .flip_prob:
image = np.flip(image, axis = 1 )
if np.random.random() < self .jitter_prob:
image = np.array(image, dtype = np.int32)
image = image + np.random.randint( - 50 , 50 , size = image.shape, dtype = np.int32)
return image
|
Now we will define a transform based on the custom augmentation we defined earlier and display different variations of the target image.
Python3
aug_transform = torchvision.transforms.Compose([
CustomTransform(),
CustomAugmentation(),
])
nonaug_transform = torchvision.transforms.Compose([
CustomTransform(),
])
aug_dataset = ImageDataset(data_path, transform = aug_transform)
nonaug_dataset = ImageDataset(data_path, transform = nonaug_transform)
image, target = nonaug_dataset[random_index]
plt.figure(figsize = ( 10 , 10 ))
plt.subplot( 2 , 2 , 1 )
plt.imshow(target)
plt.title( 'Non augmented image' )
for i in range ( 2 , 5 ):
image, target1 = aug_dataset[random_index]
plt.subplot( 2 , 2 , i)
plt.imshow(target1)
plt.title( 'Augmented image' )
plt.show()
|
Output:
-(1)-(1).png)
Custom Dataloaders
A custom dataloader can be defined by wrapping the dataset along with torch.utils.data.DataLoader class. It enable us to control various aspects of data loader like batch size, number of workers, and whether to shuffle the data or not. We can define a custom data loader in Pytorch as follows:
Python3
dataloader = torch.utils.data.DataLoader(
dataset = dataset,
batch_size = 4 ,
shuffle = True ,
num_workers = 2
)
print ( 'Number of batches:' , len (dataloader))
|
Output:
Number of batches: 274
Training and testing dataset
Now, we will combine all these knowledge and use to define train and test dataset. We will perform preprocessing on both dataset while we will only perform augmentation on train dataset. The Pytorch implementation is as follows:
Python3
train_path = f './maps/train'
test_path = f './maps/val'
train_transform = torchvision.transforms.Compose([
CustomTransform(),
CustomAugmentation(),
])
test_transform = torchvision.transforms.Compose([
CustomTransform(),
])
train_dataset = ImageDataset(train_path, transform = train_transform)
test_dataset = ImageDataset(test_path, transform = test_transform)
train_dataloader = torch.utils.data.DataLoader(
dataset = train_dataset,
batch_size = 4 ,
shuffle = True ,
num_workers = 2
)
test_dataloader = torch.utils.data.DataLoader(
dataset = test_dataset,
batch_size = 1 ,
shuffle = False ,
num_workers = 2
)
print ( 'Number of training batches:' , len (train_dataloader))
print ( 'Number of testing batches:' , len (test_dataloader))
|
Output:
Number of training batches: 274
Number of testing batches: 1098
Similar Reads
Datasets And Dataloaders in Pytorch
PyTorch is a Python library developed by Facebook to run and train machine learning and deep learning models. Training a deep learning model requires us to convert the data into the format that can be processed by the model. PyTorch provides the torch.utils.data library to make data loading easy wit
3 min read
How to use a DataLoader in PyTorch?
Operating with large datasets requires loading them into memory all at once. In most cases, we face a memory outage due to the limited amount of memory available in the system. Also, the programs tend to run slowly due to heavy datasets loaded once. PyTorch offers a solution for parallelizing the da
2 min read
How to convert an image to grayscale in PyTorch
In this article, we are going to see how to convert an image to grayscale in PyTorch. torchvision.transforms.grayscale method Grayscaling is the process of converting an image from other color spaces e.g. RGB, CMYK, HSV, etc. to shades of gray. It varies between complete black and complete white. to
2 min read
How to crop an image at center in PyTorch?
In this article, we will discuss how to crop an image at the center in PyTorch. CenterCrop() method We can crop an image in PyTorch by using the CenterCrop() method. This method accepts images like PIL Image, Tensor Image, and a batch of Tensor images. The tensor image is a PyTorch tensor with [C, H
2 min read
Creating a Tensor in Pytorch
All the deep learning is computations on tensors, which are generalizations of a matrix that can be indexed in more than 2 dimensions. Tensors can be created from Python lists with the torch.tensor() function. The tensor() Method: To create tensors with Pytorch we can simply use the tensor() method:
6 min read
How to draw bounding boxes on an image in PyTorch?
In this article, we are going to see how to draw bounding boxes on an image in PyTorch. draw_bounding_boxes() method The draw_bounding_boxes function helps us to draw bounding boxes on an image. With tensor we provide shapes in [C, H, W], where C represents the number of channels and H, W represents
2 min read
Linear Transformation to incoming data in Pytorch
We could apply linear transformation to the incoming data using the torch.nn.Linear() module in PyTorch. This module is designed to create a Linear Layer in the neural networks. A linear layer computes the linear transformation as below- [Tex]y=xA^T+b [/Tex] Where [Tex]x [/Tex] is the incoming data.
5 min read
How to crop an image at random location in PyTorch
In this article, we will discuss how to pad an image on all sides in PyTorch. Torchvision.transforms.RandomCrop method Cropping is a technique of removal of unwanted outer areas from an image to achieve this we use a method in python that is torchvision.transforms.RandomCrop(). It is used to crop an
2 min read
Difference between detach, clone, and deepcopy in PyTorch tensors
In PyTorch, managing tensors efficiently while ensuring correct gradient propagation and data manipulation is crucial in deep learning workflows. Three important operations that deal with tensor handling in PyTorch are detach(), clone(), and deepcopy(). Each serves a unique purpose when working with
6 min read
Load a Computer Vision Dataset in PyTorch
Computer vision is a subset of Artificial Intelligence that gives the ability to the computer to understand images. In Deep Learning, Convolution Neural Network is used to process the image. For building the good we need a lot of images to process. There are several ways to load a computer vision da
3 min read