Skip to content

Commit b01eac6

Browse files
authored
Merge branch 'master' into cuda-updates
2 parents bc52639 + 960192c commit b01eac6

File tree

4 files changed

+173
-4
lines changed

4 files changed

+173
-4
lines changed

_templates/layout.html

+5
Original file line numberDiff line numberDiff line change
@@ -73,5 +73,10 @@
7373
src="https://www.facebook.com/tr?id=243028289693773&ev=PageView
7474
&noscript=1"/>
7575
</noscript>
76+
77+
<script type="text/javascript">
78+
var collapsedSections = [];
79+
</script>
80+
7681
<img height="1" width="1" style="border-style:none;" alt="" src="https://www.googleadservices.com/pagead/conversion/795629140/?label=txkmCPmdtosBENSssfsC&amp;guid=ON&amp;script=0"/>
7782
{% endblock %}

index.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ Welcome to PyTorch Tutorials
179179

180180
.. customcarditem::
181181
:header: Train a Mario-playing RL Agent
182-
:card_description: Use PyTorch to train a Double Q-learning agent to play Mario .
182+
:card_description: Use PyTorch to train a Double Q-learning agent to play Mario.
183183
:image: _static/img/mario.gif
184184
:link: intermediate/mario_rl_tutorial.html
185185
:tags: Reinforcement-Learning
@@ -199,14 +199,14 @@ Welcome to PyTorch Tutorials
199199
:card_description: Introduction to TorchScript, an intermediate representation of a PyTorch model (subclass of nn.Module) that can then be run in a high-performance environment such as C++.
200200
:image: _static/img/thumbnails/cropped/Introduction-to-TorchScript.png
201201
:link: beginner/Intro_to_TorchScript_tutorial.html
202-
:tags: Production
202+
:tags: Production,TorchScript
203203

204204
.. customcarditem::
205205
:header: Loading a TorchScript Model in C++
206206
:card_description: Learn how PyTorch provides to go from an existing Python model to a serialized representation that can be loaded and executed purely from C++, with no dependency on Python.
207207
:image: _static/img/thumbnails/cropped/Loading-a-TorchScript-Model-in-Cpp.png
208208
:link: advanced/cpp_export.html
209-
:tags: Production
209+
:tags: Production,TorchScript
210210

