-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_hook.py
415 lines (340 loc) · 12.9 KB
/
test_hook.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import shutil
import sys
import tempfile
from unittest.mock import MagicMock, Mock, call, patch
import numpy as np
import pytest
import torch
import torch.nn as nn
from mmcv.runner import (CheckpointHook, IterTimerHook, PaviLoggerHook,
build_runner)
from torch.nn.init import constant_
from torch.utils.data import DataLoader, Dataset
from mmdet.core.hook import ExpMomentumEMAHook, YOLOXLrUpdaterHook
from mmdet.core.hook.sync_norm_hook import SyncNormHook
from mmdet.core.hook.sync_random_size_hook import SyncRandomSizeHook
def _build_demo_runner_without_hook(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
multi_optimziers=False):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 1)
self.conv = nn.Conv2d(3, 3, 3)
def forward(self, x):
return self.linear(x)
def train_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
def val_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
model = Model()
if multi_optimziers:
optimizer = {
'model1':
torch.optim.SGD(model.linear.parameters(), lr=0.02, momentum=0.95),
'model2':
torch.optim.SGD(model.conv.parameters(), lr=0.01, momentum=0.9),
}
else:
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)
tmp_dir = tempfile.mkdtemp()
runner = build_runner(
dict(type=runner_type),
default_args=dict(
model=model,
work_dir=tmp_dir,
optimizer=optimizer,
logger=logging.getLogger(),
max_epochs=max_epochs,
max_iters=max_iters))
return runner
def _build_demo_runner(runner_type='EpochBasedRunner',
max_epochs=1,
max_iters=None,
multi_optimziers=False):
log_config = dict(
interval=1, hooks=[
dict(type='TextLoggerHook'),
])
runner = _build_demo_runner_without_hook(runner_type, max_epochs,
max_iters, multi_optimziers)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_logger_hooks(log_config)
return runner
@pytest.mark.parametrize('multi_optimziers', (True, False))
def test_yolox_lrupdater_hook(multi_optimziers):
"""xdoctest -m tests/test_hooks.py test_cosine_runner_hook."""
# Only used to prevent program errors
YOLOXLrUpdaterHook(0, min_lr_ratio=0.05)
sys.modules['pavi'] = MagicMock()
loader = DataLoader(torch.ones((10, 2)))
runner = _build_demo_runner(multi_optimziers=multi_optimziers)
hook_cfg = dict(
type='YOLOXLrUpdaterHook',
warmup='exp',
by_epoch=False,
warmup_by_epoch=True,
warmup_ratio=1,
warmup_iters=5, # 5 epoch
num_last_epochs=15,
min_lr_ratio=0.05)
runner.register_hook_from_cfg(hook_cfg)
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
runner.register_hook(IterTimerHook())
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
# TODO: use a more elegant way to check values
assert hasattr(hook, 'writer')
if multi_optimziers:
calls = [
call(
'train', {
'learning_rate/model1': 8.000000000000001e-06,
'learning_rate/model2': 4.000000000000001e-06,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 1),
call(
'train', {
'learning_rate/model1': 0.00039200000000000004,
'learning_rate/model2': 0.00019600000000000002,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 7),
call(
'train', {
'learning_rate/model1': 0.0008000000000000001,
'learning_rate/model2': 0.0004000000000000001,
'momentum/model1': 0.95,
'momentum/model2': 0.9
}, 10)
]
else:
calls = [
call('train', {
'learning_rate': 8.000000000000001e-06,
'momentum': 0.95
}, 1),
call('train', {
'learning_rate': 0.00039200000000000004,
'momentum': 0.95
}, 7),
call('train', {
'learning_rate': 0.0008000000000000001,
'momentum': 0.95
}, 10)
]
hook.writer.add_scalars.assert_has_calls(calls, any_order=True)
def test_ema_hook():
"""xdoctest -m tests/test_hooks.py test_ema_hook."""
class DemoModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(
in_channels=1,
out_channels=2,
kernel_size=1,
padding=1,
bias=True)
self.bn = nn.BatchNorm2d(2)
self._init_weight()
def _init_weight(self):
constant_(self.conv.weight, 0)
constant_(self.conv.bias, 0)
constant_(self.bn.weight, 0)
constant_(self.bn.bias, 0)
def forward(self, x):
return self.bn(self.conv(x)).sum()
def train_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
def val_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
loader = DataLoader(torch.ones((1, 1, 1, 1)))
runner = _build_demo_runner()
demo_model = DemoModel()
runner.model = demo_model
ema_hook = ExpMomentumEMAHook(
momentum=0.0002,
total_iter=1,
skip_buffers=True,
interval=2,
resume_from=None)
checkpointhook = CheckpointHook(interval=1, by_epoch=True)
runner.register_hook(ema_hook, priority='HIGHEST')
runner.register_hook(checkpointhook)
runner.run([loader, loader], [('train', 1), ('val', 1)])
checkpoint = torch.load(f'{runner.work_dir}/epoch_1.pth')
num_eam_params = 0
for name, value in checkpoint['state_dict'].items():
if 'ema' in name:
num_eam_params += 1
value.fill_(1)
assert num_eam_params == 4
torch.save(checkpoint, f'{runner.work_dir}/epoch_1.pth')
work_dir = runner.work_dir
resume_ema_hook = ExpMomentumEMAHook(
momentum=0.5,
total_iter=10,
skip_buffers=True,
interval=1,
resume_from=f'{work_dir}/epoch_1.pth')
runner = _build_demo_runner(max_epochs=2)
runner.model = demo_model
runner.register_hook(resume_ema_hook, priority='HIGHEST')
checkpointhook = CheckpointHook(interval=1, by_epoch=True)
runner.register_hook(checkpointhook)
runner.run([loader, loader], [('train', 1), ('val', 1)])
checkpoint = torch.load(f'{runner.work_dir}/epoch_2.pth')
num_eam_params = 0
desired_output = [0.9094, 0.9094]
for name, value in checkpoint['state_dict'].items():
if 'ema' in name:
num_eam_params += 1
assert value.sum() == 2
else:
if ('weight' in name) or ('bias' in name):
np.allclose(value.data.cpu().numpy().reshape(-1),
desired_output, 1e-4)
assert num_eam_params == 4
shutil.rmtree(runner.work_dir)
shutil.rmtree(work_dir)
def test_sync_norm_hook():
# Only used to prevent program errors
SyncNormHook()
loader = DataLoader(torch.ones((5, 2)))
runner = _build_demo_runner()
runner.register_hook_from_cfg(dict(type='SyncNormHook'))
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)
def test_sync_random_size_hook():
# Only used to prevent program errors
SyncRandomSizeHook()
class DemoDataset(Dataset):
def __getitem__(self, item):
return torch.ones(2)
def __len__(self):
return 5
def update_dynamic_scale(self, dynamic_scale):
pass
loader = DataLoader(DemoDataset())
runner = _build_demo_runner()
runner.register_hook_from_cfg(
dict(type='SyncRandomSizeHook', device='cpu'))
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)
if torch.cuda.is_available():
runner = _build_demo_runner()
runner.register_hook_from_cfg(
dict(type='SyncRandomSizeHook', device='cuda'))
runner.run([loader, loader], [('train', 1), ('val', 1)])
shutil.rmtree(runner.work_dir)
@pytest.mark.parametrize('set_loss', [
dict(set_loss_nan=False, set_loss_inf=False),
dict(set_loss_nan=True, set_loss_inf=False),
dict(set_loss_nan=False, set_loss_inf=True)
])
def test_check_invalid_loss_hook(set_loss):
# Check whether loss is valid during training.
class DemoModel(nn.Module):
def __init__(self, set_loss_nan=False, set_loss_inf=False):
super().__init__()
self.set_loss_nan = set_loss_nan
self.set_loss_inf = set_loss_inf
self.linear = nn.Linear(2, 1)
def forward(self, x):
return self.linear(x)
def train_step(self, x, optimizer, **kwargs):
if self.set_loss_nan:
return dict(loss=torch.tensor(float('nan')))
elif self.set_loss_inf:
return dict(loss=torch.tensor(float('inf')))
else:
return dict(loss=self(x))
loader = DataLoader(torch.ones((5, 2)))
runner = _build_demo_runner()
demo_model = DemoModel(**set_loss)
runner.model = demo_model
runner.register_hook_from_cfg(
dict(type='CheckInvalidLossHook', interval=1))
if not set_loss['set_loss_nan'] \
and not set_loss['set_loss_inf']:
# check loss is valid
runner.run([loader], [('train', 1)])
else:
# check loss is nan or inf
with pytest.raises(AssertionError):
runner.run([loader], [('train', 1)])
shutil.rmtree(runner.work_dir)
def test_set_epoch_info_hook():
"""Test SetEpochInfoHook."""
class DemoModel(nn.Module):
def __init__(self):
super().__init__()
self.epoch = 0
self.linear = nn.Linear(2, 1)
def forward(self, x):
return self.linear(x)
def train_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))
def set_epoch(self, epoch):
self.epoch = epoch
loader = DataLoader(torch.ones((5, 2)))
runner = _build_demo_runner(max_epochs=3)
demo_model = DemoModel()
runner.model = demo_model
runner.register_hook_from_cfg(dict(type='SetEpochInfoHook'))
runner.run([loader], [('train', 1)])
assert demo_model.epoch == 2
def test_memory_profiler_hook():
from collections import namedtuple
# test ImportError without psutil and memory_profiler
with pytest.raises(ImportError):
from mmdet.core.hook import MemoryProfilerHook
MemoryProfilerHook(1)
# test ImportError without memory_profiler
sys.modules['psutil'] = MagicMock()
with pytest.raises(ImportError):
from mmdet.core.hook import MemoryProfilerHook
MemoryProfilerHook(1)
sys.modules['memory_profiler'] = MagicMock()
def _mock_virtual_memory():
virtual_memory_type = namedtuple(
'virtual_memory', ['total', 'available', 'percent', 'used'])
return virtual_memory_type(
total=270109085696,
available=250416816128,
percent=7.3,
used=17840881664)
def _mock_swap_memory():
swap_memory_type = namedtuple('swap_memory', [
'total',
'used',
'percent',
])
return swap_memory_type(total=8589930496, used=0, percent=0.0)
def _mock_memory_usage():
return [40.22265625]
mock_virtual_memory = Mock(return_value=_mock_virtual_memory())
mock_swap_memory = Mock(return_value=_mock_swap_memory())
mock_memory_usage = Mock(return_value=_mock_memory_usage())
@patch('psutil.swap_memory', mock_swap_memory)
@patch('psutil.virtual_memory', mock_virtual_memory)
@patch('memory_profiler.memory_usage', mock_memory_usage)
def _test_memory_profiler_hook():
from mmdet.core.hook import MemoryProfilerHook
hook = MemoryProfilerHook(1)
runner = _build_demo_runner()
assert not mock_memory_usage.called
assert not mock_swap_memory.called
assert not mock_memory_usage.called
hook.after_iter(runner)
assert mock_memory_usage.called
assert mock_swap_memory.called
assert mock_memory_usage.called
_test_memory_profiler_hook()