|
| 1 | +Shard Optimizer States with ZeroRedundancyOptimizer |
| 2 | +=================================================== |
| 3 | + |
| 4 | +.. note: |
| 5 | + `ZeroRedundancyOptimizer` is introduced in PyTorch 1.8 as a prototype |
| 6 | + feature. It API is subject to change. |
| 7 | +
|
| 8 | +In this recipe, you will learn: |
| 9 | + |
| 10 | +- The high-level idea of ``ZeroRedundancyOptimizer``. |
| 11 | +- How to use ``ZeroRedundancyOptimizer`` in distributed training and its impact. |
| 12 | + |
| 13 | + |
| 14 | +Requirements |
| 15 | +------------ |
| 16 | + |
| 17 | +- PyTorch 1.8+ |
| 18 | +- `Getting Started With Distributed Data Parallel <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`_ |
| 19 | + |
| 20 | + |
| 21 | +What is ``ZeroRedundancyOptimizer``? |
| 22 | +------------------------------------ |
| 23 | + |
| 24 | +The idea of ``ZeroRedundancyOptimizer`` comes from |
| 25 | +`DeepSpeed/ZeRO project <https://github.com/microsoft/DeepSpeed>`_ and |
| 26 | +`Marian <https://github.com/marian-nmt/marian-dev>`_ that shard |
| 27 | +optimizer states across distributed data-parallel processes to |
| 28 | +reduce per-process memory footprint. In the |
| 29 | +`Getting Started With Distributed Data Parallel <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`_ |
| 30 | +tutorial, we have shown how to use |
| 31 | +`DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html>`_ |
| 32 | +(DDP) to train models. In that tutorial, each process keeps a dedicated replica |
| 33 | +of the optimizer. Since DDP has already synchronized gradients in the |
| 34 | +backward pass, all optimizer replicas will operate on the same parameter and |
| 35 | +gradient values in every iteration, and this is how DDP keeps model replicas in |
| 36 | +the same state. Oftentimes, optimizers also maintain local states. For example, |
| 37 | +the ``Adam`` optimizer uses per-parameter ``exp_avg`` and ``exp_avg_sq`` states. As a |
| 38 | +result, the ``Adam`` optimizer's memory consumption is at least twice the model |
| 39 | +size. Given this observation, we can reduce the optimizer memory footprint by |
| 40 | +sharding optimizer states across DDP processes. More specifically, instead of |
| 41 | +creating per-param states for all parameters, each optimizer instance in |
| 42 | +different DDP processes only keeps optimizer states for a shard of all model |
| 43 | +parameters. The optimizer ``step()`` function only updates the parameters in its |
| 44 | +shard and then broadcasts its updated parameters to all other peer DDP |
| 45 | +processes, so that all model replicas still land in the same state. |
| 46 | + |
| 47 | +How to use ``ZeroRedundancyOptimizer``? |
| 48 | +--------------------------------------- |
| 49 | + |
| 50 | +The code below demonstrates how to use ``ZeroRedundancyOptimizer``. The majority |
| 51 | +of the code is similar to the simple DDP example presented in |
| 52 | +`Distributed Data Parallel notes <https://pytorch.org/docs/stable/notes/ddp.html>`_. |
| 53 | +The main difference is the ``if-else`` clause in the ``example`` function which |
| 54 | +wraps optimizer constructions, toggling between ``ZeroRedundancyOptimizer`` and |
| 55 | +``Adam`` optimizer. |
| 56 | + |
| 57 | + |
| 58 | +:: |
| 59 | + |
| 60 | + import os |
| 61 | + import torch |
| 62 | + import torch.distributed as dist |
| 63 | + import torch.multiprocessing as mp |
| 64 | + import torch.nn as nn |
| 65 | + import torch.optim as optim |
| 66 | + from torch.distributed.optim import ZeroRedundancyOptimizer |
| 67 | + from torch.nn.parallel import DistributedDataParallel as DDP |
| 68 | + |
| 69 | + def print_peak_memory(prefix, device): |
| 70 | + if device == 0: |
| 71 | + print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ") |
| 72 | + |
| 73 | + def example(rank, world_size, use_zero): |
| 74 | + torch.manual_seed(0) |
| 75 | + torch.cuda.manual_seed(0) |
| 76 | + os.environ['MASTER_ADDR'] = 'localhost' |
| 77 | + os.environ['MASTER_PORT'] = '29500' |
| 78 | + # create default process group |
| 79 | + dist.init_process_group("gloo", rank=rank, world_size=world_size) |
| 80 | + |
| 81 | + # create local model |
| 82 | + model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)]) |
| 83 | + print_peak_memory("Max memory allocated after creating local model", rank) |
| 84 | + |
| 85 | + # construct DDP model |
| 86 | + ddp_model = DDP(model, device_ids=[rank]) |
| 87 | + print_peak_memory("Max memory allocated after creating DDP", rank) |
| 88 | + |
| 89 | + # define loss function and optimizer |
| 90 | + loss_fn = nn.MSELoss() |
| 91 | + if use_zero: |
| 92 | + optimizer = ZeroRedundancyOptimizer( |
| 93 | + ddp_model.parameters(), |
| 94 | + optim=torch.optim.Adam, |
| 95 | + lr=0.01 |
| 96 | + ) |
| 97 | + else: |
| 98 | + optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01) |
| 99 | + |
| 100 | + # forward pass |
| 101 | + outputs = ddp_model(torch.randn(20, 2000).to(rank)) |
| 102 | + labels = torch.randn(20, 2000).to(rank) |
| 103 | + # backward pass |
| 104 | + loss_fn(outputs, labels).backward() |
| 105 | + |
| 106 | + # update parameters |
| 107 | + print_peak_memory("Max memory allocated before optimizer step()", rank) |
| 108 | + optimizer.step() |
| 109 | + print_peak_memory("Max memory allocated after optimizer step()", rank) |
| 110 | + |
| 111 | + print(f"params sum is: {sum(model.parameters()).sum()}") |
| 112 | + |
| 113 | + |
| 114 | + |
| 115 | + def main(): |
| 116 | + world_size = 2 |
| 117 | + print("=== Using ZeroRedundancyOptimizer ===") |
| 118 | + mp.spawn(example, |
| 119 | + args=(world_size, True), |
| 120 | + nprocs=world_size, |
| 121 | + join=True) |
| 122 | + |
| 123 | + print("=== Not Using ZeroRedundancyOptimizer ===") |
| 124 | + mp.spawn(example, |
| 125 | + args=(world_size, False), |
| 126 | + nprocs=world_size, |
| 127 | + join=True) |
| 128 | + |
| 129 | + if __name__=="__main__": |
| 130 | + main() |
| 131 | + |
| 132 | +The output is shown below. When enabling ``ZeroRedundancyOptimizer`` with ``Adam``, |
| 133 | +the optimizer ``step()`` peak memory consumption is half of vanilla ``Adam``'s |
| 134 | +memory consumption. This agrees with our expectation, as we are sharding |
| 135 | +``Adam`` optimizer states across two processes. The output also shows that, with |
| 136 | +``ZeroRedundancyOptimizer``, the model parameters still end up with the same |
| 137 | +values after one iterations (the parameters sum is the same with and without |
| 138 | +``ZeroRedundancyOptimizer``). |
| 139 | + |
| 140 | +:: |
| 141 | + |
| 142 | + === Using ZeroRedundancyOptimizer === |
| 143 | + Max memory allocated after creating local model: 335.0MB |
| 144 | + Max memory allocated after creating DDP: 656.0MB |
| 145 | + Max memory allocated before optimizer step(): 992.0MB |
| 146 | + Max memory allocated after optimizer step(): 1361.0MB |
| 147 | + params sum is: -3453.6123046875 |
| 148 | + params sum is: -3453.6123046875 |
| 149 | + === Not Using ZeroRedundancyOptimizer === |
| 150 | + Max memory allocated after creating local model: 335.0MB |
| 151 | + Max memory allocated after creating DDP: 656.0MB |
| 152 | + Max memory allocated before optimizer step(): 992.0MB |
| 153 | + Max memory allocated after optimizer step(): 1697.0MB |
| 154 | + params sum is: -3453.6123046875 |
| 155 | + params sum is: -3453.6123046875 |
0 commit comments