Open In App

Image Classification Using PyTorch Lightning

Last Updated : 05 Aug, 2025
Comments
Improve
Suggest changes
Like Article
Like
Report

PyTorch Lightning is a lightweight wrapper for PyTorch. It is designed to organise PyTorch code to be more readable and reproducible making research easier by handling many engineering details automatically. It abstracts away hardware management, distributed training, checkpointing and more allowing us to focus on designing experiments and models. It has advantages like:

  • Modular Code: Separates research code from engineering code.
  • Scalability: Effortlessly uses GPUs/TPUs or distributed setups.
  • Simplicity: Training loops, logging and checkpointing are handled for us.
  • Reliability: Code is less error-prone and easier to debug.

Step-by-Step Implementation

Let's see the step-by-step implementation of Image Classification with PyTorch Lightning,

Step 1: Installing and Importing Required Libraries

Here we will use pytorch which is the main library for building and training deep learning models in Python.

  • torch.nn: Contains layers like convolution and fully connected used to build neural networks.
  • torch.nn.functional: Offers activation and loss functions (e.g., relu, nll_loss) that are used inside model methods.
  • torch.utils.data.DataLoader: Loads datasets in convenient batches for training and validation.
  • torchvision: Provides popular datasets (like CIFAR-10) and image transformation functions (e.g., ToTensor, Normalize).
  • pytorch_lightning: High-level framework that organizes PyTorch code making training, validation and scaling easier and cleaner.
  • pytorch_lightning.Trainer: Automates the loop for training and testing models, supports features like logging and hardware acceleration.
Python
!pip install pytorch_lightning torchvision torch

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import pytorch_lightning as pl
from pytorch_lightning import Trainer

Step 2: Dataset Preparation

For this tutorial, we'll use the CIFAR-10 dataset, which consists of 60,000 32x32 color images in 10 different classes. PyTorch provides easy access to this dataset through torchvision.datasets.

  • transforms.ToTensor() converts PIL images to tensors.
  • transforms.Normalize() standardizes the data.
  • Data is loaded into PyTorch DataLoader for batching and shuffling.
Python
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_data = datasets.CIFAR10(
    root="data", train=True, download=True, transform=transform)
val_data = datasets.CIFAR10(
    root="data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)

Step 3: Define the Image Classification Model

We'll define a simple convolutional neural network (CNN) for image classification. The model will consist of convolutional layers followed by fully connected layers.

  • Two convolutional layers extract hierarchical features.
  • Max pooling layers reduce spatial dimensions.
  • Two fully connected layers map extracted features to class scores.
  • The output uses log-softmax for improved numerical stability.
Python
class ImageClassifier(pl.LightningModule):
    def __init__(self):
        super(ImageClassifier, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(64 * 6 * 6, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 64 * 6 * 6)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

Step 4: Configure the Training and Validation Process

PyTorch Lightning abstracts many training details, allowing us to focus on the core logic. We need to define the training and validation steps and specify the optimizer and loss function.

  • The training_step method handles the forward pass and computes training loss for each batch.
  • The validation_step computes validation loss.
  • self.log() enables automatic progress monitoring.
  • configure_optimizers returns an optimizer (Adam).
Python
class ImageClassifier(pl.LightningModule):
    def __init__(self):
        super(ImageClassifier, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(64 * 6 * 6, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 64 * 6 * 6)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = F.nll_loss(outputs, labels)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        val_loss = F.nll_loss(outputs, labels)
        self.log('val_loss', val_loss)
        return val_loss

    def test_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        test_loss = F.nll_loss(outputs, labels)
        preds = torch.argmax(outputs, dim=1)
        accuracy = (preds == labels).float().mean()
        self.log('test_loss', test_loss)
        self.log('test_accuracy', accuracy)
        return test_loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

Step 5: Training and Validate Model

The Trainer manages the entire training and validation lifecycle, including logging and checkpointing.

Python
model = ImageClassifier()

trainer = Trainer(max_epochs=5, devices=1, accelerator="cpu")

trainer.fit(model, train_loader, val_loader)

Output:

Screenshot-2025-08-01-165406
Training

Step 6: Model Testing and Evaluation

Once the model is trained, we can test its performance on unseen data,

  • trainer.test() evaluates the trained model on unseen data.
  • Saving and loading checkpoints allows for easy model reuse or deployment.
Python
trainer.test(model, val_loader)

trainer.save_checkpoint("image_classifier.ckpt")
model = ImageClassifier.load_from_checkpoint("image_classifier.ckpt")

Output:

Screenshot-2025-08-01-165350
Output

Image classification is a fundamental task in deep learning and PyTorch Lightning provides an elegant and efficient framework to build, train and scale image classification models. With its organized structure, automatic checkpointing and scalability features, PyTorch Lightning accelerates the research and development process while minimizing boilerplate code.

You can download source code from here.


Similar Reads