-
Notifications
You must be signed in to change notification settings - Fork 9.6k
/
Copy pathtrain_utils.py
102 lines (83 loc) · 3.11 KB
/
train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import torch
import torch.distributed as dist
from datetime import datetime
import tqdm
from transformers import AutoTokenizer, GPT2TokenizerFast
from transformers import T5Tokenizer, T5ForConditionalGeneration
g_gigabyte = 1024**3
def setup():
# initialize the process group
dist.init_process_group("nccl")
def cleanup():
dist.destroy_process_group()
def get_date_of_run():
"""create date and time for file save uniqueness
example: 2022-05-07-08:31:12_PM'
"""
date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
print(f"--> current date and time of run = {date_of_run}")
return date_of_run
def format_metrics_to_gb(item):
"""quick function to format numbers to gigabyte and round to 4 digit precision"""
metric_num = item / g_gigabyte
metric_num = round(metric_num, ndigits=4)
return metric_num
def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
model.train()
local_rank = int(os.environ['LOCAL_RANK'])
fsdp_loss = torch.zeros(2).to(local_rank)
if sampler:
sampler.set_epoch(epoch)
if rank==0:
inner_pbar = tqdm.tqdm(
range(len(train_loader)), colour="blue", desc="r0 Training Epoch"
)
for batch in train_loader:
for key in batch.keys():
batch[key] = batch[key].to(local_rank)
optimizer.zero_grad()
output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"] )
loss = output["loss"]
loss.backward()
optimizer.step()
fsdp_loss[0] += loss.item()
fsdp_loss[1] += len(batch)
if rank==0:
inner_pbar.update(1)
dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
train_accuracy = fsdp_loss[0] / fsdp_loss[1]
if rank == 0:
inner_pbar.close()
print(
f"Train Epoch: \t{epoch}, Loss: \t{train_accuracy:.4f}"
)
return train_accuracy
def validation(model, rank, world_size, val_loader):
model.eval()
correct = 0
local_rank = int(os.environ['LOCAL_RANK'])
fsdp_loss = torch.zeros(2).to(local_rank)
if rank == 0:
inner_pbar = tqdm.tqdm(
range(len(val_loader)), colour="green", desc="Validation Epoch"
)
with torch.no_grad():
for batch in val_loader:
for key in batch.keys():
batch[key] = batch[key].to(local_rank)
output = model(input_ids=batch["source_ids"],attention_mask=batch["source_mask"],labels=batch["target_ids"])
fsdp_loss[0] += output["loss"].item() # sum up batch loss
fsdp_loss[1] += len(batch)
if rank==0:
inner_pbar.update(1)
dist.all_reduce(fsdp_loss, op=dist.ReduceOp.SUM)
val_loss = fsdp_loss[0] / fsdp_loss[1]
if rank == 0:
inner_pbar.close()
print(f"Validation Loss: {val_loss:.4f}")
return val_loss
def setup_model(model_name):
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name, legacy=False)
return model, tokenizer