Image Classification Using PyTorch Lightning
Last Updated :
05 Aug, 2025
PyTorch Lightning is a lightweight wrapper for PyTorch. It is designed to organise PyTorch code to be more readable and reproducible making research easier by handling many engineering details automatically. It abstracts away hardware management, distributed training, checkpointing and more allowing us to focus on designing experiments and models. It has advantages like:
- Modular Code: Separates research code from engineering code.
- Scalability: Effortlessly uses GPUs/TPUs or distributed setups.
- Simplicity: Training loops, logging and checkpointing are handled for us.
- Reliability: Code is less error-prone and easier to debug.
Step-by-Step Implementation
Let's see the step-by-step implementation of Image Classification with PyTorch Lightning,
Step 1: Installing and Importing Required Libraries
Here we will use pytorch which is the main library for building and training deep learning models in Python.
- torch.nn: Contains layers like convolution and fully connected used to build neural networks.
- torch.nn.functional: Offers activation and loss functions (e.g., relu, nll_loss) that are used inside model methods.
- torch.utils.data.DataLoader: Loads datasets in convenient batches for training and validation.
- torchvision: Provides popular datasets (like CIFAR-10) and image transformation functions (e.g., ToTensor, Normalize).
- pytorch_lightning: High-level framework that organizes PyTorch code making training, validation and scaling easier and cleaner.
- pytorch_lightning.Trainer: Automates the loop for training and testing models, supports features like logging and hardware acceleration.
Python
!pip install pytorch_lightning torchvision torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import pytorch_lightning as pl
from pytorch_lightning import Trainer
Step 2: Dataset Preparation
For this tutorial, we'll use the CIFAR-10 dataset, which consists of 60,000 32x32 color images in 10 different classes. PyTorch provides easy access to this dataset through torchvision.datasets.
transforms.ToTensor()
converts PIL images to tensors.transforms.Normalize()
standardizes the data.- Data is loaded into PyTorch DataLoader for batching and shuffling.
Python
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_data = datasets.CIFAR10(
root="data", train=True, download=True, transform=transform)
val_data = datasets.CIFAR10(
root="data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)
Step 3: Define the Image Classification Model
We'll define a simple convolutional neural network (CNN) for image classification. The model will consist of convolutional layers followed by fully connected layers.
- Two convolutional layers extract hierarchical features.
- Max pooling layers reduce spatial dimensions.
- Two fully connected layers map extracted features to class scores.
- The output uses log-softmax for improved numerical stability.
Python
class ImageClassifier(pl.LightningModule):
def __init__(self):
super(ImageClassifier, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(64 * 6 * 6, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(-1, 64 * 6 * 6)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
PyTorch Lightning abstracts many training details, allowing us to focus on the core logic. We need to define the training and validation steps and specify the optimizer and loss function.
- The
training_step
method handles the forward pass and computes training loss for each batch. - The
validation_step
computes validation loss. self.log()
enables automatic progress monitoring.configure_optimizers
returns an optimizer (Adam).
Python
class ImageClassifier(pl.LightningModule):
def __init__(self):
super(ImageClassifier, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(64 * 6 * 6, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = x.view(-1, 64 * 6 * 6)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
inputs, labels = batch
outputs = self(inputs)
loss = F.nll_loss(outputs, labels)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
inputs, labels = batch
outputs = self(inputs)
val_loss = F.nll_loss(outputs, labels)
self.log('val_loss', val_loss)
return val_loss
def test_step(self, batch, batch_idx):
inputs, labels = batch
outputs = self(inputs)
test_loss = F.nll_loss(outputs, labels)
preds = torch.argmax(outputs, dim=1)
accuracy = (preds == labels).float().mean()
self.log('test_loss', test_loss)
self.log('test_accuracy', accuracy)
return test_loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
Step 5: Training and Validate Model
The Trainer
manages the entire training and validation lifecycle, including logging and checkpointing.
Python
model = ImageClassifier()
trainer = Trainer(max_epochs=5, devices=1, accelerator="cpu")
trainer.fit(model, train_loader, val_loader)
Output:
TrainingStep 6: Model Testing and Evaluation
Once the model is trained, we can test its performance on unseen data,
trainer.test()
evaluates the trained model on unseen data.- Saving and loading checkpoints allows for easy model reuse or deployment.
Python
trainer.test(model, val_loader)
trainer.save_checkpoint("image_classifier.ckpt")
model = ImageClassifier.load_from_checkpoint("image_classifier.ckpt")
Output:
OutputImage classification is a fundamental task in deep learning and PyTorch Lightning provides an elegant and efficient framework to build, train and scale image classification models. With its organized structure, automatic checkpointing and scalability features, PyTorch Lightning accelerates the research and development process while minimizing boilerplate code.
You can download source code from here.
Explore
Deep Learning Basics
Neural Networks Basics
Deep Learning Models
Deep Learning Frameworks
Model Evaluation
Deep Learning Projects