-
Notifications
You must be signed in to change notification settings - Fork 507
/
Copy pathtest_dynamic_shapes.py
630 lines (547 loc) · 22.3 KB
/
test_dynamic_shapes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
import os
import sys
import unittest
import torch, torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
sys.path.insert(1, os.path.join(sys.path[0], '..'))
import test_utils
pd = torch._C._EnablePythonDispatcher()
dev = xm.xla_device()
class TestDynamicShapes(test_utils.XlaTestCase):
def test_simple_expand(self):
size1 = 5
size2 = 2
t1 = torch.zeros([size1, size2], device=dev)
t1[3][0] = 1
t1[3][1] = 1
# t2 has size [<=10, 2]
t2 = torch.nonzero(t1)
t5 = torch.ones(1, device=dev)
t6 = t5.expand(t2.size(0))
self.assertIn('<=10', torch_xla._XLAC._get_xla_tensors_text([t6]))
t6_cpu = t6.cpu()
self.assertEqual(t6_cpu.shape[0], 2)
def test_simple_expand_on_2d_tensor(self):
size1 = 5
size2 = 2
t1 = torch.zeros([size1, size2], device=dev)
t1[3][0] = 1
t1[3][1] = 1
# t2 has size [<=10, 2]
t2 = torch.nonzero(t1)
t3 = torch.ones(1, size2, device=dev)
# varargs
t4 = t3.expand(t2.shape[0], t2.shape[1])
self.assertEqual(t4.shape[0], 2)
self.assertEqual(t4.shape[1], size2)
# shape list
t4 = t3.expand((t2.shape[0], t2.shape[1]))
self.assertEqual(t4.shape[0], 2)
self.assertEqual(t4.shape[1], size2)
# mixed python symints and ints
t4 = t3.expand(t2.shape[0], size2)
self.assertEqual(t4.shape[0], 2)
self.assertEqual(t4.shape[1], size2)
# mixed python symints and ints in a list
t4 = t3.expand((t2.shape[0], size2))
self.assertEqual(t4.shape[0], 2)
self.assertEqual(t4.shape[1], size2)
# size_clone should be called as part of decomposition from
# the python dispatcher.
self.assertGreater(met.counter_value("xla::size_clone"), 0)
def test_simple_expand_add_dimension(self):
size1 = 5
size2 = 2
t1 = torch.zeros([size1, size2], device=dev)
t1[3][0] = 1
t1[3][1] = 1
# t2 has size [<=10, 2]
t2 = torch.nonzero(t1)
t3 = torch.ones(1, device=dev)
t4 = t3.expand(t2.shape[0], t2.shape[0])
self.assertIsInstance(t4.shape[0], torch.SymInt)
self.assertEqual(str(t4.shape[0]), '<=10')
self.assertEqual(t4.shape[0], 2)
self.assertIsInstance(t4.shape[1], torch.SymInt)
self.assertEqual(str(t4.shape[1]), '<=10')
self.assertEqual(t4.shape[1], 2)
def test_wrap(self):
a1 = torch.tensor([[1, 0, 0, 5, 0, 6]], device=dev)
a2 = torch.nonzero(a1)
self.assertTrue(a2.shape[0] == 3)
a3 = a2.shape[0] + 3 # tests wrap
self.assertIsInstance(a3, torch.SymInt)
def test_sizeAdd(self):
size1 = 5
size2 = 2
t1 = torch.zeros([size1, size2], device=dev)
t1[3][0] = 1
# t2 has size [<=10, 2]
t2 = torch.nonzero(t1)
# Create a SizeAdd IR node.
# t2.shape[1] generates a SizeConstant node.
dyn_size = t2.shape[0] + t2.shape[1]
# Exercises SizeAdd::getDynamicValue.
dynamic_size = int(dyn_size)
self.assertEqual(dynamic_size, 3)
# Exercise SizeAdd::getStaticValue.
self.assertEqual(str(dyn_size), '<=12')
t3 = torch.ones(1, device=dev)
# Exercise SizeAdd::Lower.
t4 = t3.expand(dyn_size)
self.assertEqual(t4.size(0), 3)
def test_sizeSub(self):
size1 = 5
size2 = 2
t1 = torch.zeros([size1, size2], device=dev)
t1[0][0] = 1
t1[1][0] = 1
t1[2][0] = 1
# t2 has size [<=10, 2] with dynamic size=[3, 2]
t2 = torch.nonzero(t1)
dyn_size = t2.shape[0] - t2.shape[1]
self.assertGreater(met.counter_value("xla::size_sub"), 0)
# Exercises SizeSub::getDynamicValue.
dynamic_size = int(dyn_size)
self.assertEqual(dynamic_size, 1)
# Exercise SizeSub::getStaticValue.
self.assertEqual(str(dyn_size), '<=8')
t3 = torch.ones(1, device=dev)
# Exercise SizeSub::Lower.
t4 = t3.expand(dyn_size)
self.assertEqual(t4.size(0), 1)
def get_dynamic_tensor(self):
a1 = torch.tensor([[1, 0, 0, 5, 0, 6]], device=dev)
a2 = torch.nonzero(a1)
return a2
def test_empty_symint(self):
# t1.shape= torch.Size([<=6, 2]) with real size [3, 2]
t1 = self.get_dynamic_tensor()
# Don't print t1 otherwise it would cause the test to crash.
self.assertIsInstance(t1.shape[0], torch.SymInt)
t2 = torch.empty(t1.shape, dtype=torch.int32, device=dev)
self.assertIsInstance(t2.shape[0], torch.SymInt)
self.assertEqual(str(t2.shape[0]), '<=6')
self.assertEqual(t2.shape[0], 3)
self.assertIsInstance(t2.shape[1], int)
self.assertEqual(t2.shape[1], 2)
def test_t_copy(self):
t1 = torch.tensor([[1, 0, 0, 5, 0, 6], [1, 3, 2, 0, 0, 1]], device=dev)
t2 = torch.nonzero(t1)
# t2.shape=torch.Size([<=12, 2]) with real size [7, 2]
self.assertEqual(str(t2.shape[0]), '<=12')
self.assertEqual(str(t2.shape[1]), '2')
t2_t = torch.t(t2)
self.assertIsInstance(t2_t.shape[0], int)
self.assertIsInstance(t2_t.shape[1], torch.SymInt)
self.assertEqual(str(t2_t.shape[0]), '2')
self.assertEqual(str(t2_t.shape[1]), '<=12')
self.assertEqual(t2_t.shape[0], 2)
self.assertEqual(t2_t.shape[1], 7)
def test_nonzero_shape(self):
x = torch.tensor((0, 1, 2, 0, 3, 4), device=xm.xla_device())
x_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size(
torch.nonzero(x, as_tuple=False), 0)
self.assertEqual(x_dim0_shape.item(), 4)
def test_nonzero_correctness(self):
t1 = torch.tensor([[1, 0, 0, 5, 0, 6]], device=dev)
t2 = torch.nonzero(t1)
t1_aten = t1.cpu()
t2_aten = torch.nonzero(t1_aten)
self.assertEqual(t2.cpu(), t2_aten)
def test_masked_select_shape(self):
x = torch.tensor((0, 1, 2, 0, 3, 4), device=xm.xla_device())
mask = x.ge(2)
x_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size(
torch.masked_select(x, mask), 0)
self.assertEqual(x_dim0_shape.item(), 3)
def test_nonzero_cast(self):
t1 = torch.ones(5, 2, device=xm.xla_device())
# Result of the nonzero should be the index type. Currently
# index type is s64 on cpu and gpu, but s32 on TPU. We should be
# able to cast it to any other type without error.
t2 = torch.nonzero(t1.int()).float()
xm.mark_step()
def test_expand_symint_correctness(self):
dev = xm.xla_device()
size1 = 5
size2 = 2
t1 = torch.ones([size1, size2])
expand_out_aten = t1.expand(2, size1, size2)
t2 = torch.zeros([size1, size2], device=dev)
t2[3][0] = 1
t2[3][1] = 1
# t3 has size [<=10, 2]
t3 = torch.nonzero(t2)
t4 = torch.ones([size1, size2], device=dev)
expand_out_xla = t4.expand(t3.shape[0], size1, size2)
self.assertEqual(t3.shape[0], 2)
self.assertEqual(expand_out_aten.cpu(), expand_out_xla.cpu())
def test_unsqueeze_copy_dynamism(self):
t1 = torch.tensor([[1, 0, 0, 5, 0, 6], [1, 3, 2, 0, 0, 1]], device=dev)
t2 = torch.nonzero(t1)
# t2.shape=torch.Size([<=12, 2]) with real size [7, 2]
t2_unsqueeze = torch.unsqueeze(t2, 0)
self.assertEqual(len(t2_unsqueeze.size()), 3)
self.assertIsInstance(t2_unsqueeze.shape[0], int)
self.assertIsInstance(t2_unsqueeze.shape[1], torch.SymInt)
self.assertIsInstance(t2_unsqueeze.shape[2], int)
self.assertEqual(str(t2_unsqueeze.shape[0]), '1')
self.assertEqual(str(t2_unsqueeze.shape[1]), '<=12')
self.assertEqual(str(t2_unsqueeze.shape[2]), '2')
self.assertEqual(t2_unsqueeze.shape[0], 1)
self.assertEqual(t2_unsqueeze.shape[1], 7)
self.assertEqual(t2_unsqueeze.shape[2], 2)
# test correctness
t3 = torch.tensor([[1, 0, 0, 5, 0, 6], [1, 3, 2, 0, 0, 1]])
t4 = torch.nonzero(t3)
t4_unsqueeze = torch.unsqueeze(t4, 0)
self.assertEqual(t2_unsqueeze.cpu(), t4_unsqueeze.cpu())
def test_view_copy_symint_with_static_input_dyn_input_shape(self):
# If the input tensor and shape are “statically” incompatible, a compilation error is raised.
t1 = torch.tensor([1, 0, 3, 5, 0, 6], device=dev)
# t2.shape=torch.Size([<=6, 1]) with real size [4, 1]
# t2 = [[0], [2], [3], [5]]
t2 = torch.nonzero(t1)
t3 = torch.randint(10, (2, 2), device=dev)
self.assertRaises(RuntimeError, lambda: t3.view(t2.shape[0]))
# If their “dynamic” values are incompatible, a RuntimeError is raised.
t4 = torch.randint(10, (2, 3), device=dev)
self.assertRaises(RuntimeError, lambda: t4.view(t2.shape[0]))
# verify if dynamism is propagated correctly.
t5 = torch.tensor([1, 1, 3, 5, 1, 6], device=dev)
t6 = torch.nonzero(t5)
t7 = torch.ones((2, 3), device=dev)
# t6.shape=torch.Size([<=6, 1]) with real size [6, 1]
# t6 = [[0], [1], [2], [3], [4], [5]]
t8 = t7.view(t6.shape[0])
self.assertIsInstance(t8.shape[0], torch.SymInt)
self.assertEqual(str(t8.shape[0]), '<=6')
self.assertEqual(t8.shape[0], 6)
# verify correctness.
t5_aten = torch.tensor([1, 1, 3, 5, 1, 6])
t6_aten = torch.nonzero(t5_aten)
t7_aten = torch.ones((2, 3))
t8_aten = t7_aten.view(t6_aten.shape[0])
self.assertEqual(t8.cpu(), t8_aten.cpu())
def test_view_copy_symint_with_static_input_dyn_input_shape2(self):
# If the input tensor and shape are “statically” incompatible, a compilation error is raised.
t1 = torch.tensor([[1, 0, 3]], device=dev)
# t2.shape=torch.Size([<=3, 2]) with real size [2, 2]
# t2 = [[0, 0], [0, 2]]
t2 = torch.nonzero(t1)
t3 = torch.ones((2, 4), device=dev)
# Should fail in pytorch utils.infer_size
self.assertRaises(RuntimeError, lambda: t3.view(t2.shape))
# If their “dynamic” values are incompatible, a RuntimeError is raised.
t4 = torch.ones((2, 3), device=dev)
# Also fails in pytorch utils.infer_size
self.assertRaises(RuntimeError, lambda: t4.view(t2.shape))
# verify if dynamism is propagated correctly.
t5 = torch.tensor([[1, 1, 3]], device=dev)
t6 = torch.nonzero(t5)
# t6.shape=[<=3, 2] with real size [3, 2]
t7 = torch.ones((2, 3), device=dev)
t8 = t7.view(t6.shape)
self.assertIsInstance(t8.shape[0], torch.SymInt)
self.assertEqual(str(t8.shape[0]), '<=3')
self.assertEqual(t8.shape[0], 3)
self.assertIsInstance(t8.shape[1], int)
self.assertEqual(str(t8.shape[1]), '2')
self.assertEqual(t8.shape[1], 2)
# verify correctness.
t5_aten = torch.tensor([[1, 1, 3]])
t6_aten = torch.nonzero(t5_aten)
t7_aten = torch.ones((2, 3))
t8_aten = t7_aten.view(t6_aten.shape)
self.assertEqual(t8.cpu(), t8_aten.cpu())
def test_view_copy_symint_with_dyn_input_static_input_shape(self):
# If the input tensor is dynamic and input shape is static,
# it should fail because we will not likely have this case
# in reality so we don't support this feature.
t1 = torch.tensor([1, 1, 3, 5, 1, 6], device=dev)
# t2.shape=torch.Size([<=6, 1]) with real size [6, 1]
t2 = torch.nonzero(t1)
self.assertRaises(RuntimeError, lambda: t2.view(2, 3))
def test_view_copy_symint_with_dyn_input_dyn_input_shape(self):
# If the input tensor and shape are “statically” incompatible, a compilation error is raised.
t1 = torch.tensor([1, 0, 3, 5, 0, 6], device=dev)
# t2.shape=torch.Size([<=6, 1]) with real size [4, 1]
# t2 = [[0], [2], [3], [5]]
t2 = torch.nonzero(t1)
t3 = torch.tensor([1, 0, 3, 5, 0, 6, 7], device=dev)
# t4.shape=torch.Size([<=7, 1]) with real size [5, 1]
t4 = torch.nonzero(t3)
self.assertRaises(RuntimeError, lambda: t2.view(t4.shape[0]))
# If their “dynamic” values are incompatible, a RuntimeError is raised.
t5 = torch.tensor([1, 2, 3, 4, 5, 6, 0], device=dev)
# t6.shape=torch.Size([<=7, 1]) with real size [6, 1]
t6 = torch.nonzero(t5)
# statically compatible but dynamically incompatible.
# It will fail in pytorch layer.
self.assertRaises(RuntimeError, lambda: t6.view(t4.shape[0]))
# verify if dynamism is propagated correctly.
t7 = torch.tensor([1, 0, 3, 5, 0, 6, 7], device=dev)
t8 = torch.nonzero(t7)
# t8.shape=torch.Size([<=7, 1]) with real size [5, 1]
t9 = t8.view(t4.shape[0])
self.assertIsInstance(t9.shape[0], torch.SymInt)
self.assertEqual(str(t9.shape[0]), '<=7')
self.assertEqual(t9.shape[0], 5)
# verify correctness.
t7_aten = torch.tensor([1, 0, 3, 5, 0, 6, 7])
t8_aten = torch.nonzero(t7_aten)
# t8_aten.size=[5, 1]
t3_aten = torch.tensor([1, 0, 3, 5, 0, 6, 7])
t4_aten = torch.nonzero(t3_aten)
# t4_aten.size=[5, 1]
t9_aten = t8_aten.view(t4_aten.shape[0])
self.assertEqual(t9.cpu(), t9_aten.cpu())
def test_add_dyn_with_static_broadcastable(self):
t1 = torch.tensor([[1, 0, 3, 5, 0, 6]], device=dev)
t2 = torch.nonzero(t1)
t3 = torch.tensor([[1, 1]], device=dev)
# t2.shape=torch.Size([<=6, 2]) with real size [4, 2]
# t3.shape=torch.Size([1, 2]) with real size [1, 2]
t4 = torch.add(t2, t3)
self.assertIsInstance(t4.shape[0], torch.SymInt)
self.assertEqual(str(t4.shape[0]), '<=6')
self.assertEqual(t4.shape[0], 4)
self.assertIsInstance(t4.shape[1], int)
self.assertEqual(str(t4.shape[1]), '2')
self.assertEqual(t4.shape[1], 2)
# test for correctness
t1_aten = torch.tensor([[1, 0, 3, 5, 0, 6]])
t2_aten = torch.nonzero(t1_aten)
t3_aten = torch.tensor([[1, 1]])
t4_aten = torch.add(t2_aten, t3_aten)
self.assertEqual(t4.cpu(), t4_aten.cpu())
def test_add_dyn_with_static_not_broadcastable(self):
t1 = torch.tensor([[1, 0, 3, 5, 0, 6]], device=dev)
t2 = torch.nonzero(t1)
t3 = torch.tensor([[1, 1], [1, 1]], device=dev)
# t2.shape=torch.Size([<=6, 2]) with real size [4, 2]
# t3.shape=torch.Size([2, 2]) with real size [2, 2]
self.assertRaises(RuntimeError, lambda: torch.add(t2, t3))
self.assertRaises(RuntimeError, lambda: torch.add(t3, t2))
def test_add_two_dynamic_tensors(self):
t1 = torch.tensor([[1, 0, 3, 5, 0, 6]], device=dev)
t2 = torch.nonzero(t1)
t3 = torch.tensor([[1]], device=dev)
t4 = torch.nonzero(t3)
# t2.shape=torch.Size([<=6, 2]) with real size [4, 2]
# t4.shape=torch.Size([<=1, 2]) with real size [1, 2]
self.assertRaises(RuntimeError, lambda: torch.add(t2, t4))
self.assertRaises(RuntimeError, lambda: torch.add(t4, t2))
# For now, we disallow if both operands have the same upper bound and real size.
# This is consistent with PyTorch's behavior.
# t2.shape=torch.Size([<=6, 2]) with real size [4, 2]
# t6.shape=torch.Size([<=6, 2]) with real size [4, 2]
t5 = torch.tensor([[1, 0, 3, 5, 0, 6]], device=dev)
t6 = torch.nonzero(t5)
self.assertRaises(RuntimeError, lambda: torch.add(t2, t6))
def test_sub_dyn_with_static_broadcastable(self):
t1 = torch.tensor([[1, 0, 3, 5, 0, 6]], device=dev)
t2 = torch.nonzero(t1)
t3 = torch.tensor([[1, 1]], device=dev)
# t2.shape=torch.Size([<=6, 2]) with real size [4, 2]
# t3.shape=torch.Size([1, 2]) with real size [1, 2]
t4 = torch.sub(t2, t3)
self.assertIsInstance(t4.shape[0], torch.SymInt)
self.assertEqual(str(t4.shape[0]), '<=6')
self.assertEqual(t4.shape[0], 4)
self.assertIsInstance(t4.shape[1], int)
self.assertEqual(str(t4.shape[1]), '2')
self.assertEqual(t4.shape[1], 2)
# test for correctness
t1_aten = torch.tensor([[1, 0, 3, 5, 0, 6]])
t2_aten = torch.nonzero(t1_aten)
t3_aten = torch.tensor([[1, 1]])
t4_aten = torch.sub(t2_aten, t3_aten)
self.assertEqual(t4.cpu(), t4_aten.cpu())
def test_sub_dyn_with_static_not_broadcastable(self):
t1 = torch.tensor([[1, 0, 3, 5, 0, 6]], device=dev)
t2 = torch.nonzero(t1)
t3 = torch.tensor([[1, 1], [1, 1]], device=dev)
# t2.shape=torch.Size([<=6, 2]) with real size [4, 2]
# t3.shape=torch.Size([2, 2]) with real size [2, 2]
self.assertRaises(RuntimeError, lambda: torch.sub(t2, t3))
self.assertRaises(RuntimeError, lambda: torch.sub(t3, t2))
def test_sub_two_dynamic_tensors(self):
t1 = torch.tensor([[1, 0, 3, 5, 0, 6]], device=dev)
t2 = torch.nonzero(t1)
t3 = torch.tensor([[1]], device=dev)
t4 = torch.nonzero(t3)
# t2.shape=torch.Size([<=6, 2]) with real size [4, 2]
# t4.shape=torch.Size([<=1, 2]) with real size [1, 2]
self.assertRaises(RuntimeError, lambda: torch.sub(t2, t4))
self.assertRaises(RuntimeError, lambda: torch.sub(t4, t2))
# For now, we disallow if both operands have the same upper bound and real size.
# This is consistent with PyTorch's behavior.
# t2.shape=torch.Size([<=6, 2]) with real size [4, 2]
# t6.shape=torch.Size([<=6, 2]) with real size [4, 2]
t5 = torch.tensor([[1, 0, 3, 5, 0, 6]], device=dev)
t6 = torch.nonzero(t5)
self.assertRaises(RuntimeError, lambda: torch.sub(t2, t6))
self.assertRaises(RuntimeError, lambda: torch.sub(t6, t2))
def test_clone(self):
t1 = torch.tensor([1, 0, 3, 5, 0, 6], device=dev)
# t2.shape=torch.Size([<=6, 1]) with real size [4, 1]
# t2 = [[0], [2], [3], [5]]
t2 = torch.nonzero(t1)
t2_clone = torch.clone(t2)
self.assertIsInstance(t2_clone.shape[0], torch.SymInt)
self.assertEqual(str(t2_clone.shape[0]), '<=6')
self.assertEqual(t2_clone.shape[0], 4)
self.assertIsInstance(t2_clone.shape[1], int)
self.assertEqual(str(t2_clone.shape[1]), '1')
self.assertEqual(t2_clone.shape[1], 1)
# For correctness
self.assertEqual(t2.cpu(), t2_clone.cpu())
def test_xlatensor_memoize_symsizes(self):
met.clear_all()
t1 = torch.tensor([1, 0, 3, 5, 0, 6], device=dev)
# t2.shape=torch.Size([<=6, 1]) with real size [4, 1]
# t2 = [[0], [2], [3], [5]]
t2 = torch.nonzero(t1)
sym_size0 = t2.shape[0]
sym_size1 = t2.shape[0]
self.assertEqual(sym_size0, sym_size1)
self.assertIsNone(met.metric_data('CompileTime'))
def test_abs(self):
t1 = torch.tensor([1, 0, 3, 5, 0, 6], device=dev)
# t2.shape=torch.Size([<=6, 1]) with real size [4, 1]
# t2 = [[0], [2], [3], [5]]
t2 = torch.nonzero(t1)
t3 = torch.abs(t2)
self.assertIsInstance(t3.shape[0], torch.SymInt)
self.assertEqual(str(t3.shape[0]), '<=6')
self.assertEqual(t3.shape[0], 4)
self.assertIsInstance(t3.shape[1], int)
self.assertEqual(str(t3.shape[1]), '1')
self.assertEqual(t3.shape[1], 1)
# test for correctness
t1_aten = torch.tensor([1, 0, 3, 5, 0, 6])
t2_aten = torch.nonzero(t1_aten)
t3_aten = torch.abs(t2_aten)
self.assertEqual(t3.cpu(), t3_aten.cpu())
def test_fill_(self):
t1 = torch.tensor([1, 0, 3, 5, 0, 6], device=dev)
# t2.shape=torch.Size([<=6, 1]) with real size [4, 1]
# t2 = [[0], [2], [3], [5]]
t2 = torch.nonzero(t1)
self.assertIsInstance(t2.shape[0], torch.SymInt)
self.assertIsInstance(t2.shape[1], int)
t2.fill_(1)
self.assertIsInstance(t2.shape[0], torch.SymInt)
self.assertEqual(str(t2.shape[0]), '<=6')
self.assertEqual(t2.shape[0], 4)
self.assertIsInstance(t2.shape[1], int)
self.assertEqual(str(t2.shape[1]), '1')
self.assertEqual(t2.shape[1], 1)
# test for correctness
t1_aten = torch.tensor([1, 0, 3, 5, 0, 6])
t2_aten = torch.nonzero(t1_aten)
t2_aten.fill_(1)
self.assertEqual(t2.cpu(), t2_aten.cpu())
def test_sizeMod(self):
met.clear_all()
size1 = 5
size2 = 2
t1 = torch.zeros([size1, size2], device=dev)
t1[3][0] = 1
# t2 has size [<=10, 2] with real size [1, 2]
t2 = torch.nonzero(t1)
# Create a SizeMod IR node.
# t2.shape[1] generates a SizeConstant node.
dyn_size = t2.shape[0] % t2.shape[1]
self.assertGreater(met.counter_value("xla::size_mod"), 0)
# Exercises SizeMod::getDynamicValue.
dynamic_size = int(dyn_size)
self.assertEqual(dynamic_size, 1)
self.assertEqual(str(dyn_size), '<=0')
# t3 has size [<=10, 2] with real size [1, 2]
t3 = torch.nonzero(t1)
dyn_size = t2.shape[0] % t3.shape[0]
dynamic_size = int(dyn_size)
self.assertEqual(dynamic_size, 0)
self.assertEqual(str(dyn_size), '<=0')
def test_sizeGe(self):
met.clear_all()
size1 = 5
size2 = 2
t1 = torch.zeros([size1, size2], device=dev)
t1[3][0] = 1
# t2 has size [<=10, 2]
t2 = torch.nonzero(t1)
# Create a SizeAdd IR node.
# t2.shape[1] generates a SizeConstant node.
dyn_size = t2.shape[0] >= t2.shape[1]
self.assertGreater(met.counter_value("xla::size_ge"), 0)
# Exercises SizeGe::getDynamicValue.
dynamic_size = int(dyn_size)
self.assertEqual(dynamic_size, 0)
def test_sizeLt(self):
met.clear_all()
size1 = 5
size2 = 2
t1 = torch.zeros([size1, size2], device=dev)
t1[3][0] = 1
# t2 has size [<=10, 2]
t2 = torch.nonzero(t1)
# Create a SizeLt IR node.
# t2.shape[1] generates a SizeConstant node.
dyn_size = t2.shape[0] < t2.shape[1]
self.assertGreater(met.counter_value("xla::size_lt"), 0)
# Exercises SizeLt::getDynamicValue.
dynamic_size = int(dyn_size)
self.assertEqual(dynamic_size, 1)
def test_sizeGt(self):
met.clear_all()
size1 = 5
size2 = 2
t1 = torch.zeros([size1, size2], device=dev)
t1[3][0] = 1
# t2 has size [<=10, 2]
t2 = torch.nonzero(t1)
# Create a SizeGt IR node.
# t2.shape[1] generates a SizeConstant node.
dyn_size = t2.shape[0] > t2.shape[1]
self.assertGreater(met.counter_value("xla::size_gt"), 0)
# Exercises SizeGt::getDynamicValue.
dynamic_size = int(dyn_size)
# To evaluate dynamic value (1 > 2), hence false.
self.assertEqual(dynamic_size, 0)
def test_sizeNe(self):
met.clear_all()
size1 = 5
size2 = 2
t1 = torch.zeros([size1, size2], device=dev)
t1[3][0] = 1
# t2 has size [<=10, 2]
t2 = torch.nonzero(t1)
# Create a SizeAdd IR node.
# t2.shape[1] generates a SizeConstant node.
dyn_size = t2.shape[0] != t2.shape[1]
self.assertGreater(met.counter_value("xla::size_ne"), 0)
# Exercises SizeNe::getDynamicValue.
dynamic_size = int(dyn_size)
self.assertEqual(dynamic_size, 1)
def test_SizeEq_should_not_compile_for_identical_symints(self):
met.clear_all()
t1 = torch.tensor([1, 0, 3, 5, 0, 6, 7], device=dev)
t2 = torch.nonzero(t1)
dyn_size = t2.shape[0]
self.assertEqual(dyn_size, dyn_size)
# Without the code change, met.metric_data('CompileTime')[0] returns 1.
# self.assertIsNone(met.metric_data('CompileTime'))
# TODO(ds): Uncomment the line above after we implement 0/1 specialization.
# The extra compilation comes from the call `set_sizes_and_strides` in XLATensorImpl::XLATensorImpl when we compare a SymInt with 0.
self.assertEqual(met.metric_data('CompileTime')[0], 1)
if __name__ == '__main__':
assert os.environ['XLA_EXPERIMENTAL'] != ''
test = unittest.main()
# DISABLE PYTHON DISPATCHER FLAG
del pd
sys.exit(0 if test.result.wasSuccessful() else 1)