-
Notifications
You must be signed in to change notification settings - Fork 507
/
Copy pathtest_while_loop.py
116 lines (86 loc) · 3.47 KB
/
test_while_loop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import unittest
from typing import Callable, Dict, List
import torch
import torch_xla
# We need to import the underlying implementation function to register with the dispatcher
import torch_xla.experimental.fori_loop
from torch_xla.experimental.fori_loop import fori_loop
from torch._higher_order_ops.while_loop import while_loop
import torch_xla.core.xla_model as xm
import torch_xla.core.xla_builder as xb
import torch_xla.utils.utils as xu
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
def _fake_while_loop(cond_fn, body_fn, operands):
# operands need to be more than one here
while cond_fn(*operands):
operands = body_fn(*operands)
return operands
class WhileLoopTest(unittest.TestCase):
def test_while_loop_addition(self):
device = xm.xla_device()
def cond_fn(iteri, x):
return iteri > 0
def body_fn(iteri, x):
return iteri - 1, torch.add(x, 1)
init_val = torch.tensor(3, dtype=torch.int32, device=device)
iteri = torch.tensor(10, device=device)
_, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val))
_, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val))
self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop)))
def test_while_loop_addition_nested(self):
device = xm.xla_device()
def cond_fn(iteri, x):
return iteri > 0
def body_fn(iteri, x):
return iteri - 1, torch.add(torch.add(x, 1), 1)
init_val = torch.tensor(2, dtype=torch.int32, device=device)
iteri = torch.tensor(10, device=device)
_, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val))
_, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val))
self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop)))
def test_while_loop_simple_linear_inside_loop(self):
device = xm.xla_device()
torch.set_grad_enabled(False)
class SimpleLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, iteri, x):
def cond_fn(iteri, x):
return iteri > 0
def body_fn(iteri, x):
return iteri - 1, self.linear(x)
return while_loop(cond_fn, body_fn, (iteri, x))
def forward_without_while_loop_op(self, iteri, x):
while (iteri > 0):
x = self.linear(x)
iteri -= 1
return iteri, x
linear_model = SimpleLinear()
linear_model.to(device)
l_in_0 = torch.randn(2, 2, dtype=torch.float32, device=device)
iteri = torch.tensor(10, dtype=torch.int32, device=device)
_, res_with_loop = linear_model(iteri, l_in_0)
_, res_without_loop = linear_model.forward_without_while_loop_op(
iteri, l_in_0)
self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop)))
# ====== fori_loop ======
@unittest.skip("Fori_loop is not supported now due to unstable result.")
def test_fori_loop_addition(self):
device = xm.xla_device()
lower = torch.tensor(0, device=device)
upper = torch.tensor(50, device=device)
init_val = torch.tensor(1, dtype=torch.int32, device=device)
def body_fun(x):
return torch.add(x, 1)
_, res_with_loop = fori_loop(lower, upper, body_fun, (init_val))
# === expected ===
for i in range(upper - lower):
init_val = torch.add(init_val, 1)
res_without_loop = init_val
if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)