-
Notifications
You must be signed in to change notification settings - Fork 505
/
Copy pathtest_torch_distributed_xla_backend.py
383 lines (330 loc) · 14.5 KB
/
test_torch_distributed_xla_backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
import contextlib
import functools
import os
import re
from unittest import mock, skipIf
from absl.testing import absltest, parameterized
import torch
import torch.distributed as dist
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_backend
from torch_xla import runtime as xr
from datetime import timedelta
def get_process_group_xla(rank, size):
pg_xla_creator = dist.Backend._plugins['XLA'].creator_fn
pg_xla = pg_xla_creator(
prefix_store=None, rank=rank, size=size, timeout=timedelta(minutes=1))
return pg_xla
def hlo_matches(hlo, expected_pattern, match_times=1):
matches = re.findall(expected_pattern, hlo)
assert len(list(matches)) == match_times, hlo
@contextlib.contextmanager
def new_group_barrier_disabled():
with mock.patch.object(torch.distributed.distributed_c10d,
'_store_based_barrier'):
yield
@contextlib.contextmanager
def patch_world(rank, size):
assert isinstance(dist.group.WORLD,
torch_xla.distributed.xla_backend.ProcessGroupXla)
with mock.patch.object(
dist.group.WORLD, 'rank', return_value=rank), mock.patch.object(
dist.group.WORLD, 'size', return_value=size):
yield
class XlaBackendTest(parameterized.TestCase):
@classmethod
def setUpClass(cls):
# Add no-op all-reduce ops to HLO
os.environ['XLA_ALWAYS_ALLREDUCE'] = '1'
dist.init_process_group('xla', init_method='xla://')
def tearDown(self) -> None:
# Purge all computations attached the device.
xm.mark_step()
def test_xla_backend_exists(self):
# torch_xla.distributed._register_xla_backend() should have been
# automatically called.
pg_xla_creator = dist.Backend.XLA
self.assertIsNotNone(pg_xla_creator)
def test_allreduce(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
all_reduce_pattern = r'%all\-reduce\.\d+ = .+ all\-reduce\('
dist.all_reduce(tensor)
hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
hlo_matches(hlo, all_reduce_pattern)
@patch_world(rank=3, size=6)
def test_allreduce_with_mesh(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
pg_options = {'xla_pg_options': {'spmd': True}}
ranks = [2, 3]
with new_group_barrier_disabled():
new_pg = dist.new_group(ranks=ranks, pg_options=pg_options)
opts = dist.AllreduceOptions()
opts.reduceOp = dist.ReduceOp.SUM
all_reduce_pattern = (r'%all\-reduce\.\d+ = .+ all\-reduce\(.+\), .*'
r'replica_groups=\{\{0,1\},\{2,3\},\{4,5\}\}')
new_pg.allreduce([tensor], opts)
hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
hlo_matches(hlo, all_reduce_pattern)
@patch_world(rank=3, size=8)
def test_allgather(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
output_tensors = [torch.zeros_like(tensor, device=device) for _ in range(8)]
all_gather_pattern = r'%all\-gather\.\d+ = .+ all\-gather\('
dist.all_gather(output_tensors, tensor)
hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors)
hlo_matches(hlo, all_gather_pattern)
@patch_world(rank=3, size=8)
def test_all_scalar_allgather(self):
device = xm.xla_device()
tensor = torch.zeros((), device=device) + 1 + 2 * dist.get_rank()
output_tensors = [torch.zeros_like(tensor, device=device) for _ in range(8)]
all_gather_pattern = r'%all\-gather\.\d+ = .+ all\-gather\('
dist.all_gather(output_tensors, tensor)
hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors)
hlo_matches(hlo, all_gather_pattern)
@patch_world(rank=3, size=8)
def test_allgather_coalesced(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank()
pg_xla = get_process_group_xla(rank=3, size=8)
output_tensors = [torch.zeros_like(tensor)] * 8
output_tensors2 = [torch.zeros_like(tensor2)] * 8
# because we set os.environ[xenv.WORLD_SIZE] = '1', here the outputs'
# shapes will be same as the inputs' shapes.
# Ex: %all-gather.26 = (s64[2]{0}, s64[5]{0}) all-gather(s64[2]{0} %get-tuple-element.24, s64[5]{0} %get-tuple-element.25), replica_groups={}, dimensions={0}
all_gather_pattern = (
r'%all-gather\.\d+ = \(s64\[2]\{0}, s64\[5]\{0}\) '
r'all-gather\(s64\[2]\{0} %.+\.\d+, s64\[5]\{0} %.+\.\d+\)')
pg_xla.allgather_coalesced([output_tensors, output_tensors2],
[tensor, tensor2])
hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors)
hlo_matches(hlo, all_gather_pattern)
def test_broadcast(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
all_reduce_pattern = r'%all\-reduce\.\d+ = .+ all\-reduce\('
dist.broadcast(tensor, 0)
hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
hlo_matches(hlo, all_reduce_pattern)
# Needed for ZeRO stage 1
def test_reduce_scatter(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
input_list = [tensor]
output = torch.zeros_like(tensor)
reduce_scatter_pattern = r'%reduce\-scatter\.\d+ = .+ reduce\-scatter\('
dist.reduce_scatter(output, input_list)
hlo = torch_xla._XLAC._get_xla_tensors_hlo([output])
hlo_matches(hlo, reduce_scatter_pattern)
@skipIf(xr.device_type() == 'CPU',
"UNIMPLEMENTED: ReduceScatter is not implemented on CPU.")
def test_reduce_scatter_coalesced(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank()
input_tensors_list = [[tensor, tensor], [tensor2, tensor2]]
output_list = [torch.zeros_like(tensor), torch.zeros_like(tensor2)]
pg_xla = get_process_group_xla(rank=0, size=len(input_tensors_list[0]))
opts = dist.ReduceScatterOptions()
opts.reduceOp = dist.ReduceOp.SUM
reduce_scatter_pattern = (
r'%reduce\-scatter\.\d+ = \(s64\[2]\{0}, s64\[5]\{0}, s64\[]\) '
r'reduce\-scatter\(s64\[4]\{0} %.+\.\d+, s64\[10]\{0} %.+\.\d+, '
r's64\[] %.+\.\d+\)')
pg_xla.reduce_scatter_coalesced(output_list, input_tensors_list, opts)
hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_list)
hlo_matches(hlo, reduce_scatter_pattern)
# purge all computations attached the device.
xm.mark_step()
@patch_world(0, 6)
def test_send(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
input_list = [tensor]
with mock.patch.object(
torch_xla.distributed.xla_backend.ProcessGroupXla,
'make_send_channel_id',
new=lambda self, dst_rank, tag: dst_rank * 2):
dist.send(tensor, 1)
send_pattern = r'%send\.\d+ = .+ send\(.+\), channel_id=2'
senddone_pattern = r'%send\-done\.\d+ = .+ send\-done\(.+\), channel_id=2'
hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
hlo_matches(hlo, send_pattern)
hlo_matches(hlo, senddone_pattern)
# Don't try to run Send on CPU because it's not implemented
torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))
@patch_world(0, 6)
def test_recv(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
with mock.patch.object(
torch_xla.distributed.xla_backend.ProcessGroupXla,
'make_recv_channel_id',
new=lambda self, src_rank, tag: src_rank * 3):
dist.recv(tensor, 1)
recv_pattern = r'%recv\.\d+ = .+ recv\(.+\), channel_id=3'
recvdone_pattern = r'%recv\-done\.\d+ = .+ recv\-done\(.+\), channel_id=3'
hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
hlo_matches(hlo, recv_pattern)
hlo_matches(hlo, recvdone_pattern)
# Don't try to run Recv on CPU because it's not implemented
torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))
@patch_world(rank=0, size=12)
def test_new_group_no_ranks(self):
with new_group_barrier_disabled():
pg = dist.new_group()
self.assertIsInstance(pg, torch_xla.distributed.xla_backend.ProcessGroupXla)
self.assertEqual(pg.size(), dist.get_world_size())
def test_new_group_horizontal(self):
pg_options = {'xla_pg_options': {'spmd': True}}
with patch_world(rank=5, size=12):
ranks = [4, 5, 6, 7]
with new_group_barrier_disabled():
pg = dist.new_group(ranks=ranks, pg_options=pg_options)
self.assertIsInstance(pg,
torch_xla.distributed.xla_backend.ProcessGroupXla)
self.assertEqual(pg.size(), len(ranks))
self.assertEqual(pg.rank(), ranks.index(5))
self.assertListEqual(pg._mesh,
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
with patch_world(rank=2, size=12):
ranks = [0, 1, 2, 3]
with new_group_barrier_disabled():
pg = dist.new_group(ranks=ranks, pg_options=pg_options)
self.assertIsInstance(pg,
torch_xla.distributed.xla_backend.ProcessGroupXla)
self.assertEqual(pg.size(), len(ranks))
self.assertEqual(pg.rank(), ranks.index(2))
self.assertListEqual(pg._mesh,
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
with patch_world(rank=11, size=12):
ranks = [8, 9, 10, 11]
with new_group_barrier_disabled():
pg = dist.new_group(ranks=ranks, pg_options=pg_options)
self.assertIsInstance(pg,
torch_xla.distributed.xla_backend.ProcessGroupXla)
self.assertEqual(pg.size(), len(ranks))
self.assertEqual(pg.rank(), ranks.index(11))
self.assertListEqual(pg._mesh,
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
def test_new_group_vertical(self):
pg_options = {'xla_pg_options': {'spmd': True}}
with patch_world(rank=5, size=12):
ranks = [1, 5, 9]
with new_group_barrier_disabled():
pg = dist.new_group(ranks=ranks, pg_options=pg_options)
self.assertIsInstance(pg,
torch_xla.distributed.xla_backend.ProcessGroupXla)
self.assertEqual(pg.size(), len(ranks))
self.assertEqual(pg.rank(), ranks.index(5))
self.assertListEqual(pg._mesh,
[[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]])
with patch_world(rank=0, size=12):
ranks = [0, 4, 8]
with new_group_barrier_disabled():
pg = dist.new_group(ranks=ranks, pg_options=pg_options)
self.assertIsInstance(pg,
torch_xla.distributed.xla_backend.ProcessGroupXla)
self.assertEqual(pg.size(), len(ranks))
self.assertEqual(pg.rank(), ranks.index(0))
self.assertListEqual(pg._mesh,
[[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]])
with patch_world(rank=11, size=12):
ranks = [3, 7, 11]
with new_group_barrier_disabled():
pg = dist.new_group(ranks=ranks, pg_options=pg_options)
self.assertIsInstance(pg,
torch_xla.distributed.xla_backend.ProcessGroupXla)
self.assertEqual(pg.size(), len(ranks))
self.assertEqual(pg.rank(), ranks.index(11))
self.assertListEqual(pg._mesh,
[[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]])
@patch_world(rank=5, size=12)
def test_new_group_one_paticipant(self):
pg_options = {'xla_pg_options': {'spmd': True}}
ranks = [5]
with new_group_barrier_disabled():
pg = dist.new_group(ranks=ranks, pg_options=pg_options)
self.assertIsInstance(pg, torch_xla.distributed.xla_backend.ProcessGroupXla)
self.assertEqual(pg.size(), 1)
self.assertEqual(pg.rank(), 0)
self.assertListEqual(
pg._mesh,
[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11]])
@patch_world(rank=5, size=12)
def test_new_group_entire_world(self):
ranks = range(12)
with new_group_barrier_disabled():
pg = dist.new_group(ranks=ranks)
self.assertIsInstance(pg, torch_xla.distributed.xla_backend.ProcessGroupXla)
self.assertEqual(pg.size(), 12)
self.assertEqual(pg.rank(), 5)
self.assertListEqual(pg._mesh, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])
def test_new_group_invalid_horizontal(self):
pg_options = {'xla_pg_options': {'spmd': True}}
with patch_world(rank=5, size=12):
ranks = [4, 5, 6]
with new_group_barrier_disabled():
with self.assertRaises(ValueError):
dist.new_group(ranks=ranks, pg_options=pg_options)
with patch_world(rank=2, size=12):
ranks = [0, 1, 2, 3, 4]
with new_group_barrier_disabled():
with self.assertRaises(ValueError):
dist.new_group(ranks=ranks, pg_options=pg_options)
with patch_world(rank=9, size=12):
ranks = [7, 8, 9, 10]
with new_group_barrier_disabled():
with self.assertRaises(ValueError):
dist.new_group(ranks=ranks, pg_options=pg_options)
def test_new_group_invalid_vertical(self):
pg_options = {'xla_pg_options': {'spmd': True}}
with patch_world(rank=5, size=12):
ranks = [1, 5]
with new_group_barrier_disabled():
with self.assertRaises(ValueError):
dist.new_group(ranks=ranks, pg_options=pg_options)
with patch_world(rank=4, size=12):
ranks = [4, 7, 10]
with new_group_barrier_disabled():
with self.assertRaises(ValueError):
dist.new_group(ranks=ranks, pg_options=pg_options)
def test_new_group_invalid_ranks(self):
# unevenly distributed
pg_options = {'xla_pg_options': {'spmd': True}}
with patch_world(rank=5, size=12):
ranks = [1, 5, 10]
with new_group_barrier_disabled():
with self.assertRaises(ValueError):
dist.new_group(ranks=ranks, pg_options=pg_options)
def test_barrier(self):
# nothing to verify. Just run it through.
dist.barrier()
@parameterized.parameters(
'reduce',
'allreduce_coalesced',
'alltoall',
'gather',
'scatter',
'recv_anysource',
'monitored_barrier',
)
def test_unimplemented_op(self, op):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
pg_xla = dist.group.WORLD
self.assertIsInstance(pg_xla,
torch_xla.distributed.xla_backend.ProcessGroupXla)
with self.assertRaises(NotImplementedError):
getattr(pg_xla, op)(tensor)
if __name__ == '__main__':
if xr.device_type() != 'CPU':
print(f"Skipping XLA backend unit test as this test doesn't exercise"
"{xr.pjrt_device}-specific behaviors.")
exit(0)
absltest.main()