-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Broadcasting behaviour for linear algebra solvers #52915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Would this need anything else other than prefixing these functions by a: TORCH_ASSERT(A.ndim() >= 2);
const auto compatible = [](IntArrayRef a, IntArrayRef b){ return a.size() - 1 == b.size() &&
std::equal(a.begin(), a.end() - 2, b.begin()); }
const Tensor B_ = b.dim() == 1 || compatible(A.sizes(), b.sizes()) ? b.unsqueeze(-1) : b; ? |
Actually, that does not take into account broadcasting. This should be done inside the |
FYI, the ambiguous behavior of This also aligns with the array API (data-apis/array-api#285). The logic currently in PyTorch doesn't even match NumPy 1.26, because it only applies the >>> torch.linalg.solve(torch.rand((2, 1, 2, 2)), torch.rand((2, 2, 2))).shape
torch.Size([2, 2, 2, 2])
>>> np.linalg.solve(np.random.random((2, 1, 2, 2)), np.random.random((2, 2, 2))).shape # numpy 1.26
(2, 2, 2) But regardless, PyTorch should just update to match the array API and NumPy 2.0 behavior, because it removes inherent ambiguity in the definition and matches the behavior of matmul. Any users relying on the old behavior can easily work around it with |
Uh oh!
There was an error while loading. Please reload this page.
🚀 Feature Discussion
The question is whether solve-like functions need to support batch broadcasting for b of shape
(n,)
.PyTorch currently includes several functions for problems of type: find x s.t. ||Ax - b|| is minimized (Ax = b). Let's call it "solve-like functions":
For A with shape
(n, n)
onlytorch.linalg.solve
allows the b input of shape(n,)
or(n, nrhs)
. Other functions require the b input to be a 2-dimensional tensor of shape(n, nrhs)
.Supporting 1-dimensional b input is NumPy and SciPy compatible. SciPy doesn't support batched inputs, NumPy supports batched inputs for
numpy.linalg.solve
, but not fornumpy.linalg.lstsq
.numpy.linalg.solve
supports batch-wise broadcasting only for(n, nrhs)
type of b inputs:NumPy's behaviour makes sense because
(a.inverse() @ b).shape = torch.Size([2, 3, 1, 3])
and batched matrix multiplication doesn't work fora @ (a.inverse() @ b)
, but works fora @ (a.inverse() @ b).unsqueeze(-1)
.For NumPy compatibility, we need to support batch broadcasting for b of shape
(n, nrhs)
fortorch.linalg.solve
and consequently apply the same behavior to all solve-like functions.Do solve-like functions need to support batch broadcasting for b of shape
(n,)
?The problem here is ambiguity for deciding whether we have a matrix or vector b. For example for A of shape
(3, 3, 3)
how should we interpret b of shape(3, 3)
is a single matrix input to be batch broadcasted or a batch of vectors.Currently
torch.linalg.solve
treats the matrix case as primary and b is regarded as vector ifb.ndim == 1 or ((A.ndim-b.ndim == 1) and (A.shape[:-1] == b.shape))
. This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389.Interestingly NumPy fails for this case:
Additional context
Memory inefficiency of the actual implementation is discussed here #49252.
cc @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr @IvanYashchuk
The text was updated successfully, but these errors were encountered: