-
Notifications
You must be signed in to change notification settings - Fork 505
/
Copy pathschedulers.py
140 lines (123 loc) · 5.94 KB
/
schedulers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import math
import torch_xla.test.test_utils as test_utils
from torch.optim.lr_scheduler import _LRScheduler
class WarmupAndExponentialDecayScheduler(_LRScheduler):
"""Update the learning rate of wrapped optimizer based on epoch and step.
Args:
optimizer: Instance of torch.optim.Optimizer. Learning rate will be changed.
num_steps_per_epoch: int, the number of steps required to finish 1 epoch.
divide_every_n_epochs: After this number of epochs, learning rate will be
divided by the `divisor` param.
divisor: The learning rate will be divided by this amount when epoch %
divide_every_n_epochs == 0 (epoch 0 is excluded).
num_warmup_epochs: Float. Learning rate will ramp up from 0 to max learning
rate over this many epochs. Note that partial epochs are allowed, e.g. 0.5
epochs.
min_delta_to_update_lr: If the new learning rate does not differ much from
the learning rate of the previous step, don't bother updating the
optimizer's learning rate.
summary_writer: Instance of `torch.utils.tensorboard.SummaryWriter`. If
provided, learning rate will be logged during calls to step if step is
called with write_to_summary=True. If summary_writer is None, then no
logging happens.
"""
def __init__(self,
optimizer,
num_steps_per_epoch,
divide_every_n_epochs=20,
divisor=5,
num_warmup_epochs=0.9,
min_delta_to_update_lr=1e-6,
summary_writer=None):
self._num_steps_per_epoch = num_steps_per_epoch
self._divide_every_n_epochs = divide_every_n_epochs
self._divisor = divisor
self._num_warmup_epochs = num_warmup_epochs
self._min_delta_to_update_lr = min_delta_to_update_lr
self._previous_lr = -1
self._max_lr = optimizer.param_groups[0]['lr']
self._summary_writer = summary_writer
super(WarmupAndExponentialDecayScheduler, self).__init__(optimizer)
def _epoch(self):
return self._step_count // self._num_steps_per_epoch
def _is_warmup_epoch(self):
return self._epoch() < math.ceil(self._num_warmup_epochs)
def get_lr(self):
epoch = self._epoch()
lr = 0.0
if self._is_warmup_epoch():
# Ramp up learning rate from 0.0 to self._max_lr using a linear slope.
num_warmup_steps = self._num_warmup_epochs * self._num_steps_per_epoch
lr = min(self._max_lr,
self._max_lr * ((self._step_count + 1.0) / num_warmup_steps))
else:
# Normal epoch. Use an exponential decay determined by init params.
lr = self._max_lr / (
self._divisor**(epoch // self._divide_every_n_epochs))
# _LRScheduler expects a list of learning rates like this.
return [lr for _ in self.base_lrs]
def step(self, epoch=None):
current_lr = self.get_lr()[0]
# Outside of warmup epochs, we use the same learning rate for every step
# in an epoch. Don't bother updating learning rate if it hasn't changed.
if abs(current_lr - self._previous_lr) >= self._min_delta_to_update_lr:
super(WarmupAndExponentialDecayScheduler, self).step()
self._previous_lr = current_lr
else:
self._step_count += 1 # This normally happens in super().step().
# Add current learning rate to Tensorboard metrics. For warmup epochs,
# log the learning rate at every step. For non-warmup epochs, log only
# the first step since the entire epoch will use the same learning rate.
if self._summary_writer:
if self._is_warmup_epoch() or (self._step_count %
self._num_steps_per_epoch == 0):
test_utils.write_to_summary(
self._summary_writer,
self._step_count,
dict_to_write={
'LearningRate': self.optimizer.param_groups[0]['lr']
},
write_xla_metrics=False)
def wrap_optimizer_with_scheduler(optimizer,
scheduler_type=None,
scheduler_divisor=None,
scheduler_divide_every_n_epochs=None,
num_steps_per_epoch=None,
summary_writer=None):
"""Wraps an optimizer in a `torch.optim.lr_scheduler` object.
Args:
optimizer: Instance of `torch.optim.Optimizer`. Will be modified by the
scheduler to overwrite the learning rate.
scheduler_type: string, type of learning rate scheduler to use. If None,
this method returns None.
scheduler_divisor: int, required for WarmupAndExponentialDecayScheduler.
scheduler_divide_every_n_epochs: int, required for
WarmupAndExponentialDecayScheduler.
num_steps_per_epoch: int, the number of steps that occur in each epoch.
Required for WarmupAndExponentialDecayScheduler.
summary_writer: Instance of `torch.utils.tensorboard.SummaryWriter` that
will be passed into the scheduler to log learning rate during training.
Raises:
ValueError if the requested scheduler_type is unrecognized or if any
required params are missing for the requested scheduler_type.
"""
if not scheduler_type:
return None
if scheduler_type == 'WarmupAndExponentialDecayScheduler':
if scheduler_divisor is None:
raise ValueError('scheduler_divisor is required for '
'WarmupAndExponentialDecayScheduler.')
if scheduler_divide_every_n_epochs is None:
raise ValueError('scheduler_divide_every_n_epochs is required for '
'WarmupAndExponentialDecayScheduler.')
if num_steps_per_epoch is None:
raise ValueError('num_steps_per_epoch is required for '
'WarmupAndExponentialDecayScheduler.')
return WarmupAndExponentialDecayScheduler(
optimizer,
num_steps_per_epoch,
divide_every_n_epochs=scheduler_divide_every_n_epochs,
divisor=scheduler_divisor,
summary_writer=summary_writer)
else:
raise ValueError('Unknown scheduler_type: {}'.format(scheduler_type))