-
Notifications
You must be signed in to change notification settings - Fork 7k
/
Copy pathroi_align.py
294 lines (246 loc) · 11.1 KB
/
roi_align.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import functools
from typing import List, Union
import torch
import torch.fx
from torch import nn, Tensor
from torch._dynamo.utils import is_compile_supported
from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops, _has_ops
from ..utils import _log_api_usage_once
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
def lazy_compile(**compile_kwargs):
"""Lazily wrap a function with torch.compile on the first call
This avoids eagerly importing dynamo.
"""
def decorate_fn(fn):
@functools.wraps(fn)
def compile_hook(*args, **kwargs):
compiled_fn = torch.compile(fn, **compile_kwargs)
globals()[fn.__name__] = functools.wraps(fn)(compiled_fn)
return compiled_fn(*args, **kwargs)
return compile_hook
return decorate_fn
# NB: all inputs are tensors
def _bilinear_interpolate(
input, # [N, C, H, W]
roi_batch_ind, # [K]
y, # [K, PH, IY]
x, # [K, PW, IX]
ymask, # [K, IY]
xmask, # [K, IX]
):
_, channels, height, width = input.size()
# deal with inverse element out of feature map boundary
y = y.clamp(min=0)
x = x.clamp(min=0)
y_low = y.int()
x_low = x.int()
y_high = torch.where(y_low >= height - 1, height - 1, y_low + 1)
y_low = torch.where(y_low >= height - 1, height - 1, y_low)
y = torch.where(y_low >= height - 1, y.to(input.dtype), y)
x_high = torch.where(x_low >= width - 1, width - 1, x_low + 1)
x_low = torch.where(x_low >= width - 1, width - 1, x_low)
x = torch.where(x_low >= width - 1, x.to(input.dtype), x)
ly = y - y_low
lx = x - x_low
hy = 1.0 - ly
hx = 1.0 - lx
# do bilinear interpolation, but respect the masking!
# TODO: It's possible the masking here is unnecessary if y and
# x were clamped appropriately; hard to tell
def masked_index(
y, # [K, PH, IY]
x, # [K, PW, IX]
):
if ymask is not None:
assert xmask is not None
y = torch.where(ymask[:, None, :], y, 0)
x = torch.where(xmask[:, None, :], x, 0)
return input[
roi_batch_ind[:, None, None, None, None, None],
torch.arange(channels, device=input.device)[None, :, None, None, None, None],
y[:, None, :, None, :, None], # prev [K, PH, IY]
x[:, None, None, :, None, :], # prev [K, PW, IX]
] # [K, C, PH, PW, IY, IX]
v1 = masked_index(y_low, x_low)
v2 = masked_index(y_low, x_high)
v3 = masked_index(y_high, x_low)
v4 = masked_index(y_high, x_high)
# all ws preemptively [K, C, PH, PW, IY, IX]
def outer_prod(y, x):
return y[:, None, :, None, :, None] * x[:, None, None, :, None, :]
w1 = outer_prod(hy, hx)
w2 = outer_prod(hy, lx)
w3 = outer_prod(ly, hx)
w4 = outer_prod(ly, lx)
val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
return val
# TODO: this doesn't actually cache
# TODO: main library should make this easier to do
def maybe_cast(tensor):
if torch.is_autocast_enabled() and tensor.is_cuda and tensor.dtype != torch.double:
return tensor.float()
else:
return tensor
# This is a pure Python and differentiable implementation of roi_align. When
# run in eager mode, it uses a lot of memory, but when compiled it has
# acceptable memory usage. The main point of this implementation is that
# its backwards is deterministic.
# It is transcribed directly off of the roi_align CUDA kernel, see
# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266
@lazy_compile(dynamic=True)
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
orig_dtype = input.dtype
input = maybe_cast(input)
rois = maybe_cast(rois)
_, _, height, width = input.size()
ph = torch.arange(pooled_height, device=input.device) # [PH]
pw = torch.arange(pooled_width, device=input.device) # [PW]
# input: [N, C, H, W]
# rois: [K, 5]
roi_batch_ind = rois[:, 0].int() # [K]
offset = 0.5 if aligned else 0.0
roi_start_w = rois[:, 1] * spatial_scale - offset # [K]
roi_start_h = rois[:, 2] * spatial_scale - offset # [K]
roi_end_w = rois[:, 3] * spatial_scale - offset # [K]
roi_end_h = rois[:, 4] * spatial_scale - offset # [K]
roi_width = roi_end_w - roi_start_w # [K]
roi_height = roi_end_h - roi_start_h # [K]
if not aligned:
roi_width = torch.clamp(roi_width, min=1.0) # [K]
roi_height = torch.clamp(roi_height, min=1.0) # [K]
bin_size_h = roi_height / pooled_height # [K]
bin_size_w = roi_width / pooled_width # [K]
exact_sampling = sampling_ratio > 0
roi_bin_grid_h = sampling_ratio if exact_sampling else torch.ceil(roi_height / pooled_height) # scalar or [K]
roi_bin_grid_w = sampling_ratio if exact_sampling else torch.ceil(roi_width / pooled_width) # scalar or [K]
"""
iy, ix = dims(2)
"""
if exact_sampling:
count = max(roi_bin_grid_h * roi_bin_grid_w, 1) # scalar
iy = torch.arange(roi_bin_grid_h, device=input.device) # [IY]
ix = torch.arange(roi_bin_grid_w, device=input.device) # [IX]
ymask = None
xmask = None
else:
count = torch.clamp(roi_bin_grid_h * roi_bin_grid_w, min=1) # [K]
# When doing adaptive sampling, the number of samples we need to do
# is data-dependent based on how big the ROIs are. This is a bit
# awkward because first-class dims can't actually handle this.
# So instead, we inefficiently suppose that we needed to sample ALL
# the points and mask out things that turned out to be unnecessary
iy = torch.arange(height, device=input.device) # [IY]
ix = torch.arange(width, device=input.device) # [IX]
ymask = iy[None, :] < roi_bin_grid_h[:, None] # [K, IY]
xmask = ix[None, :] < roi_bin_grid_w[:, None] # [K, IX]
def from_K(t):
return t[:, None, None]
y = (
from_K(roi_start_h)
+ ph[None, :, None] * from_K(bin_size_h)
+ (iy[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_h / roi_bin_grid_h)
) # [K, PH, IY]
x = (
from_K(roi_start_w)
+ pw[None, :, None] * from_K(bin_size_w)
+ (ix[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_w / roi_bin_grid_w)
) # [K, PW, IX]
val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) # [K, C, PH, PW, IY, IX]
# Mask out samples that weren't actually adaptively needed
if not exact_sampling:
val = torch.where(ymask[:, None, None, None, :, None], val, 0)
val = torch.where(xmask[:, None, None, None, None, :], val, 0)
output = val.sum((-1, -2)) # remove IY, IX ~> [K, C, PH, PW]
if isinstance(count, torch.Tensor):
output /= count[:, None, None, None]
else:
output /= count
output = output.to(orig_dtype)
return output
@torch.fx.wrap
def roi_align(
input: Tensor,
boxes: Union[Tensor, List[Tensor]],
output_size: BroadcastingList2[int],
spatial_scale: float = 1.0,
sampling_ratio: int = -1,
aligned: bool = False,
) -> Tensor:
"""
Performs Region of Interest (RoI) Align operator with average pooling, as described in Mask R-CNN.
Args:
input (Tensor[N, C, H, W]): The input tensor, i.e. a batch with ``N`` elements. Each element
contains ``C`` feature maps of dimensions ``H x W``.
If the tensor is quantized, we expect a batch size of ``N == 1``.
boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2)
format where the regions will be taken from.
The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
If a single Tensor is passed, then the first column should
contain the index of the corresponding element in the batch, i.e. a number in ``[0, N - 1]``.
If a list of Tensors is passed, then each Tensor will correspond to the boxes for an element i
in the batch.
output_size (int or Tuple[int, int]): the size of the output (in bins or pixels) after the pooling
is performed, as (height, width).
spatial_scale (float): a scaling factor that maps the box coordinates to
the input coordinates. For example, if your boxes are defined on the scale
of a 224x224 image and your input is a 112x112 feature map (resulting from a 0.5x scaling of
the original image), you'll want to set this to 0.5. Default: 1.0
sampling_ratio (int): number of sampling points in the interpolation grid
used to compute the output value of each pooled output bin. If > 0,
then exactly ``sampling_ratio x sampling_ratio`` sampling points per bin are used. If
<= 0, then an adaptive number of grid points are used (computed as
``ceil(roi_width / output_width)``, and likewise for height). Default: -1
aligned (bool): If False, use the legacy implementation.
If True, pixel shift the box coordinates it by -0.5 for a better alignment with the two
neighboring pixel indices. This version is used in Detectron2
Returns:
Tensor[K, C, output_size[0], output_size[1]]: The pooled RoIs.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(roi_align)
check_roi_boxes_shape(boxes)
rois = boxes
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
if not torch.jit.is_scripting():
if (
not _has_ops()
or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps or input.is_xpu))
) and is_compile_supported(input.device.type):
return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
_assert_has_ops()
return torch.ops.torchvision.roi_align(
input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned
)
class RoIAlign(nn.Module):
"""
See :func:`roi_align`.
"""
def __init__(
self,
output_size: BroadcastingList2[int],
spatial_scale: float,
sampling_ratio: int,
aligned: bool = False,
):
super().__init__()
_log_api_usage_once(self)
self.output_size = output_size
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio
self.aligned = aligned
def forward(self, input: Tensor, rois: Union[Tensor, List[Tensor]]) -> Tensor:
return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned)
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"output_size={self.output_size}"
f", spatial_scale={self.spatial_scale}"
f", sampling_ratio={self.sampling_ratio}"
f", aligned={self.aligned}"
f")"
)
return s