PyTorch Lightning Tutorial: : Simplifying Deep Learning with PyTorch
Last Updated :
23 Jul, 2025
Pytorch-Lightning is an open source library that extends the library PyTorch. It is a useful library as it provides direct approach for training and testing loops thereby making codes simple and also reducing lines of code. This library is also used for multi GPU training, distribution training etc. Some other features of Pytorch-Lightning are as follows:
- Integration with Loggers like CSV Logger, Tensorboard Logger.
- We can use checkpoints to save model weights during training phase.
- Customize callbacks to get details of the metrics programmatically.
Setup and Installation
In this step we will simply create a PyTorch model and utilize Pytorch-Lightning for training and testing of the model. Here we have used MNIST dataset. But before that we need to install libraries using pip or conda.
pip install torch torchvision pytorch-lightning torchmetrics comet-ml
The article is thoughtfully divided into three progressive sections—Beginner, Intermediate, and Advanced tutorials—each designed to cater to varying levels of expertise and to systematically build the reader's proficiency with PyTorch-Lightning.
Beginners Tutorial: Creating a PyTorch Model with PyTorch-Lightning
The Beginner Tutorial serves as an entry point, guiding newcomers through the essential steps of setting up their environment, creating a simple convolutional neural network (CNN) using the MNIST dataset, and executing basic training and testing procedures without the complexity of manual loops. This section emphasizes understanding the foundational architecture and leveraging PyTorch-Lightning's streamlined training mechanisms to reduce code verbosity and enhance clarity.
1. Creating a Model
After installing and importing the libraries, we will create the architecture of the model. In this step we will define the structure of our model.
- The layers, activation functions are defined in this step.
- From the code we can see that we have used two Convolution layers, one MaxPool layer, two linear layers and ReLU as our activation function.
- For accuracy we have the Accuracy method of the torchmetrics library. We have also defined the forward pass method as well.
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torchmetrics import Accuracy # Use torchmetrics for accuracy
# Define the PyTorch Lightning model
class MNISTModel(pl.LightningModule):
def __init__(self):
super(MNISTModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3) # Conv layer 1
self.conv2 = nn.Conv2d(32, 64, kernel_size=3) # Conv layer 2
self.pool = nn.MaxPool2d(2, 2) # Max Pool layer
self.fc1 = nn.Linear(64 * 5 * 5, 128) # Adjusted Linear layer 1 (64 * 5 * 5 = 1600)
self.fc2 = nn.Linear(128, 10) # Output layer
self.accuracy = Accuracy(task='multiclass', num_classes=10) # Initialize accuracy metric
def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # (batch_size, 32, 26, 26) -> (batch_size, 32, 13, 13)
x = self.pool(F.relu(self.conv2(x))) # (batch_size, 64, 11, 11) -> (batch_size, 64, 5, 5)
x = x.view(x.size(0), -1) # Flatten the tensor
x = F.relu(self.fc1(x)) # (batch_size, 1600)
x = self.fc2(x) # (batch_size, 10)
return x
2. Training and Optimizing our model
In this step we will not use any loop to train our model. This is where Pytorch-Lightning comes into play.
- We will just provide with batch, inputs for the particular batch, pass the input into our model, use loss function to calculate loss for each batch.
- We use self.log method to get logs of the training loss. Finally we provide the Adam optimizer to optimize the model weights.
Python
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
def training_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('val_loss', loss)
def test_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
acc = self.accuracy(output, target) # Calculate accuracy using torchmetrics
self.log('test_loss', loss)
self.log('test_acc', acc) # Log accuracy as well
3. Preparation of the dataset
In this step we load the MNIST dataset and divide it into train, validation and test. For validation we consider about 20% of random training data. Lastly we create data loaders whose batch size is 64.
Python
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
# Dataset and DataLoader
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='data', train=False, download=True, transform=transform)
# Split training and validation sets
train_size = int(0.8 * len(train_data))
val_size = len(train_data) - train_size
train_data, val_data = random_split(train_data, [train_size, val_size])
# Data loaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)
4. Fit the data and test the model
In this we initialize the trainer model, fit the data and train it for 10 epochs. Then we use the test data loader to test our model performance.
Python
# Initialize and train the model
model = MNISTModel()
trainer = pl.Trainer(max_epochs=10)
# Train the model
trainer.fit(model, train_loader, val_loader)
# Test the model
trainer.test(model, test_loader)
Full code implementation
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import pytorch_lightning as pl
from torchmetrics import Accuracy # Use torchmetrics for accuracy
# Define the PyTorch Lightning model
class MNISTModel(pl.LightningModule):
def __init__(self):
super(MNISTModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3) # Conv layer 1
self.conv2 = nn.Conv2d(32, 64, kernel_size=3) # Conv layer 2
self.pool = nn.MaxPool2d(2, 2) # Max Pool layer
self.fc1 = nn.Linear(64 * 5 * 5, 128) # Adjusted Linear layer 1 (64 * 5 * 5 = 1600)
self.fc2 = nn.Linear(128, 10) # Output layer
self.accuracy = Accuracy(task='multiclass', num_classes=10) # Initialize accuracy metric
def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # (batch_size, 32, 26, 26) -> (batch_size, 32, 13, 13)
x = self.pool(F.relu(self.conv2(x))) # (batch_size, 64, 11, 11) -> (batch_size, 64, 5, 5)
x = x.view(x.size(0), -1) # Flatten the tensor
x = F.relu(self.fc1(x)) # (batch_size, 1600)
x = self.fc2(x) # (batch_size, 10)
return x
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
def training_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('val_loss', loss)
def test_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
acc = self.accuracy(output, target) # Calculate accuracy using torchmetrics
self.log('test_loss', loss)
self.log('test_acc', acc) # Log accuracy as well
# Dataset and DataLoader
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='data', train=False, download=True, transform=transform)
# Split training and validation sets
train_size = int(0.8 * len(train_data))
val_size = len(train_data) - train_size
train_data, val_data = random_split(train_data, [train_size, val_size])
# Data loaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)
# Initialize and train the model
model = MNISTModel()
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_loader, val_loader)
# Test the model
trainer.test(model, test_loader)
Output:
PyTorch Lightning TutorialsAs we can see the test accuracy of our model is 98.7%.
Moving into the Intermediate Tutorial, the focus shifts to optimizing model performance and resource efficiency. Here, readers learn to implement mixed precision training, which balances 16-bit and 32-bit floating-point computations to accelerate training and minimize memory usage.
Additionally, this section introduces the concept of custom callbacks, allowing users to inject custom behaviors—such as printing epoch numbers—into the training loop, thereby providing greater control and flexibility over the training process.
Mixed precision training utilizes both 16-bit and 32-bit floating-point types:
- 16-bit (FP16): Reduces memory consumption and increases computational speed.
- 32-bit (FP32): Maintains model stability during weight updates.
Here also we will create a model, defined the train and test methods inside the class. Finally we initialize our model and Trainer class for training and testing purposes. We have also used custom callback to print the epoch number during the training phase.
Modify the trainer to enable mixed precision and add a custom callback to monitor epochs:
Python
trainer = pl.Trainer(max_epochs=5,precision=16, # Enable mixed precision training
callbacks=[PrintEpochCallback()] )
From the code we can see that for enabling precision training when we call the Trainer class, we just provide with the precision value.
For example here we have given 16. So basically it might happen that 16 bit floating point numbers can be used during forward pass and gradient calculations while 32 bit floating point can be used during weight updates.
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import pytorch_lightning as pl
from pytorch_lightning.loggers import CometLogger
# Define the PyTorch Lightning model
class MNISTModel(pl.LightningModule):
def __init__(self):
super(MNISTModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3) # Conv layer 1
self.conv2 = nn.Conv2d(32, 64, kernel_size=3) # Conv layer 2
self.pool = nn.MaxPool2d(2, 2) # Max Pool layer
self.fc1 = nn.Linear(64 * 5 * 5, 128) # Adjusted Linear layer 1
self.fc2 = nn.Linear(128, 10) # Output layer
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(x.size(0), -1) # Flatten the tensor
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
def training_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('val_loss', loss)
def test_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('test_loss', loss)
# Dataset and DataLoader
# Custom Callback to print the epoch
class PrintEpochCallback(Callback):
def on_train_epoch_start(self, trainer, pl_module):
print(f"Starting Epoch: {trainer.current_epoch + 1}")
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='data', train=False, download=True, transform=transform)
# Split training and validation sets
train_size = int(0.8 * len(train_data))
val_size = len(train_data) - train_size
train_data, val_data = random_split(train_data, [train_size, val_size])
# Data loaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)
# Initialize and train the model
model = MNISTModel()
trainer = pl.Trainer(max_epochs=1,precision=16, # Enable mixed precision training
callbacks=[PrintEpochCallback()] ) # Add the custom callback)
trainer.fit(model, train_loader, val_loader)
# Test the model
trainer.test(model, test_loader)
Output:
PyTorch Lightning TutorialsAdvanced Tutorial: Integrating Comet Logger
Finally, the Advanced Tutorial delves into sophisticated integrations and experiment management techniques. It demonstrates how to incorporate external tools like Comet.ml for comprehensive experiment tracking and visualization, enabling users to log metrics, compare different training runs, and collaborate more effectively.
In this tutorial, we will monitor the training phase as we all know that Pytorch-Lightning can be integrated with many Loggers like Tensorboard Logger, Comet Logger. Here we will be using Comet Logger to log our metrics and visualize them interactively. It also helps us to keep track of hyperparameters and also provides with charts thereby reducing code complexity.
Setting Up Comet.ml:
- Sign Up: Create an account at Comet.ml.
- Obtain API Key: Navigate to Settings to find your API key.
- Install Comet-ML: Ensure it's installed via pip:
After signing up to Comet.ml in order to get the API key and the workspace name. A workspace gets created by default. Under the workspace consists of the list of projects. Also we need to install comet-ml using pip or conda command.
pip install comet-ml
- Now we will create a model, configure the optimizers, define the forward pass and provide with the train, validation and test methods.
- Then we will create dataloaders and lastly initialize the Comet logger with the API Key, workspace name and project name.
Python
# Initialize CometLogger
comet_logger = CometLogger(
api_key="API", # Replace with your Comet API key
project_name="mnist-classification",
workspace="WORKSPACE_NAME" # Replace with your workspace name
)
- CometLogger: Captures and logs metrics automatically when using
self.log
in the model. - Visualization: Access interactive dashboards on Comet.ml to monitor training progress, compare experiments, and analyze hyperparameters.
Below is the full implementation of the code
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import pytorch_lightning as pl
from pytorch_lightning.loggers import CometLogger
# Define the PyTorch Lightning model
class MNISTModel(pl.LightningModule):
def __init__(self):
super(MNISTModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3) # Conv layer 1
self.conv2 = nn.Conv2d(32, 64, kernel_size=3) # Conv layer 2
self.pool = nn.MaxPool2d(2, 2) # Max Pool layer
self.fc1 = nn.Linear(64 * 5 * 5, 128) # Adjusted Linear layer 1
self.fc2 = nn.Linear(128, 10) # Output layer
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(x.size(0), -1) # Flatten the tensor
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
def training_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('val_loss', loss)
def test_step(self, batch, batch_idx):
data, target = batch
output = self(data)
loss = F.cross_entropy(output, target)
self.log('test_loss', loss)
# Dataset and DataLoader
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='data', train=False, download=True, transform=transform)
# Split training and validation sets
train_size = int(0.8 * len(train_data))
val_size = len(train_data) - train_size
train_data, val_data = random_split(train_data, [train_size, val_size])
# Data loaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)
# Initialize CometLogger
comet_logger = CometLogger(
api_key="sbMMY0ClIkTR7QoREyRBFP3Ju", # Replace with your Comet API key
project_name="mnist-classification",
workspace="baidehi1874" # Replace with your workspace name
)
# Initialize and train the model
model = MNISTModel()
trainer = pl.Trainer(max_epochs=5, logger=comet_logger)
trainer.fit(model, train_loader, val_loader)
# Test the model
trainer.test(model, test_loader)
Output:
Benefits of Using Comet Logger:
- Real-Time Tracking: Monitor training metrics in real-time through interactive dashboards.
- Experiment Management: Compare different runs, track hyperparameters, and maintain reproducibility.
- Collaboration: Share experiments and results with team members seamlessly.
Conclusion
PyTorch-Lightning significantly simplifies the PyTorch workflow by abstracting complex training loops, enabling advanced features with minimal code changes, and integrating seamlessly with various tools for logging and monitoring. Whether you're a beginner aiming to build and train models efficiently or an advanced practitioner looking to optimize and monitor large-scale experiments, PyTorch-Lightning offers robust solutions to enhance your deep learning projects.
Similar Reads
Deep Learning Tutorial Deep Learning is a subset of Artificial Intelligence (AI) that helps machines to learn from large datasets using multi-layered neural networks. It automatically finds patterns and makes predictions and eliminates the need for manual feature extraction. Deep Learning tutorial covers the basics to adv
5 min read
Deep Learning Basics
Introduction to Deep LearningDeep Learning is transforming the way machines understand, learn and interact with complex data. Deep learning mimics neural networks of the human brain, it enables computers to autonomously uncover patterns and make informed decisions from vast amounts of unstructured data. How Deep Learning Works?
7 min read
Artificial intelligence vs Machine Learning vs Deep LearningNowadays many misconceptions are there related to the words machine learning, deep learning, and artificial intelligence (AI), most people think all these things are the same whenever they hear the word AI, they directly relate that word to machine learning or vice versa, well yes, these things are
4 min read
Deep Learning Examples: Practical Applications in Real LifeDeep learning is a branch of artificial intelligence (AI) that uses algorithms inspired by how the human brain works. It helps computers learn from large amounts of data and make smart decisions. Deep learning is behind many technologies we use every day like voice assistants and medical tools.This
3 min read
Challenges in Deep LearningDeep learning, a branch of artificial intelligence, uses neural networks to analyze and learn from large datasets. It powers advancements in image recognition, natural language processing, and autonomous systems. Despite its impressive capabilities, deep learning is not without its challenges. It in
7 min read
Why Deep Learning is ImportantDeep learning has emerged as one of the most transformative technologies of our time, revolutionizing numerous fields from computer vision to natural language processing. Its significance extends far beyond just improving predictive accuracy; it has reshaped entire industries and opened up new possi
5 min read
Neural Networks Basics
What is a Neural Network?Neural networks are machine learning models that mimic the complex functions of the human brain. These models consist of interconnected nodes or neurons that process data, learn patterns and enable tasks such as pattern recognition and decision-making.In this article, we will explore the fundamental
11 min read
Types of Neural NetworksNeural networks are computational models that mimic the way biological neural networks in the human brain process information. They consist of layers of neurons that transform the input data into meaningful outputs through a series of mathematical operations. In this article, we are going to explore
7 min read
Layers in Artificial Neural Networks (ANN)In Artificial Neural Networks (ANNs), data flows from the input layer to the output layer through one or more hidden layers. Each layer consists of neurons that receive input, process it, and pass the output to the next layer. The layers work together to extract features, transform data, and make pr
4 min read
Activation functions in Neural NetworksWhile building a neural network, one key decision is selecting the Activation Function for both the hidden layer and the output layer. It is a mathematical function applied to the output of a neuron. It introduces non-linearity into the model, allowing the network to learn and represent complex patt
8 min read
Feedforward Neural NetworkFeedforward Neural Network (FNN) is a type of artificial neural network in which information flows in a single direction i.e from the input layer through hidden layers to the output layer without loops or feedback. It is mainly used for pattern recognition tasks like image and speech classification.
6 min read
Backpropagation in Neural NetworkBack Propagation is also known as "Backward Propagation of Errors" is a method used to train neural network . Its goal is to reduce the difference between the modelâs predicted output and the actual output by adjusting the weights and biases in the network.It works iteratively to adjust weights and
9 min read
Deep Learning Models
Deep Learning Frameworks
TensorFlow TutorialTensorFlow is an open-source machine-learning framework developed by Google. It is written in Python, making it accessible and easy to understand. It is designed to build and train machine learning (ML) and deep learning models. It is highly scalable for both research and production.It supports CPUs
2 min read
Keras TutorialKeras high-level neural networks APIs that provide easy and efficient design and training of deep learning models. It is built on top of powerful frameworks like TensorFlow, making it both highly flexible and accessible. Keras has a simple and user-friendly interface, making it ideal for both beginn
3 min read
PyTorch TutorialPyTorch is an open-source deep learning framework designed to simplify the process of building neural networks and machine learning models. With its dynamic computation graph, PyTorch allows developers to modify the networkâs behavior in real-time, making it an excellent choice for both beginners an
7 min read
Caffe : Deep Learning FrameworkCaffe (Convolutional Architecture for Fast Feature Embedding) is an open-source deep learning framework developed by the Berkeley Vision and Learning Center (BVLC) to assist developers in creating, training, testing, and deploying deep neural networks. It provides a valuable medium for enhancing com
8 min read
Apache MXNet: The Scalable and Flexible Deep Learning FrameworkIn the ever-evolving landscape of artificial intelligence and deep learning, selecting the right framework for building and deploying models is crucial for performance, scalability, and ease of development. Apache MXNet, an open-source deep learning framework, stands out by offering flexibility, sca
6 min read
Theano in PythonTheano is a Python library that allows us to evaluate mathematical operations including multi-dimensional arrays efficiently. It is mostly used in building Deep Learning Projects. Theano works way faster on the Graphics Processing Unit (GPU) rather than on the CPU. This article will help you to unde
4 min read
Model Evaluation
Deep Learning Projects