-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_split_batch.py
95 lines (90 loc) · 3.62 KB
/
test_split_batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from copy import deepcopy
import mmcv
import numpy as np
import torch
from mmdet.utils import split_batch
def test_split_batch():
img_root = osp.join(osp.dirname(__file__), '../data/color.jpg')
img = mmcv.imread(img_root, 'color')
h, w, _ = img.shape
gt_bboxes = np.array([[0.2 * w, 0.2 * h, 0.4 * w, 0.4 * h],
[0.6 * w, 0.6 * h, 0.8 * w, 0.8 * h]],
dtype=np.float32)
gt_lables = np.ones(gt_bboxes.shape[0], dtype=np.int64)
img = torch.tensor(img).permute(2, 0, 1)
meta = dict()
meta['filename'] = img_root
meta['ori_shape'] = img.shape
meta['img_shape'] = img.shape
meta['img_norm_cfg'] = {
'mean': np.array([103.53, 116.28, 123.675], dtype=np.float32),
'std': np.array([1., 1., 1.], dtype=np.float32),
'to_rgb': False
}
meta['pad_shape'] = img.shape
# For example, tag include sup, unsup_teacher and unsup_student,
# in order to distinguish the difference between the three groups of data,
# the scale_factor of sup is [0.5, 0.5, 0.5, 0.5]
# the scale_factor of unsup_teacher is [1.0, 1.0, 1.0, 1.0]
# the scale_factor of unsup_student is [2.0, 2.0, 2.0, 2.0]
imgs = img.unsqueeze(0).repeat(9, 1, 1, 1)
img_metas = []
tags = [
'sup', 'unsup_teacher', 'unsup_student', 'unsup_teacher',
'unsup_student', 'unsup_teacher', 'unsup_student', 'unsup_teacher',
'unsup_student'
]
for tag in tags:
img_meta = deepcopy(meta)
if tag == 'sup':
img_meta['scale_factor'] = [0.5, 0.5, 0.5, 0.5]
img_meta['tag'] = 'sup'
elif tag == 'unsup_teacher':
img_meta['scale_factor'] = [1.0, 1.0, 1.0, 1.0]
img_meta['tag'] = 'unsup_teacher'
elif tag == 'unsup_student':
img_meta['scale_factor'] = [2.0, 2.0, 2.0, 2.0]
img_meta['tag'] = 'unsup_student'
else:
continue
img_metas.append(img_meta)
kwargs = dict()
kwargs['gt_bboxes'] = [torch.tensor(gt_bboxes)] + [torch.zeros(0, 4)] * 8
kwargs['gt_lables'] = [torch.tensor(gt_lables)] + [torch.zeros(0, )] * 8
data_groups = split_batch(imgs, img_metas, kwargs)
assert set(data_groups.keys()) == set(tags)
assert data_groups['sup']['img'].shape == (1, 3, h, w)
assert data_groups['unsup_teacher']['img'].shape == (4, 3, h, w)
assert data_groups['unsup_student']['img'].shape == (4, 3, h, w)
# the scale_factor of sup is [0.5, 0.5, 0.5, 0.5]
assert data_groups['sup']['img_metas'][0]['scale_factor'] == [
0.5, 0.5, 0.5, 0.5
]
# the scale_factor of unsup_teacher is [1.0, 1.0, 1.0, 1.0]
assert data_groups['unsup_teacher']['img_metas'][0]['scale_factor'] == [
1.0, 1.0, 1.0, 1.0
]
assert data_groups['unsup_teacher']['img_metas'][1]['scale_factor'] == [
1.0, 1.0, 1.0, 1.0
]
assert data_groups['unsup_teacher']['img_metas'][2]['scale_factor'] == [
1.0, 1.0, 1.0, 1.0
]
assert data_groups['unsup_teacher']['img_metas'][3]['scale_factor'] == [
1.0, 1.0, 1.0, 1.0
]
# the scale_factor of unsup_student is [2.0, 2.0, 2.0, 2.0]
assert data_groups['unsup_student']['img_metas'][0]['scale_factor'] == [
2.0, 2.0, 2.0, 2.0
]
assert data_groups['unsup_student']['img_metas'][1]['scale_factor'] == [
2.0, 2.0, 2.0, 2.0
]
assert data_groups['unsup_student']['img_metas'][2]['scale_factor'] == [
2.0, 2.0, 2.0, 2.0
]
assert data_groups['unsup_student']['img_metas'][3]['scale_factor'] == [
2.0, 2.0, 2.0, 2.0
]