Introduction:
PyTorch Lightning is a library that provides a high-level interface for PyTorch. Problem with PyTorch is that every time you start a project you have to rewrite those training and testing loop. PyTorch Lightning fixes the problem by not only reducing boilerplate code but also providing added functionality that might come handy while training your neural networks. One of the things I love about Lightning is that the code is very organized and reusable, and not only that but it reduces the training and testing loop while retain the flexibility that PyTorch is known for. And once you learn how to use it you’ll see how similar the code is to that of PyTorch.
Installing PyTorch Lightning:
Installing Lightning is same as that of any other library in python.
pip install pytorch-lightning
or if you want to install it in a conda environment you can use the following command:-
conda install -c conda-forge pytorch-lightning
PyTorch Lightning Model Format:
If you have ever used PyTorch you must know that defining PyTorch model follows the following format
from torch import nn
class model(nn.Module):
def __init__(self):
# Define Model Here
def forward(self, x):
# Define Forward Pass Here
That’s how we define a model in PyTorch now after defining loop we usually define loss, optimizer and training outside the class. In PyTorch Lightning, the way to define model is similar except for the fact that we add the loss, optimizer and training steps in the model itself. To define a lightning model we follow the following format:-
import pytorch-lightning as pl
class model(pl.LightningModule):
def __init__(self):
# Define Model Here
def forward(self, x):
# Define Forward Pass Here
def configure_optimizers(self):
# Define Optimizer Here
def training_step(self, train_batch, batch_idx):
# Define Training loop steps here
def validation_step(self, valid_batch, batch_idx):
# Define Validation loop steps here
Note: The names of the above functions should be exactly the same.
Training our Neural Network:
Loading Our Data:
For this tutorial we are going to be using MNIST dataset, so we’ll start by loading our data and defining the model afterwards. To load data for Lightning Model you can either define DataLoaders as you do in PyTorch and pass both train dataloader and validation dataloader in pl.Trainer() function or you can use LightingDataModule which does the same thing except now you do the steps in a python class. To create dataloaders we follow the following step:-
Loading Data by Creating DataLoaders:
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.ToTensor()
])
train = datasets.MNIST('',train = True, download = True, transform=transform)
test = datasets.MNIST('',train = False, download = True, transform=transform)
trainloader = DataLoader(train, batch_size= 32, shuffle=True)
testloader = DataLoader(test, batch_size= 32, shuffle=True)
To creating LightningDataModule we follow the following steps:-
Loading Data by Creating LightningDataModule:
import pytorch-lightning as pl
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
class Data(pl.LightningDataModule):
def prepare_data(self):
transform=transforms.Compose([
transforms.ToTensor()
])
self.train_data = datasets.MNIST('', train=True, download=True, transform=transform)
self.test_data = datasets.MNIST('', train=False, download=True, transform=transform)
def train_dataloader(self):
return DataLoader(self.train_data, batch_size= 32, shuffle=True)
def val_dataloader(self):
return DataLoader(self.test_data, batch_size= 32, shuffle=True)
Note: The names of the above functions should be exactly the same.
This is how you create Lightning Data Module. Creating dataloaders can get messy thats why its better to club the dataset in form of Data Module.
Defining Our Neural Network
Defining the model in PyTorch lighting is pretty much the same as that in PyTorch except now we are clubbing everything inside our model class.
from torch import nn
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.optim import SGD
class model(pl.LightningModule):
def __init__(self):
super(model,self).__init__()
self.fc1 = nn.Linear(28*28,256)
self.fc2 = nn.Linear(256,128)
self.out = nn.Linear(128,10)
self.lr = 0.01
self.loss = nn.CrossEntropyLoss()
def forward(self,x):
batch_size, _, _, _ = x.size()
x = x.view(batch_size,-1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.out(x)
def configure_optimizers(self):
return SGD(self.parameters(),lr = self.lr)
def training_step(self, train_batch, batch_idx):
x, y = train_batch
logits = self.forward(x)
loss = self.loss(logits,y)
return loss
def validation_step(self, valid_batch, batch_idx):
x, y = valid_batch
logits = self.forward(x)
loss = self.loss(logits,y)
We’ll further discuss how training_step() differs from the steps in Training Loop in Pytorch and other difference between Lightning Model and Pytorch model.
Training Our Model
To training model in Pytorch, you first have to write the training loop but the Trainer class in Lightning makes the tasks easier. To Train model in Lightning:-
# Create Model Object
clf = model()
# Create Data Module Object
mnist = Data()
# Create Trainer Object
trainer = pl.Trainer(gpus=1,accelerator='dp',max_epochs=5)
trainer.fit(clf,mnist)
Note: `dp` is DataParallel (split batch among GPUs of same machine).
Note: If you have loaded data by creating dataloaders you can fit trainer by trainer.fit(clf,trainloader,testloader).
Difference Between PyTorch Model and Lightning Model:
As we can see the first difference between PyTorch and lightning model is the class that the model class inherits:-
PyTorch
class model(nn.Module):
PyTorch-Lightning
class model(pl.LightningModule):
__init__() method
In both Pytorch and and Lightning Model we use the __init__() method to define our layers, since in lightning we club everything together we can also define other hyper parameters like learning rate for optimizer and the loss function.
PyTorch
def __init__(self):
super(model,self).__init__()
self.fc1 = nn.Linear(28*28,256)
self.fc2 = nn.Linear(256,128)
self.out = nn.Linear(128,10)
Pytorch-Lightning
def __init__(self):
super(model,self).__init__()
self.fc1 = nn.Linear(28*28,256)
self.fc2 = nn.Linear(256,128)
self.out = nn.Linear(128,10)
self.lr = 0.01
self.loss = nn.CrossEntropyLoss()
forward() method:
In both Pytorch and Lightning Model we use the forward() method to define our forward pass, hence it is same for both.
PyTorch and PyTorch-Lightning
def forward(self,x):
batch_size, _, _, _ = x.size()
x = x.view(batch_size,-1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.out(x)
Defining Optimizer:
In PyTorch, we usually define our optimizers by directly creating their object but in PyTorch-lightning we define our optimizers under configure_optimizers() method. Another thing to note is that in PyTorch we pass model object parameters as the arguments for optimizer but in lightning, we pass self.parameters() since the class is the model itself.
PyTorch
from torch.optim import SGD
clf = model() # Pytorch Model Object
optimizer = SGD(clf.parameters(),lr=0.01)
PyTorch-Lightning
def configure_optimizers(self):
return SGD(self.parameters(),lr = self.lr)
Note: You can create multiple optimizers in lightning too.
Training Loop(Step):
It won’t be wrong to say that this is what makes Lightning stand out from PyTorch. In PyTorch we define the full training loop while in lightning we use the Trainer() to do the job. But we still define the steps that are going to be executed while training.
PyTorch
epochs = 5
for i in range(epochs):
train_loss = 0.0
for data,label in trainloader:
if is_gpu:
data, label = data.cuda(), label.cuda()
output = model(data)
optimizer.zero_grad()
loss = criterion(output,label)
loss.backward()
optimizer.step()
train_loss += loss.item() * data.size(0)
print(f'Epoch: {i+1} / {epochs} \t\t\t Training Loss:{train_loss/len(trainloader)}')
PyTorch-Lightning
def training_step(self, train_batch, batch_idx):
x, y = train_batch
logits = self.forward(x)
loss = self.loss(logits,y)
return loss
See how in training steps we just write the steps necessary(bolded).
Code
Python3
import torch
from torch import nn
import pytorch_lightning as pl
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.optim import SGD
class model(pl.LightningModule):
def __init__( self ):
super (model, self ).__init__()
self .fc1 = nn.Linear( 28 * 28 , 256 )
self .fc2 = nn.Linear( 256 , 128 )
self .out = nn.Linear( 128 , 10 )
self .lr = 0.01
self .loss = nn.CrossEntropyLoss()
def forward( self , x):
batch_size, _, _, _ = x.size()
x = x.view(batch_size, - 1 )
x = F.relu( self .fc1(x))
x = F.relu( self .fc2(x))
return self .out(x)
def configure_optimizers( self ):
return torch.optim.SGD( self .parameters(), lr = self .lr)
def training_step( self , train_batch, batch_idx):
x, y = train_batch
logits = self .forward(x)
loss = self .loss(logits, y)
return loss
def validation_step( self , valid_batch, batch_idx):
x, y = valid_batch
logits = self .forward(x)
loss = self .loss(logits, y)
class Data(pl.LightningDataModule):
def prepare_data( self ):
transform = transforms.Compose([
transforms.ToTensor()
])
self .train_data = datasets.MNIST(
'', train = True , download = True , transform = transform)
self .test_data = datasets.MNIST(
'', train = False , download = True , transform = transform)
def train_dataloader( self ):
return DataLoader( self .train_data, batch_size = 32 , shuffle = True )
def val_dataloader( self ):
return DataLoader( self .test_data, batch_size = 32 , shuffle = True )
clf = model()
mnist = Data()
trainer = pl.Trainer(gpus = 1 , accelerator = 'dp' , max_epochs = 5 )
trainer.fit(clf, mnist)
|
Similar Reads
Training Neural Networks with Validation using PyTorch
Neural Networks are a biologically-inspired programming paradigm that deep learning is built around. Python provides various libraries using which you can create and train neural networks over given data. PyTorch is one such library that provides us with various utilities to build and train neural n
8 min read
Training a Neural Network using Keras API in Tensorflow
In the field of machine learning and deep learning has been significantly transformed by tools like TensorFlow and Keras. TensorFlow, developed by Google, is an open-source platform that provides a comprehensive ecosystem for machine learning. Keras, now fully integrated into TensorFlow, offers a us
3 min read
Train and Test Neural Networks Using R
Training and testing neural networks using R is a fundamental aspect of machine learning and deep learning. In this comprehensive guide, we will explore the theory and practical steps involved in building, training, and evaluating neural networks in R Programming Language. Neural networks are a clas
10 min read
Visualizing PyTorch Neural Networks
Visualizing neural network models is a crucial step in understanding their architecture, debugging, and conveying their design. PyTorch, a popular deep learning framework, offers several tools and libraries that facilitate model visualization. This article will guide you through the process of visua
4 min read
Building a Convolutional Neural Network using PyTorch
Convolutional Neural Networks (CNNs) are deep learning models used for image processing tasks. They automatically learn spatial hierarchies of features from images through convolutional, pooling and fully connected layers. In this article we'll learn how to build a CNN model using PyTorch. This incl
6 min read
Graph Neural Networks with PyTorch
Graph Neural Networks (GNNs) represent a powerful class of machine learning models tailored for interpreting data described by graphs. This is particularly useful because many real-world structures are networks composed of interconnected elements, such as social networks, molecular structures, and c
4 min read
Implementing Recurrent Neural Networks in PyTorch
Recurrent Neural Networks (RNNs) are a class of neural networks that are particularly effective for sequential data. Unlike traditional feedforward neural networks RNNs have connections that form loops allowing them to maintain a hidden state that can capture information from previous inputs. This m
6 min read
How to Visualize PyTorch Neural Networks
Visualizing neural networks is crucial for understanding their architecture, debugging, and optimizing models. PyTorch offers several ways to visualize both simple and complex neural networks. In this article, we'll explore how to visualize different types of neural networks, including a simple feed
7 min read
Graph Neural Networks (GNNs) Using R
A specialized class of neural networks known as Graph Neural Networks (GNNs) has been developed to learn from such graph-structured data effectively. GNNs are designed to capture the dependencies between nodes in a graph through message passing between the nodes, making them powerful tools for tasks
8 min read
Implementing Neural Networks Using TensorFlow
Deep learning has been on the rise in this decade and its applications are so wide-ranging and amazing that it's almost hard to believe that it's been only a few years in its advancements. And at the core of deep learning lies a basic "unit" that governs its architecture, yes, It's neural networks.
8 min read