You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TorchVision has a new backwards compatible API for building models with multi-weight support. The new API allows loading different pre-trained weights on the same model variant, keeps track of vital meta-data such as the classification labels and includes the preprocessing transforms necessary for using the models. In this blog post, we plan to review the prototype API, show-case its features and highlight key differences with the existing one.
8
+
TorchVision has a new backwards compatible API for building models with multi-weight support. The new API allows loading different pre-trained weights on the same model variant, keeps track of vital meta-data such as the classification labels and includes the preprocessing transforms necessary for using the models. In this blog post, we plan to review the prototype API, show-case its features and highlight key differences with the existing one.
@@ -17,7 +17,7 @@ We are hoping to get your thoughts about the API prior finalizing it. To collect
17
17
18
18
TorchVision currently provides pre-trained models which could be a starting point for transfer learning or used as-is in Computer Vision applications. The typical way to instantiate a pre-trained model and make a prediction is:
19
19
20
-
```python
20
+
```Python
21
21
import torch
22
22
23
23
fromPILimport Image
@@ -58,15 +58,15 @@ There are a few limitations with the above approach:
58
58
59
59
1.**Inability to support multiple pre-trained weights:** Since the `pretrained` variable is boolean, we can only offer one set of weights. This poses a severe limitation when we significantly [improve the accuracy of existing models](https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/) and we want to make those improvements available to the community. It also stops us from offering pre-trained weights of the same model variant on different datasets.
60
60
2.**Missing inference/preprocessing transforms:** The user is forced to define the necessary transforms prior using the model. The inference transforms are usually linked to the training process and dataset used to estimate the weights. Any minor discrepancies in these transforms (such as interpolation value, resize/crop sizes etc) can lead to major reductions in accuracy or unusable models.
61
-
3.**Lack of meta-data:** Critical pieces of information in relation to the weights are unavailable to the users. For example, one needs to look into external sources and the documentation to find things like the [category labels](https://github.com/pytorch/vision/issues/1946), the training recipe, the accuracy metrics etc.
61
+
3.**Lack of meta-data:** Critical pieces of information in relation to the weights are unavailable to the users. For example, one needs to look into external sources and the documentation to find things like the [category labels](https://github.com/pytorch/vision/issues/1946), the training recipe, the accuracy metrics etc.
62
62
63
63
The new API addresses the above limitations and reduces the amount of boilerplate code needed for standard tasks.
64
64
65
65
## Overview of the prototype API
66
66
67
67
Let’s see how we can achieve exactly the same results as above using the new API:
68
68
69
-
```python
69
+
```Python
70
70
fromPILimport Image
71
71
from torchvision.prototype import models asPM
72
72
@@ -116,11 +116,11 @@ model = resnet50(weights=ResNet50_Weights.DEFAULT)
The weights of each model are associated with meta-data. The type of information we store depends on the task of the model (Classification, Detection, Segmentation etc). Typical information includes a link to the training recipe, the interpolation mode, information such as the categories and validation metrics. These values are programmatically accessible via the `meta` attribute:
122
122
123
-
```python
123
+
```Python
124
124
from torchvision.prototype.models import ResNet50_Weights
125
125
126
126
# Accessing a single record
@@ -133,7 +133,7 @@ for k, v in ResNet50_Weights.IMAGENET1K_V2.meta.items():
133
133
134
134
Additionally, each weights entry is associated with the necessary preprocessing transforms. All current preprocessing transforms are JIT-scriptable and can be accessed via the `transforms` attribute. Prior using them with the data, the transforms need to be initialized/constructed. This lazy initialization scheme is done to ensure the solution is memory efficient. The input of the transforms can be either a `PIL.Image` or a `Tensor` read using `torchvision.io`.
135
135
136
-
```python
136
+
```Python
137
137
from torchvision.prototype.models import ResNet50_Weights
138
138
139
139
# Initializing preprocessing at standard 224x224 resolution
Associating the weights with their meta-data and preprocessing will boost transparency, improve reproducibility and make it easier to document how a set of weights was produced.
149
+
Associating the weights with their meta-data and preprocessing will boost transparency, improve reproducibility and make it easier to document how a set of weights was produced.
150
150
151
151
### Get weights by name
152
152
153
153
The ability to link directly the weights with their properties (meta data, preprocessing callables etc) is the reason why our implementation uses Enums instead of Strings. Nevertheless for cases when only the name of the weights is available, we offer a method capable of linking Weight names to their Enums:
154
154
155
-
```python
155
+
```Python
156
156
from torchvision.prototype.models import get_weight
In the new API the boolean `pretrained` and `pretrained_backbone` parameters, which were previously used to load weights to the full model or to its backbone, are deprecated. The current implementation is fully backwards compatible as it seamlessly maps the old parameters to the new ones. Using the old parameters to the new builders emits the following deprecation warnings:
169
169
170
-
```python
170
+
```Python
171
171
>>> model = torchvision.prototype.models.resnet50(pretrained=True)
172
172
UserWarning: The parameter 'pretrained'is deprecated, please use 'weights' instead.
173
-
UserWarning:
174
-
Arguments other than a weight enum or`None`for'weights' are deprecated.
175
-
The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`.
173
+
UserWarning:
174
+
Arguments other than a weight enum or`None`for'weights' are deprecated.
175
+
The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`.
176
176
You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.
177
177
```
178
178
179
179
Additionally the builder methods require using keyword parameters. The use of positional parameter is deprecated and using them emits the following warning:
180
180
181
-
```python
181
+
```Python
182
182
>>> model = torchvision.prototype.models.resnet50(None)
183
-
UserWarning:
184
-
Using 'weights'as positional parameter(s) is deprecated.
183
+
UserWarning:
184
+
Using 'weights'as positional parameter(s) is deprecated.
185
185
Please use keyword parameter(s) instead.
186
186
```
187
187
@@ -217,32 +217,32 @@ If you are still unconvinced about giving a try to the new API, here is one more
Please spare a few minutes to provide your feedback on the new API, as this is crucial for graduating it from prototype and including it in the next release. You can do this on the dedicated [Github Issue](https://github.com/pytorch/vision/issues/5088). We are looking forward to reading your comments!
@@ -158,7 +158,7 @@ Learn more with our [docs](https://pytorch.org/torcheval), see our [examples](ht
158
158
159
159
### TorchMultimodal Release (Beta)
160
160
161
-
Please watch for upcoming blogs in early November that will introduce TorchMultimodal, a PyTorch domain library for training SoTA multi-task multimodal models at scale, in more details; in the meantime, play around with the library and models through our [tutorial](https://github.com/pytorch/tutorials/pull/2054).
161
+
Please watch for upcoming blogs in early November that will introduce TorchMultimodal, a PyTorch domain library for training SoTA multi-task multimodal models at scale, in more details; in the meantime, play around with the library and models through our [tutorial](https://pytorch.org/tutorials/beginner/flava_finetuning_tutorial.html).
0 commit comments