Saving and Loading Weights in PyTorch Lightning
Last Updated :
23 Jul, 2025
In Machine learning models, it is important to save and load weights efficiently. This helps us preserve the state of our model during training, so we can resume later without starting from scratch. In this article, we are going to discuss how to save and load weights in PyTorch Lightning. PyTorch Lightning is an easy-to-use library that simplifies PyTorch.
We will cover the steps involved in saving and loading weights, various configurations, and best practices for working with models in PyTorch Lightning.
Why Saving and Loading Weights is Important?
let's first understand, PyTorch Lightning is a lightweight wrapper around PyTorch that helps us organize our code and reduce boilerplate. It makes training models simpler and more efficient by providing built-in features for saving and loading weights, managing checkpoints, and many more.
Saving and loading model weights is essential for the following reasons:
- Checkpointing: Regularly saving model weights ensures that you can resume training from the last saved state in case of interruptions.
- Inference: Once a model is trained, you can save its weights to disk and load them later for inference without having to retrain the model.
- Model Versioning: It allows you to keep different versions of your model with varying hyperparameters and architectures.
- Transfer Learning: Loading pre-trained weights enables fine-tuning models for different tasks.
Checkpoints in PyTorch Lightning
PyTorch Lightning provides built-in support for saving and loading model checkpoints. These checkpoints store more than just the model weights—they also include information about the optimizer, learning rate scheduler, and current epoch, making it easy to resume training seamlessly.
A checkpoint is essentially a snapshot of our model at a specific point during training. It saves not only the model's weights but also things such as:
- Current training epoch
- Optimizer states
- Learning rate scheduler states
- Hyperparameters used during training
Saving Model Weights in PyTorch Lightning
The ModelCheckpoint
callback in PyTorch Lightning is designed to save the model's state at specified intervals or under certain conditions such as when the validation accuracy improves.
Install PyTorch Lightning: In our Google Colab or Jupyter notebook, run the following command to install the library:
!pip install pytorch-lightning
Step 1: Import Required Libraries
First, we will import some required libraries:
- PyTorch for building the neural network and managing data.
- PyTorch Lightning to streamline the training process.
- ModelCheckpoint to save the model automatically based on the loss during training.
Python
import pytorch_lightning as pl
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning.callbacks import ModelCheckpoint
Step 2: Create a Sample Dataset
We will generate a simple dataset where the target y
follows the formula y = 2x + 1
. PyTorch's TensorDataset
will hold the features x
and labels y
. The DataLoader
will handle batching the data during training.
Python
x = torch.rand(100, 1) # Random 100 data points
y = 2 * x + 1 # Linear relationship
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=10) # Batching the data
Step 3: Define the Model
We define a very simple neural network with just one linear layer using PyTorch's nn.Linear
. This is a basic linear regression model, which tries to learn the relationship between input x
and output y
. In this model training_step d
efines the training loop for one batch, computing the Mean Squared Error (MSE) loss and configure_optimizers s
pecifies the optimizer for the model parameters, in this case, Stochastic Gradient Descent (SGD).
Python
class SimpleModel(pl.LightningModule):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(1, 1) # Linear layer with 1 input, 1 output
def forward(self, x):
return self.linear(x) # Forward pass
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x) # Model predictions
loss = nn.MSELoss()(y_hat, y) # Compute loss
self.log('train_loss', loss) # Log the training loss
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.01) # Optimizer
Step 4: Add ModelCheckpoint Callback (Model Saving)
We use PyTorch Lightning’s ModelCheckpoint
callback to save the best model during training. The ModelCheckpoint
saves the model every time a new minimum training loss is found. The ModelCheckpoint
callback is used to automatically save the model's weights during training. In your code:
monitor='train_loss'
: It monitors the training loss.filename='best_model'
: The model is saved with this filename.save_top_k=1
: Only the best model (in terms of the lowest training loss) will be saved.mode='min'
: The checkpoint is saved when the monitored value (train_loss
) decreases.dirpath='checkpoints/'
: Specifies the directory where the checkpoint is saved.
During training, the model's best weights are saved in the checkpoints/best_model.ckpt
file.
Python
checkpoint_callback = ModelCheckpoint(
monitor='train_loss',
filename='best_model',
save_top_k=1,
mode='min',
dirpath='checkpoints/'
)
Step 5: Initialize the Trainer
The pl.Trainer
is the core of PyTorch Lightning. Here, we pass the checkpoint_callback
and set max_epochs=1
0.
Python
trainer = pl.Trainer(callbacks=[checkpoint_callback], max_epochs=10)
Step 6: Train the Model
Now we train the model using the trainer.fit()
method. The model and data loader are passed as arguments, and training begins.
Python
trainer.fit(SimpleModel(), dataloader)
During training, the model's weights will be saved in the checkpoints/
directory, and the checkpoint with the best training loss will be saved as best_model.ckpt
.
Loading Model Weights in PyTorch Lightning
After training is complete, we can load the best model from the checkpoint. This allows us to resume training, fine-tune the model, or use it for inference.
Python
loaded_model = SimpleModel.load_from_checkpoint('checkpoints/best_model.ckpt')
Example: Saving and Loading Weights of a Simple Model
Now, the complete code which shows how to build, train, and save a simple linear regression model using PyTorch Lightning.
Python
import pytorch_lightning as pl
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning.callbacks import ModelCheckpoint
# Sample dataset
x = torch.rand(100, 1)
y = 2 * x + 1
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=10)
# Define a simple model
class SimpleModel(pl.LightningModule):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.MSELoss()(y_hat, y)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.01)
# Create a ModelCheckpoint callback
checkpoint_callback = ModelCheckpoint(
monitor='train_loss',
filename='best_model',
save_top_k=1,
mode='min',
dirpath='checkpoints/'
)
# Initialize the trainer with the checkpoint callback
trainer = pl.Trainer(callbacks=[checkpoint_callback], max_epochs=10)
# Train the model
trainer.fit(SimpleModel(), dataloader)
Output:
INFO:pytorch_lightning.callbacks.model_summary:
| Name | Type | Params | Mode
------------------------------------------
0 | linear | Linear | 2 | train
------------------------------------------
2 Trainable params
0 Non-trainable params
2 Total params
0.000 Total estimated model params size (MB)
1 Modules in train mode
0 Modules in eval mode
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Epoch 9: 100%
 10/10 [00:00<00:00, 205.39it/s, v_num=0]
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.
Python
After training, load the model
loaded_model = SimpleModel.load_from_checkpoint('checkpoints/best_model.ckpt')
Output:
Saving and Loading WeightsBest Practices for Saving and Loading Weights
1. Monitor the Right Metric
It’s important to monitor the most relevant metric for your task when saving checkpoints. For example, for a classification task, you might want to monitor validation accuracy (val_acc
), while for a regression task, you may want to track validation loss (val_loss
).
2. Use save_top_k
Wisely
The save_top_k
argument in the ModelCheckpoint
callback allows you to save only the best performing models. This helps in reducing storage overhead by not saving every checkpoint.
3. Use GPU/CPU Flexibility
When saving and loading models, PyTorch Lightning takes care of moving your model between CPUs and GPUs automatically. This means you can train your model on a GPU and load it for inference on a CPU without any changes.
# Load the model on a specific device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MyModel.load_from_checkpoint("checkpoint.ckpt", map_location=device)
4. Checkpoint Naming Conventions
Use meaningful naming conventions when saving checkpoints. Including metrics such as epoch and validation loss/accuracy in the checkpoint filename helps you to identify the best models easily.
5. Resume Training with Frozen Weights
If you want to resume training with part of the model frozen (e.g., for fine-tuning), you can achieve this by manually setting the requires_grad
flag of the layers you want to freeze.
# Load model and freeze some layers
model = MyModel.load_from_checkpoint("checkpoint.ckpt")
for param in model.feature_extractor.parameters():
param.requires_grad = False
Common Pitfalls
- Forgetting to Save the Optimizer State: When saving a model for resuming training, ensure that you save the optimizer’s state. PyTorch Lightning’s checkpointing system automatically saves the optimizer state, but if you are manually handling checkpoints, you need to include the optimizer state in the checkpoint.
- Overwriting Checkpoints: If you don’t specify a unique filename or directory for each checkpoint, you might end up overwriting previously saved models. To avoid this, use dynamic file names based on epochs and metrics.
- Loading Weights into a Different Model Architecture: If you attempt to load weights into a model that does not match the architecture of the saved model, PyTorch will throw an error. Always ensure that the architecture is identical when loading weights.
Conclusion
In this article, we have seen a basic workflow for training a model using PyTorch Lightning and ModelCheckpoint to save the best-performing model. It automates many aspects of training, including managing the training loop and saving model checkpoints which makes it easier to focus on building and fine-tuning the model.
Similar Reads
Deep Learning Tutorial Deep Learning is a subset of Artificial Intelligence (AI) that helps machines to learn from large datasets using multi-layered neural networks. It automatically finds patterns and makes predictions and eliminates the need for manual feature extraction. Deep Learning tutorial covers the basics to adv
5 min read
Deep Learning Basics
Introduction to Deep LearningDeep Learning is transforming the way machines understand, learn and interact with complex data. Deep learning mimics neural networks of the human brain, it enables computers to autonomously uncover patterns and make informed decisions from vast amounts of unstructured data. How Deep Learning Works?
7 min read
Artificial intelligence vs Machine Learning vs Deep LearningNowadays many misconceptions are there related to the words machine learning, deep learning, and artificial intelligence (AI), most people think all these things are the same whenever they hear the word AI, they directly relate that word to machine learning or vice versa, well yes, these things are
4 min read
Deep Learning Examples: Practical Applications in Real LifeDeep learning is a branch of artificial intelligence (AI) that uses algorithms inspired by how the human brain works. It helps computers learn from large amounts of data and make smart decisions. Deep learning is behind many technologies we use every day like voice assistants and medical tools.This
3 min read
Challenges in Deep LearningDeep learning, a branch of artificial intelligence, uses neural networks to analyze and learn from large datasets. It powers advancements in image recognition, natural language processing, and autonomous systems. Despite its impressive capabilities, deep learning is not without its challenges. It in
7 min read
Why Deep Learning is ImportantDeep learning has emerged as one of the most transformative technologies of our time, revolutionizing numerous fields from computer vision to natural language processing. Its significance extends far beyond just improving predictive accuracy; it has reshaped entire industries and opened up new possi
5 min read
Neural Networks Basics
What is a Neural Network?Neural networks are machine learning models that mimic the complex functions of the human brain. These models consist of interconnected nodes or neurons that process data, learn patterns and enable tasks such as pattern recognition and decision-making.In this article, we will explore the fundamental
11 min read
Types of Neural NetworksNeural networks are computational models that mimic the way biological neural networks in the human brain process information. They consist of layers of neurons that transform the input data into meaningful outputs through a series of mathematical operations. In this article, we are going to explore
7 min read
Layers in Artificial Neural Networks (ANN)In Artificial Neural Networks (ANNs), data flows from the input layer to the output layer through one or more hidden layers. Each layer consists of neurons that receive input, process it, and pass the output to the next layer. The layers work together to extract features, transform data, and make pr
4 min read
Activation functions in Neural NetworksWhile building a neural network, one key decision is selecting the Activation Function for both the hidden layer and the output layer. It is a mathematical function applied to the output of a neuron. It introduces non-linearity into the model, allowing the network to learn and represent complex patt
8 min read
Feedforward Neural NetworkFeedforward Neural Network (FNN) is a type of artificial neural network in which information flows in a single direction i.e from the input layer through hidden layers to the output layer without loops or feedback. It is mainly used for pattern recognition tasks like image and speech classification.
6 min read
Backpropagation in Neural NetworkBack Propagation is also known as "Backward Propagation of Errors" is a method used to train neural network . Its goal is to reduce the difference between the modelâs predicted output and the actual output by adjusting the weights and biases in the network.It works iteratively to adjust weights and
9 min read
Deep Learning Models
Deep Learning Frameworks
TensorFlow TutorialTensorFlow is an open-source machine-learning framework developed by Google. It is written in Python, making it accessible and easy to understand. It is designed to build and train machine learning (ML) and deep learning models. It is highly scalable for both research and production.It supports CPUs
2 min read
Keras TutorialKeras high-level neural networks APIs that provide easy and efficient design and training of deep learning models. It is built on top of powerful frameworks like TensorFlow, making it both highly flexible and accessible. Keras has a simple and user-friendly interface, making it ideal for both beginn
3 min read
PyTorch TutorialPyTorch is an open-source deep learning framework designed to simplify the process of building neural networks and machine learning models. With its dynamic computation graph, PyTorch allows developers to modify the networkâs behavior in real-time, making it an excellent choice for both beginners an
7 min read
Caffe : Deep Learning FrameworkCaffe (Convolutional Architecture for Fast Feature Embedding) is an open-source deep learning framework developed by the Berkeley Vision and Learning Center (BVLC) to assist developers in creating, training, testing, and deploying deep neural networks. It provides a valuable medium for enhancing com
8 min read
Apache MXNet: The Scalable and Flexible Deep Learning FrameworkIn the ever-evolving landscape of artificial intelligence and deep learning, selecting the right framework for building and deploying models is crucial for performance, scalability, and ease of development. Apache MXNet, an open-source deep learning framework, stands out by offering flexibility, sca
6 min read
Theano in PythonTheano is a Python library that allows us to evaluate mathematical operations including multi-dimensional arrays efficiently. It is mostly used in building Deep Learning Projects. Theano works way faster on the Graphics Processing Unit (GPU) rather than on the CPU. This article will help you to unde
4 min read
Model Evaluation
Deep Learning Projects