211211
.. customcarditem::
212212
:header: (optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime

recipes_source/recipes_index.rst

+10-1
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
225225
:link: ../recipes/recipes/tuning_guide.html
226226
:tags: Model-Optimization
227227

228+
.. Distributed Training
229+
230+
.. customcarditem::
231+
:header: Shard Optimizer States with ZeroRedundancyOptimizer
232+
:card_description: How to use ZeroRedundancyOptimizer to reduce memory consumption.
233+
:image: ../_static/img/thumbnails/cropped/profiler.png
234+
:link: ../recipes/zero_redundancy_optimizer.html
235+
:tags: Distributed-Training
236+
228237
.. End of tutorial card section
229238
230239
.. raw:: html
@@ -261,4 +270,4 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
261270
/recipes/torchscript_inference
262271
/recipes/deployment_with_flask
263272
/recipes/distributed_rpc_profiling
264-
/recipes/distributed_rpc_profiling
273+
/recipes/zero_redundancy_optimizer
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
Shard Optimizer States with ZeroRedundancyOptimizer
2+
===================================================
3+
4+
.. note:
5+
`ZeroRedundancyOptimizer` is introduced in PyTorch 1.8 as a prototype
6+
feature. It API is subject to change.
7+
8+
In this recipe, you will learn:
9+
10+
- The high-level idea of ``ZeroRedundancyOptimizer``.
11+
- How to use ``ZeroRedundancyOptimizer`` in distributed training and its impact.
12+
13+
14+
Requirements
15+
------------
16+
17+
- PyTorch 1.8+
18+
- `Getting Started With Distributed Data Parallel <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`_
19+
20+
21+
What is ``ZeroRedundancyOptimizer``?
22+
------------------------------------
23+
24+
The idea of ``ZeroRedundancyOptimizer`` comes from
25+
`DeepSpeed/ZeRO project <https://github.com/microsoft/DeepSpeed>`_ and
26+
`Marian <https://github.com/marian-nmt/marian-dev>`_ that shard
27+
optimizer states across distributed data-parallel processes to
28+
reduce per-process memory footprint. In the
29+
`Getting Started With Distributed Data Parallel <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`_
30+
tutorial, we have shown how to use
31+
`DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html>`_
32+
(DDP) to train models. In that tutorial, each process keeps a dedicated replica
33+
of the optimizer. Since DDP has already synchronized gradients in the
34+
backward pass, all optimizer replicas will operate on the same parameter and
35+
gradient values in every iteration, and this is how DDP keeps model replicas in
36+
the same state. Oftentimes, optimizers also maintain local states. For example,
37+
the ``Adam`` optimizer uses per-parameter ``exp_avg`` and ``exp_avg_sq`` states. As a
38+
result, the ``Adam`` optimizer's memory consumption is at least twice the model
39+
size. Given this observation, we can reduce the optimizer memory footprint by
40+
sharding optimizer states across DDP processes. More specifically, instead of
41+
creating per-param states for all parameters, each optimizer instance in
42+
different DDP processes only keeps optimizer states for a shard of all model
43+
parameters. The optimizer ``step()`` function only updates the parameters in its
44+
shard and then broadcasts its updated parameters to all other peer DDP
45+
processes, so that all model replicas still land in the same state.
46+
47+
How to use ``ZeroRedundancyOptimizer``?
48+
---------------------------------------
49+
50+
The code below demonstrates how to use ``ZeroRedundancyOptimizer``. The majority
51+
of the code is similar to the simple DDP example presented in
52+
`Distributed Data Parallel notes <https://pytorch.org/docs/stable/notes/ddp.html>`_.
53+
The main difference is the ``if-else`` clause in the ``example`` function which
54+
wraps optimizer constructions, toggling between ``ZeroRedundancyOptimizer`` and
55+
``Adam`` optimizer.
56+
57+
58+
::
59+
60+
import os
61+
import torch
62+
import torch.distributed as dist
63+
import torch.multiprocessing as mp
64+
import torch.nn as nn
65+
import torch.optim as optim
66+
from torch.distributed.optim import ZeroRedundancyOptimizer
67+
from torch.nn.parallel import DistributedDataParallel as DDP
68+
69+
def print_peak_memory(prefix, device):
70+
if device == 0:
71+
print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ")
72+
73+
def example(rank, world_size, use_zero):
74+
torch.manual_seed(0)
75+
torch.cuda.manual_seed(0)
76+
os.environ['MASTER_ADDR'] = 'localhost'
77+
os.environ['MASTER_PORT'] = '29500'
78+
# create default process group
79+
dist.init_process_group("gloo", rank=rank, world_size=world_size)
80+
81+
# create local model
82+
model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
83+
print_peak_memory("Max memory allocated after creating local model", rank)
84+
85+
# construct DDP model
86+
ddp_model = DDP(model, device_ids=[rank])
87+
print_peak_memory("Max memory allocated after creating DDP", rank)
88+
89+
# define loss function and optimizer
90+
loss_fn = nn.MSELoss()
91+
if use_zero:
92+
optimizer = ZeroRedundancyOptimizer(
93+
ddp_model.parameters(),
94+
optim=torch.optim.Adam,
95+
lr=0.01
96+
)
97+
else:
98+
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)
99+
100+
# forward pass
101+
outputs = ddp_model(torch.randn(20, 2000).to(rank))
102+
labels = torch.randn(20, 2000).to(rank)
103+
# backward pass
104+
loss_fn(outputs, labels).backward()
105+
106+
# update parameters
107+
print_peak_memory("Max memory allocated before optimizer step()", rank)
108+
optimizer.step()
109+
print_peak_memory("Max memory allocated after optimizer step()", rank)
110+
111+
print(f"params sum is: {sum(model.parameters()).sum()}")
112+
113+
114+
115+
def main():
116+
world_size = 2
117+
print("=== Using ZeroRedundancyOptimizer ===")
118+
mp.spawn(example,
119+
args=(world_size, True),
120+
nprocs=world_size,
121+
join=True)
122+
123+
print("=== Not Using ZeroRedundancyOptimizer ===")
124+
mp.spawn(example,
125+
args=(world_size, False),
126+
nprocs=world_size,
127+
join=True)
128+
129+
if __name__=="__main__":
130+
main()
131+
132+
The output is shown below. When enabling ``ZeroRedundancyOptimizer`` with ``Adam``,
133+
the optimizer ``step()`` peak memory consumption is half of vanilla ``Adam``'s
134+
memory consumption. This agrees with our expectation, as we are sharding
135+
``Adam`` optimizer states across two processes. The output also shows that, with
136+
``ZeroRedundancyOptimizer``, the model parameters still end up with the same
137+
values after one iterations (the parameters sum is the same with and without
138+
``ZeroRedundancyOptimizer``).
139+
140+
::
141+
142+
=== Using ZeroRedundancyOptimizer ===
143+
Max memory allocated after creating local model: 335.0MB
144+
Max memory allocated after creating DDP: 656.0MB
145+
Max memory allocated before optimizer step(): 992.0MB
146+
Max memory allocated after optimizer step(): 1361.0MB
147+
params sum is: -3453.6123046875
148+
params sum is: -3453.6123046875
149+
=== Not Using ZeroRedundancyOptimizer ===
150+
Max memory allocated after creating local model: 335.0MB
151+
Max memory allocated after creating DDP: 656.0MB
152+
Max memory allocated before optimizer step(): 992.0MB
153+
Max memory allocated after optimizer step(): 1697.0MB
154+
params sum is: -3453.6123046875
155+
params sum is: -3453.6123046875

0 commit comments

Comments
 (0)