-
Notifications
You must be signed in to change notification settings - Fork 7k
/
Copy pathlraspp.py
178 lines (147 loc) · 7.46 KB
/
lraspp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
from collections import OrderedDict
from functools import partial
from typing import Any, Dict, Optional
from torch import nn, Tensor
from torch.nn import functional as F
from ...transforms._presets import SemanticSegmentation
from ...utils import _log_api_usage_once
from .._api import register_model, Weights, WeightsEnum
from .._meta import _VOC_CATEGORIES
from .._utils import _ovewrite_value_param, handle_legacy_interface, IntermediateLayerGetter
from ..mobilenetv3 import mobilenet_v3_large, MobileNet_V3_Large_Weights, MobileNetV3
__all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_large"]
class LRASPP(nn.Module):
"""
Implements a Lite R-ASPP Network for semantic segmentation from
`"Searching for MobileNetV3"
<https://arxiv.org/abs/1905.02244>`_.
Args:
backbone (nn.Module): the network used to compute the features for the model.
The backbone should return an OrderedDict[Tensor], with the key being
"high" for the high level feature map and "low" for the low level feature map.
low_channels (int): the number of channels of the low level features.
high_channels (int): the number of channels of the high level features.
num_classes (int, optional): number of output classes of the model (including the background).
inter_channels (int, optional): the number of channels for intermediate computations.
"""
def __init__(
self, backbone: nn.Module, low_channels: int, high_channels: int, num_classes: int, inter_channels: int = 128
) -> None:
super().__init__()
_log_api_usage_once(self)
self.backbone = backbone
self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels)
def forward(self, input: Tensor) -> Dict[str, Tensor]:
features = self.backbone(input)
out = self.classifier(features)
out = F.interpolate(out, size=input.shape[-2:], mode="bilinear", align_corners=False)
result = OrderedDict()
result["out"] = out
return result
class LRASPPHead(nn.Module):
def __init__(self, low_channels: int, high_channels: int, num_classes: int, inter_channels: int) -> None:
super().__init__()
self.cbr = nn.Sequential(
nn.Conv2d(high_channels, inter_channels, 1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True),
)
self.scale = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(high_channels, inter_channels, 1, bias=False),
nn.Sigmoid(),
)
self.low_classifier = nn.Conv2d(low_channels, num_classes, 1)
self.high_classifier = nn.Conv2d(inter_channels, num_classes, 1)
def forward(self, input: Dict[str, Tensor]) -> Tensor:
low = input["low"]
high = input["high"]
x = self.cbr(high)
s = self.scale(high)
x = x * s
x = F.interpolate(x, size=low.shape[-2:], mode="bilinear", align_corners=False)
return self.low_classifier(low) + self.high_classifier(x)
def _lraspp_mobilenetv3(backbone: MobileNetV3, num_classes: int) -> LRASPP:
backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
low_pos = stage_indices[-4] # use C2 here which has output_stride = 8
high_pos = stage_indices[-1] # use C5 which has output_stride = 16
low_channels = backbone[low_pos].out_channels
high_channels = backbone[high_pos].out_channels
backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): "low", str(high_pos): "high"})
return LRASPP(backbone, low_channels, high_channels, num_classes)
class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
transforms=partial(SemanticSegmentation, resize_size=520),
meta={
"num_params": 3221538,
"categories": _VOC_CATEGORIES,
"min_size": (1, 1),
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large",
"_metrics": {
"COCO-val2017-VOC-labels": {
"miou": 57.9,
"pixel_acc": 91.2,
}
},
"_ops": 2.086,
"_file_size": 12.49,
"_docs": """
These weights were trained on a subset of COCO, using only the 20 categories that are present in the
Pascal VOC dataset.
""",
},
)
DEFAULT = COCO_WITH_VOC_LABELS_V1
@register_model()
@handle_legacy_interface(
weights=("pretrained", LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1),
weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
def lraspp_mobilenet_v3_large(
*,
weights: Optional[LRASPP_MobileNet_V3_Large_Weights] = None,
progress: bool = True,
num_classes: Optional[int] = None,
weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
**kwargs: Any,
) -> LRASPP:
"""Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone from
`Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_ paper.
.. betastatus:: segmentation module
Args:
weights (:class:`~torchvision.models.segmentation.LRASPP_MobileNet_V3_Large_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.segmentation.LRASPP_MobileNet_V3_Large_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
num_classes (int, optional): number of output classes of the model (including the background).
aux_loss (bool, optional): If True, it uses an auxiliary loss.
weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained
weights for the backbone.
**kwargs: parameters passed to the ``torchvision.models.segmentation.LRASPP``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/segmentation/lraspp.py>`_
for more details about this class.
.. autoclass:: torchvision.models.segmentation.LRASPP_MobileNet_V3_Large_Weights
:members:
"""
if kwargs.pop("aux_loss", False):
raise NotImplementedError("This model does not use auxiliary loss")
weights = LRASPP_MobileNet_V3_Large_Weights.verify(weights)
weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
elif num_classes is None:
num_classes = 21
backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True)
model = _lraspp_mobilenetv3(backbone, num_classes)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model