data_parallelism
data_parallelism
Table of Contents:
1. Introduction
2. Data Parallelism
3. Model Parallelism
4. Code Example of Model Parallelism in PyTorch
5. Saving and Serving a Model-Trained with Model Parallelism
○ Saving the Model
○ Serving for Online Inference
○ Inference on Multiple vs. Single Devices
6. Conclusion
1. Introduction
In distributed deep learning, there are two primary strategies for scaling training across multiple
devices (e.g., GPUs): Data Parallelism and Model Parallelism. Understanding these
strategies is crucial for efficiently training large models or large datasets.
2. Data Parallelism
Definition: Each device (GPU) holds a full copy of the model. The dataset is split into batches
that are distributed across devices. Each GPU processes a separate batch, computes
gradients, and the gradients are then aggregated to update the model weights.
Pros:
● Straightforward to implement.
● Scales well with large datasets.
Cons:
Data parallelism is best when the model comfortably fits into a single GPU’s memory, and you
have a large amount of data.
3. Model Parallelism
Definition: The model is split across multiple devices. Each device holds only a part of the
model. During the forward pass, intermediate outputs are passed between devices.
Pros:
● Enables training of very large models that cannot fit into a single GPU’s memory.
Cons:
Model parallelism is ideal when model size is the bottleneck rather than dataset size.
Note: This is a simplified example assuming two GPUs, GPU 0 and GPU 1. The model’s first
half runs on GPU 0 and the second half on GPU 1.
python
Copy code
import torch
import torch.nn as nn
import torch.optim as optim
# Device setup
device0 = torch.device("cuda:0" if torch.cuda.is_available() else
"cpu")
device1 = torch.device("cuda:1" if (torch.cuda.is_available() and
torch.cuda.device_count() > 1) else "cpu")
class ModelParallelNN(nn.Module):
def __init__(self):
super(ModelParallelNN, self).__init__()
# Part of model on GPU 0
self.fc1 = nn.Linear(1024, 512).to(device0)
self.relu = nn.ReLU()
# Dummy data
data = torch.randn(64, 1024) # 64 examples, 1024 features each
labels = torch.randint(0, 10, (64,)).to(device1)
Saving works similarly to standard PyTorch models. The state_dict includes all parameters
from all devices.
python
Copy code
torch.save(model.state_dict(), 'model_parallel.pth')
Loading:
python
Copy code
model = ModelParallelNN()
model.load_state_dict(torch.load('model_parallel.pth'))
# Ensure parts of model are on correct devices if re-instantiated
model.fc1.to(device0)
model.fc2.to(device1)
model.fc3.to(device1)
python
Copy code
device = torch.device('cuda:0' if torch.cuda.is_available() else
'cpu')
model = ModelParallelNN()
model.load_state_dict(torch.load('model_parallel.pth',
map_location=device))
model.to(device)
python
Copy code
def infer(input_data):
input_data = input_data.to(device)
with torch.no_grad():
output = model(input_data)
return output
Inference on Multiple Devices:
If the model is too large to fit on one device, you can perform inference similarly to the training
forward pass, with parts of the model on different GPUs.
python
Copy code
def infer_parallel(input_data):
input_data = input_data.to(device0)
with torch.no_grad():
output = model(input_data)
return output
In Practice:
● If possible, consolidate the model onto one device for inference to reduce complexity
and overhead.
● Use frameworks like TorchServe or NVIDIA Triton to handle multi-GPU deployment and
scaling.
● Convert models to ONNX and use efficient inference engines if needed.
6. Conclusion
● Data Parallelism is straightforward when the model fits on a single device and involves
replicating the model across multiple devices to process different parts of the dataset.
● Model Parallelism is used when the model is too large for a single device, splitting it
across multiple devices.
● When serving models for online inference, consider consolidating onto a single device if
feasible. If the model is too large, maintain model parallelism for inference.
● Saving and loading model-parallel-trained models involves saving the state_dict and
carefully loading it onto the appropriate devices.