-
Notifications
You must be signed in to change notification settings - Fork 7k
/
Copy pathdrop_block.py
155 lines (128 loc) · 5.72 KB
/
drop_block.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch
import torch.fx
import torch.nn.functional as F
from torch import nn, Tensor
from ..utils import _log_api_usage_once
def drop_block2d(
input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True
) -> Tensor:
"""
Implements DropBlock2d from `"DropBlock: A regularization method for convolutional networks"
<https://arxiv.org/abs/1810.12890>`.
Args:
input (Tensor[N, C, H, W]): The input tensor or 4-dimensions with the first one
being its batch i.e. a batch with ``N`` rows.
p (float): Probability of an element to be dropped.
block_size (int): Size of the block to drop.
inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``.
eps (float): A value added to the denominator for numerical stability. Default: 1e-6.
training (bool): apply dropblock if is ``True``. Default: ``True``.
Returns:
Tensor[N, C, H, W]: The randomly zeroed tensor after dropblock.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(drop_block2d)
if p < 0.0 or p > 1.0:
raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.")
if input.ndim != 4:
raise ValueError(f"input should be 4 dimensional. Got {input.ndim} dimensions.")
if not training or p == 0.0:
return input
N, C, H, W = input.size()
block_size = min(block_size, W, H)
# compute the gamma of Bernoulli distribution
gamma = (p * H * W) / ((block_size**2) * ((H - block_size + 1) * (W - block_size + 1)))
noise = torch.empty((N, C, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device)
noise.bernoulli_(gamma)
noise = F.pad(noise, [block_size // 2] * 4, value=0)
noise = F.max_pool2d(noise, stride=(1, 1), kernel_size=(block_size, block_size), padding=block_size // 2)
noise = 1 - noise
normalize_scale = noise.numel() / (eps + noise.sum())
if inplace:
input.mul_(noise).mul_(normalize_scale)
else:
input = input * noise * normalize_scale
return input
def drop_block3d(
input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True
) -> Tensor:
"""
Implements DropBlock3d from `"DropBlock: A regularization method for convolutional networks"
<https://arxiv.org/abs/1810.12890>`.
Args:
input (Tensor[N, C, D, H, W]): The input tensor or 5-dimensions with the first one
being its batch i.e. a batch with ``N`` rows.
p (float): Probability of an element to be dropped.
block_size (int): Size of the block to drop.
inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``.
eps (float): A value added to the denominator for numerical stability. Default: 1e-6.
training (bool): apply dropblock if is ``True``. Default: ``True``.
Returns:
Tensor[N, C, D, H, W]: The randomly zeroed tensor after dropblock.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(drop_block3d)
if p < 0.0 or p > 1.0:
raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.")
if input.ndim != 5:
raise ValueError(f"input should be 5 dimensional. Got {input.ndim} dimensions.")
if not training or p == 0.0:
return input
N, C, D, H, W = input.size()
block_size = min(block_size, D, H, W)
# compute the gamma of Bernoulli distribution
gamma = (p * D * H * W) / ((block_size**3) * ((D - block_size + 1) * (H - block_size + 1) * (W - block_size + 1)))
noise = torch.empty(
(N, C, D - block_size + 1, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device
)
noise.bernoulli_(gamma)
noise = F.pad(noise, [block_size // 2] * 6, value=0)
noise = F.max_pool3d(
noise, stride=(1, 1, 1), kernel_size=(block_size, block_size, block_size), padding=block_size // 2
)
noise = 1 - noise
normalize_scale = noise.numel() / (eps + noise.sum())
if inplace:
input.mul_(noise).mul_(normalize_scale)
else:
input = input * noise * normalize_scale
return input
torch.fx.wrap("drop_block2d")
class DropBlock2d(nn.Module):
"""
See :func:`drop_block2d`.
"""
def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None:
super().__init__()
self.p = p
self.block_size = block_size
self.inplace = inplace
self.eps = eps
def forward(self, input: Tensor) -> Tensor:
"""
Args:
input (Tensor): Input feature map on which some areas will be randomly
dropped.
Returns:
Tensor: The tensor after DropBlock layer.
"""
return drop_block2d(input, self.p, self.block_size, self.inplace, self.eps, self.training)
def __repr__(self) -> str:
s = f"{self.__class__.__name__}(p={self.p}, block_size={self.block_size}, inplace={self.inplace})"
return s
torch.fx.wrap("drop_block3d")
class DropBlock3d(DropBlock2d):
"""
See :func:`drop_block3d`.
"""
def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None:
super().__init__(p, block_size, inplace, eps)
def forward(self, input: Tensor) -> Tensor:
"""
Args:
input (Tensor): Input feature map on which some areas will be randomly
dropped.
Returns:
Tensor: The tensor after DropBlock layer.
"""
return drop_block3d(input, self.p, self.block_size, self.inplace, self.eps, self.training)