-
Notifications
You must be signed in to change notification settings - Fork 24k
/
Copy pathlinear.py
302 lines (246 loc) · 11 KB
/
linear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
# mypy: allow-untyped-defs
import math
from typing import Any
import torch
from torch import Tensor
from torch.nn import functional as F, init
from torch.nn.parameter import Parameter, UninitializedParameter
from .lazy import LazyModuleMixin
from .module import Module
__all__ = [
"Bilinear",
"Identity",
"LazyLinear",
"Linear",
]
class Identity(Module):
r"""A placeholder identity operator that is argument-insensitive.
Args:
args: any argument (unused)
kwargs: any keyword argument (unused)
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
Examples::
>>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 20])
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
def forward(self, input: Tensor) -> Tensor:
return input
class Linear(Module):
r"""Applies an affine linear transformation to the incoming data: :math:`y = xA^T + b`.
This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
Args:
in_features: size of each input sample
out_features: size of each output sample
bias: If set to ``False``, the layer will not learn an additive bias.
Default: ``True``
Shape:
- Input: :math:`(*, H_\text{in})` where :math:`*` means any number of
dimensions including none and :math:`H_\text{in} = \text{in\_features}`.
- Output: :math:`(*, H_\text{out})` where all but the last dimension
are the same shape as the input and :math:`H_\text{out} = \text{out\_features}`.
Attributes:
weight: the learnable weights of the module of shape
:math:`(\text{out\_features}, \text{in\_features})`. The values are
initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
:math:`k = \frac{1}{\text{in\_features}}`
bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
If :attr:`bias` is ``True``, the values are initialized from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in\_features}}`
Examples::
>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: Tensor
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(
torch.empty((out_features, in_features), **factory_kwargs)
)
if bias:
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
# https://github.com/pytorch/pytorch/issues/57109
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
def forward(self, input: Tensor) -> Tensor:
return F.linear(input, self.weight, self.bias)
def extra_repr(self) -> str:
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
# This class exists solely to avoid triggering an obscure error when scripting
# an improperly quantized attention layer. See this issue for details:
# https://github.com/pytorch/pytorch/issues/58969
# TODO: fail fast on quantization API usage error, then remove this class
# and replace uses of it with plain Linear
class NonDynamicallyQuantizableLinear(Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__(
in_features, out_features, bias=bias, device=device, dtype=dtype
)
class Bilinear(Module):
r"""Applies a bilinear transformation to the incoming data: :math:`y = x_1^T A x_2 + b`.
Args:
in1_features: size of each first input sample, must be > 0
in2_features: size of each second input sample, must be > 0
out_features: size of each output sample, must be > 0
bias: If set to ``False``, the layer will not learn an additive bias.
Default: ``True``
Shape:
- Input1: :math:`(*, H_\text{in1})` where :math:`H_\text{in1}=\text{in1\_features}` and
:math:`*` means any number of additional dimensions including none. All but the last dimension
of the inputs should be the same.
- Input2: :math:`(*, H_\text{in2})` where :math:`H_\text{in2}=\text{in2\_features}`.
- Output: :math:`(*, H_\text{out})` where :math:`H_\text{out}=\text{out\_features}`
and all but the last dimension are the same shape as the input.
Attributes:
weight: the learnable weights of the module of shape
:math:`(\text{out\_features}, \text{in1\_features}, \text{in2\_features})`.
The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
:math:`k = \frac{1}{\text{in1\_features}}`
bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
If :attr:`bias` is ``True``, the values are initialized from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
:math:`k = \frac{1}{\text{in1\_features}}`
Examples::
>>> m = nn.Bilinear(20, 30, 40)
>>> input1 = torch.randn(128, 20)
>>> input2 = torch.randn(128, 30)
>>> output = m(input1, input2)
>>> print(output.size())
torch.Size([128, 40])
"""
__constants__ = ["in1_features", "in2_features", "out_features"]
in1_features: int
in2_features: int
out_features: int
weight: Tensor
def __init__(
self,
in1_features: int,
in2_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if in1_features <= 0:
raise ValueError(f"in1_features must be > 0, but got {in1_features}")
self.in1_features = in1_features
self.in2_features = in2_features
self.out_features = out_features
self.weight = Parameter(
torch.empty((out_features, in1_features, in2_features), **factory_kwargs)
)
if bias:
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
bound = 1 / math.sqrt(self.weight.size(1))
init.uniform_(self.weight, -bound, bound)
if self.bias is not None:
init.uniform_(self.bias, -bound, bound)
def forward(self, input1: Tensor, input2: Tensor) -> Tensor:
return F.bilinear(input1, input2, self.weight, self.bias)
def extra_repr(self) -> str:
return (
f"in1_features={self.in1_features}, in2_features={self.in2_features}, "
f"out_features={self.out_features}, bias={self.bias is not None}"
)
class LazyLinear(LazyModuleMixin, Linear):
r"""A :class:`torch.nn.Linear` module where `in_features` is inferred.
In this module, the `weight` and `bias` are of :class:`torch.nn.UninitializedParameter`
class. They will be initialized after the first call to ``forward`` is done and the
module will become a regular :class:`torch.nn.Linear` module. The ``in_features`` argument
of the :class:`Linear` is inferred from the ``input.shape[-1]``.
Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
on lazy modules and their limitations.
Args:
out_features: size of each output sample
bias: If set to ``False``, the layer will not learn an additive bias.
Default: ``True``
Attributes:
weight: the learnable weights of the module of shape
:math:`(\text{out\_features}, \text{in\_features})`. The values are
initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
:math:`k = \frac{1}{\text{in\_features}}`
bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
If :attr:`bias` is ``True``, the values are initialized from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in\_features}}`
"""
cls_to_become = Linear # type: ignore[assignment]
weight: UninitializedParameter
bias: UninitializedParameter # type: ignore[assignment]
def __init__(
self, out_features: int, bias: bool = True, device=None, dtype=None
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
# bias is hardcoded to False to avoid creating tensor
# that will soon be overwritten.
super().__init__(0, 0, False)
self.weight = UninitializedParameter(**factory_kwargs)
self.out_features = out_features
if bias:
self.bias = UninitializedParameter(**factory_kwargs)
def reset_parameters(self) -> None:
if not self.has_uninitialized_params() and self.in_features != 0:
super().reset_parameters()
def initialize_parameters(self, input) -> None: # type: ignore[override]
if self.has_uninitialized_params():
with torch.no_grad():
self.in_features = input.shape[-1]
self.weight.materialize((self.out_features, self.in_features))
if self.bias is not None:
self.bias.materialize((self.out_features,))
self.reset_parameters()
if self.in_features == 0:
assert input.shape[-1] == self.weight.shape[-1], (
f"The in_features inferred from input: {input.shape[-1]} "
f"is not equal to in_features from self.weight: "
f"{self.weight.shape[-1]}"
)
self.in_features = input.shape[-1]
# TODO: PartialLinear - maybe in sparse?