-
Notifications
You must be signed in to change notification settings - Fork 24k
/
Copy pathdistributed.py
277 lines (233 loc) · 12.7 KB
/
distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import logging
from dataclasses import dataclass, field
from typing import Any, List, Optional
import torch
import torch.fx.traceback as fx_traceback
from torch import fx
from torch.fx.node import Node
log = logging.getLogger(__name__)
def args_str(args):
# a debug helper
if torch.is_tensor(args):
return f"T[{args.shape}]"
elif isinstance(args, tuple):
return f"tuple({', '.join([args_str(x) for x in args])})"
elif isinstance(args, list):
return f"list({', '.join([args_str(x) for x in args])})"
else:
return str(args)
@dataclass
class Bucket:
size: int = 0
params: List[str] = field(default_factory=list)
nodes: List[fx.Node] = field(default_factory=list)
# param_ids is just used for unit testing
param_ids: List = field(default_factory=list)
def pretty_print_buckets(buckets: List[Bucket]):
headers = ("Index", "Size (b)", "Param Names")
rows = []
for idx, bucket in enumerate(reversed(buckets)):
if len(bucket.params) > 0:
rows.append((idx, bucket.size, bucket.params[0]))
for param in bucket.params[1:]:
rows.append((None, None, param))
try:
from tabulate import tabulate
log.info(
"\nDDPOptimizer bucket assignments\n"
+ tabulate(rows, headers=headers, tablefmt="simple_grid")
)
except ImportError:
log.info(
"Please `pip install tabulate` in order to pretty-print ddp bucket sizes"
)
class DDPOptimizer:
"""
DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP),
breaking the dynamo graph into chunks to compile separately, with the breaks aligning to
the boundaries of gradient-allreduce buckets chosen by DDP.
Background/Motivation
- DDP uses allreduce collectives to synchronize partial gradients computed on different workers
- DDP groups gradient allreduces into 'buckets' to optimize communication efficiency of all-reduce
- Parameters grouped into buckets are assumed to be adjacent in time, so they become ready
at around the same time during backward and thus can share the same allreduce efficently
- Allreduces must overlap with backward compute for optimal training performance
- DDP schedules allreduces using 'hooks' fired from the c++ autograd engine in pytorch, which
operates when individual grads become 'ready'
- Dynamo+AOTAutograd produces a single fused graph that runs 'atomically' from the perspective of the
autograd engine, such that all gradients become 'ready' at the same time. Hooks fire after the whole
fused backward function executes, preventing any overlap of compute and communication
Algorithm
- DDPOptimizer starts off with an FX graph traced by dynamo which represents forward. It can traverse
this graph in reverse order to determine the true order that gradients will become ready during backward.
- Parameter sizes are counted in reverse order, up to a bucket size limit, at which point a new bucket is started
and a graph break introduced
- Each of the subgraphs is compiled by the compiler provided to dynamo by the user, and then fused back together
into an outer module that is returned to the user
Notes
- It would be better to enforce (by adding an API to DDP) that the bucket splits chosen here are used by DDP,
and that DDP does not need to detect or optimize bucket order by observing execution at runtime, as it does
in eager.
- If Dynamo can't capture a whole graph for the portion of the model wrapped by DDP, this algorithm will currently
produce splits that do not necessarily align with the buckets used by DDP. This should result in performance
degradation approaching the baseline case where graph-splits are not used, but not worse.
- If the backend compiler fails to compile a single subgraph, it will execute eagerly despite the rest of the
subgraphs being compiled
- DDP has a 'parameters_and_buffers_to_ignore' field, which DDPOptimizer attempts to honor by reading markers
left by DDP on individual parameters. In cases where other transformations, such as reparameterization, are
also used, the ignore markers could be lost. If DDPOptimizer fails to ignore a parameter ignored by DDP,
it is not catastrophic but could impact performance by choosing sub-optimal bucket splits.
- DDPOptimizer always ignores all buffers, regardless of their ignore flag, since buffers do not require gradients,
and therefore aren't allreduced by DDP. (They are broadcast during forward, but this is not covered by
DDPOptimizer)
Args:
bucket_bytes_cap (int): Controls the size of buckets, in bytes, used to determine graphbreaks. Should be
set to match the equivalent parameter on the original DDP module.
backend_compile_fn (callable): A dynamo compiler function, to be invoked to compile each subgraph.
first_bucket_cap (int): Controls the size of the first bucket. Should match DDP's first bucket cap. DDP
special-cases the first bucket size since it is sometimes optimal to start a small allreduce early.
"""
def __init__(
self,
bucket_bytes_cap: int,
backend_compile_fn,
first_bucket_cap: Optional[int] = None,
):
if first_bucket_cap is not None:
self.first_bucket_cap = first_bucket_cap
elif torch.distributed.is_available():
# this constant comes from C10D lib which is not always built
self.first_bucket_cap = torch.distributed._DEFAULT_FIRST_BUCKET_BYTES
else:
self.first_bucket_cap = bucket_bytes_cap
self.bucket_bytes_cap = bucket_bytes_cap
assert (
self.first_bucket_cap <= self.bucket_bytes_cap
), "First bucket should be smaller/equal to other buckets to get comms warmed up ASAP"
self.backend_compile_fn = backend_compile_fn
def _ignore_parameter(self, parameter):
return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored
def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):
"""
Implements graph splitting, first determining a set of of buckets by counting
parameter sizes in reverse graph order, then invoking the user/backend compiler
to compile each subgraph. Finally, stiches compiled graphs into one graphmodule
and returns its callable.
"""
# 1: compute the partition map according to DDP bucket logic
buckets = [Bucket()] # (size, param_names)
for node in reversed(gm.graph.nodes):
if node.op in ("output", "placeholder"):
continue
if (
buckets[0].size >= self.bucket_bytes_cap
or len(buckets) == 1
and buckets[0].size >= self.first_bucket_cap
):
buckets.insert(0, Bucket())
if node.op == "call_module":
target = gm.get_submodule(node.target)
for name, p in target.named_parameters():
param = target.get_parameter(name)
if p.requires_grad and not self._ignore_parameter(param):
buckets[0].size += p._storage().nbytes()
buckets[0].params.append(f"{node.target}_{name}")
buckets[0].param_ids.append(id(param))
elif node.op == "get_attr":
maybe_param = getattr(gm, node.target)
if maybe_param.requires_grad and not self._ignore_parameter(
maybe_param
):
buckets[0].size += maybe_param._storage().nbytes()
buckets[0].params.append(node.target)
buckets[0].param_ids.append(id(maybe_param))
# All nodes have to be mapped to a bucket, even if they don't have their own params
# Ignored params still end up in buckets, we just don't count them towards the capacity
buckets[0].nodes.append(node)
# stash buckets for testing/debugging purposes
self.buckets = buckets
log.info(
f"DDPOptimizer used bucket cap {self.bucket_bytes_cap} and produced the following buckets:"
)
pretty_print_buckets(buckets)
if len(buckets) == 1:
# bypass split/fuse logic if there is only one bucket
return self.backend_compile_fn(gm, example_inputs)
# 2: partition the graphmodule according to bucket capacity
partition_map = {}
for idx, b in enumerate(buckets):
for node in b.nodes:
partition_map[node] = idx
split_gm = fx.passes.split_module.split_module(
gm, None, lambda node: partition_map[node]
)
debug_str = (
f"\n---orig graph---\n{gm.graph}\n"
+ f"\n---split graph---\n{split_gm.graph}\n"
)
for name, module in split_gm.named_modules():
if "." not in name and len(name):
# only print the submod graphs, not their children
debug_str += f"\n---{name} graph---\n{module.graph}\n"
debug_str += "\n---------------\n"
log.debug(debug_str)
# 3: compile each of the partitioned submodules using the user-provided compiler
class SubmodCompiler(torch.fx.interpreter.Interpreter):
def __init__(self, module, compiler):
super().__init__(module)
self.compiler = compiler
def compile_submod(self, submod, args, kwargs):
"""
Compile the submodule,
using a wrapper to make sure its output is always a tuple,
which is required by AotAutograd based compilers
"""
assert len(kwargs) == 0, "We assume only args for these modules"
class WrapperModule(torch.nn.Module):
def __init__(self, compiled_submod, unwrap_singleton_tuple):
super().__init__()
self.compiled_submod = compiled_submod
self.unwrap_singleton_tuple = unwrap_singleton_tuple
def forward(self, *args):
x = self.compiled_submod(*args)
# TODO(whc)
# for some reason the isinstance check is necessary if I split one node per submod
# - even though I supposedly wrapped the output in a tuple in those cases, the real
# compiled module was still returning a tensor
if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)):
return x[0]
return x
unwrap_singleton_tuple = False
for sn in submod.graph.nodes:
if sn.op == "output":
if not isinstance(sn.args[0], tuple):
unwrap_singleton_tuple = True
sn.args = (sn.args,)
submod.recompile()
wrapper = WrapperModule(
self.compiler(submod, args),
unwrap_singleton_tuple,
)
return wrapper
def run_node(self, n: Node) -> Any:
with fx_traceback.append_stack_trace(n.stack_trace):
args, kwargs = self.fetch_args_kwargs_from_env(n)
log.debug(f"run_node {n.op}, {n.target} got args {args_str(args)}")
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
# modify the currently running FX graph
# maybe this isn't sound in general, but only changing the target of a node might be ok?
if n.op == "call_module":
submod = self.fetch_attr(n.target)
log.debug(f"\n---{n.target} graph---\n" + str(submod.graph))
compiled_submod = self.compile_submod(submod, args, kwargs)
self.module.delete_submodule(n.target)
n.target = "compiled_" + n.target
self.module.add_submodule(n.target, compiled_submod)
# then we execute the modified node using the usual logic
return getattr(self, n.op)(n.target, args, kwargs)
submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn)
submod_compiler.run(*example_inputs)
split_gm.recompile()
log.debug("\n---final graph---\n" + str(split_gm.graph) + "\n---------------\n")
return split_gm