-
Notifications
You must be signed in to change notification settings - Fork 507
/
Copy pathtrain_resnet_amp.py
35 lines (28 loc) · 1.16 KB
/
train_resnet_amp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from train_resnet_base import TrainResNetBase
import itertools
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.core.xla_model as xm
from torch_xla.amp import autocast
# For more details check https://github.com/pytorch/xla/blob/master/docs/amp.md
class TrainResNetXLAAMP(TrainResNetBase):
def train_loop_fn(self, loader, epoch):
tracker = xm.RateTracker()
self.model.train()
loader = itertools.islice(loader, self.num_steps)
for step, (data, target) in enumerate(loader):
self.optimizer.zero_grad()
# Enables autocasting for the forward pass
with autocast(xm.xla_device()):
output = self.model(data)
loss = self.loss_fn(output, target)
# TPU amp uses bf16 hence gradient scaling is not necessary. If runnign with XLA:GPU
# check https://github.com/pytorch/xla/blob/master/docs/amp.md#amp-for-xlagpu.
loss.backward()
self.run_optimizer()
tracker.add(self.batch_size)
if step % 10 == 0:
xm.add_step_closure(
self._train_update, args=(step, loss, tracker, epoch))
if __name__ == '__main__':
xla_amp = TrainResNetXLAAMP()
xla_amp.start_training()