-
Notifications
You must be signed in to change notification settings - Fork 507
/
Copy pathtrain_decoder_only_base.py
162 lines (142 loc) · 5 KB
/
train_decoder_only_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel
from torch_xla import runtime as xr
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import argparse
import time
import itertools
import torch
import torch_xla
import torch.nn as nn
class TrainDecoderOnlyBase:
def __init__(self,
decoder_cls=DecoderOnlyModel,
num_steps: int = 200,
config=DecoderOnlyConfig()):
self.config = config
if xr.device_type() == 'NEURON':
self.batch_size = 4
else:
self.batch_size = 16
self.seq_len = 512
self.num_steps = num_steps
self.num_epochs = 1
self.train_dataset_len = 1200000 # Roughly the size of Imagenet dataset.
# For the purpose of this example, we are going to use fake data.
train_loader = xu.SampleGenerator(
data=(torch.zeros(self.batch_size, self.seq_len, dtype=torch.int64),
torch.zeros(self.batch_size, self.seq_len, dtype=torch.int64)),
sample_count=self.train_dataset_len // self.batch_size)
self.device = torch_xla.device()
self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device)
self.model = decoder_cls(self.config).to(self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0001)
self.loss_fn = nn.CrossEntropyLoss()
# Compile the step fn
self.compiled_step_fn = torch_xla.compile(
self.step_fn, full_graph=True, name="decoder_step_fn")
def _train_update(self, step, loss, tracker, epoch):
print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}')
assert not torch.isnan(loss).item(), "Loss became NaN!"
def run_optimizer(self):
self.optimizer.step()
def step_fn(self, data, target):
self.optimizer.zero_grad()
logits = self.model(data)
loss = self.loss_fn(
logits.view(-1, self.config.vocab_size), target.view(-1))
loss.backward()
self.run_optimizer()
return loss
def train_loop_fn(self, loader, epoch):
tracker = xm.RateTracker()
self.model.train()
loader = itertools.islice(loader, self.num_steps)
for step, (data, target) in enumerate(loader):
loss = self.compiled_step_fn(data, target)
tracker.add(self.batch_size)
if step % 10 == 0:
xm.add_step_closure(
self._train_update, args=(step, loss, tracker, epoch))
def start_training(self):
for epoch in range(1, self.num_epochs + 1):
xm.master_print('Epoch {} train begin {}'.format(
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
self.train_loop_fn(self.train_device_loader, epoch)
xm.master_print('Epoch {} train end {}'.format(
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y')))
xm.wait_device_ops()
if __name__ == '__main__':
parser = argparse.ArgumentParser("Train a decoder only model")
parser.add_argument(
"cls_name",
type=str,
nargs="?",
default=None,
help="The decoder model to train, as fully qualified Python class. \
Defauls to decoder_only_model.DecoderOnlyModel")
parser.add_argument(
"--num-steps",
type=int,
default=200,
help="Number of steps to train the model for")
parser.add_argument(
"--hidden-size",
type=int,
default=1024,
help="Hidden size of the model, aka the embedding size")
parser.add_argument(
"--num-layers",
type=int,
default=2,
help="Number of decoder layers in the model",
)
parser.add_argument(
"--num-attention-heads",
type=int,
default=8,
help="Number of attention heads in the model",
)
parser.add_argument(
"--num-key-value-heads",
type=int,
default=4,
help="Number of key value heads in the model",
)
parser.add_argument(
"--intermediate-size",
type=int,
default=32 * 1024,
help="Intermediate size of the model, aka the up-projection output size",
)
parser.add_argument(
"--print-metrics",
action="store_true",
help="Print torch_xla metrics at the end of the training",
)
args = parser.parse_args()
# Seed the RNG for deterministic results
torch.manual_seed(42)
torch_xla.manual_seed(42)
# Figure out the decoder model to use
decoder_cls = None
if args.cls_name is not None:
xm.master_print(f'Using decoder class: {args.cls_name}')
module, cls_name = args.cls_name.rsplit('.', 1)
decoder_cls = getattr(__import__(module, fromlist=[cls_name]), cls_name)
# Initialize config
config = DecoderOnlyConfig(
hidden_size=args.hidden_size,
num_hidden_layers=args.num_layers,
num_attention_heads=args.num_attention_heads,
num_key_value_heads=args.num_key_value_heads,
intermediate_size=args.intermediate_size,
)
params = []
if decoder_cls is not None:
params.append(decoder_cls)
base = TrainDecoderOnlyBase(*params, num_steps=args.num_steps, config=config)
base.start_training()
if args.print_metrics:
print(torch_xla._XLAC._xla_metrics_report())