Skip to content

Broadcasting behaviour for linear algebra solvers #52915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
IvanYashchuk opened this issue Feb 26, 2021 · 3 comments
Open

Broadcasting behaviour for linear algebra solvers #52915

IvanYashchuk opened this issue Feb 26, 2021 · 3 comments
Labels
module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Feb 26, 2021

🚀 Feature Discussion

The question is whether solve-like functions need to support batch broadcasting for b of shape (n,).

PyTorch currently includes several functions for problems of type: find x s.t. ||Ax - b|| is minimized (Ax = b). Let's call it "solve-like functions":

torch.linalg.solve
torch.solve
torch.cholesky_solve
torch.triangular_solve
torch.lu_solve
torch.lstsq

For A with shape (n, n) only torch.linalg.solve allows the b input of shape (n,) or (n, nrhs). Other functions require the b input to be a 2-dimensional tensor of shape (n, nrhs).
Supporting 1-dimensional b input is NumPy and SciPy compatible. SciPy doesn't support batched inputs, NumPy supports batched inputs for numpy.linalg.solve, but not for numpy.linalg.lstsq.

numpy.linalg.solve supports batch-wise broadcasting only for (n, nrhs) type of b inputs:

import torch
import numpy as np
a = torch.randn(2, 3, 1, 3, 3)
b = torch.randn(3)
np.linalg.solve(a, b) # doesn't work
np.linalg.solve(a, b.unsqueeze(-1)) # this one works
# both cases work currently for torch.linalg.solve
a = torch.randn(3, 3)
b = torch.randn(2, 3, 1, 3)
np.linalg.solve(a, b) # doesn't work
np.linalg.solve(a, b.unsqueeze(-1)) # this one works
# torch.linalg.solve currently has the same behaviour

NumPy's behaviour makes sense because (a.inverse() @ b).shape = torch.Size([2, 3, 1, 3]) and batched matrix multiplication doesn't work for a @ (a.inverse() @ b), but works for a @ (a.inverse() @ b).unsqueeze(-1).

For NumPy compatibility, we need to support batch broadcasting for b of shape (n, nrhs) for torch.linalg.solve and consequently apply the same behavior to all solve-like functions.

Do solve-like functions need to support batch broadcasting for b of shape (n,)?

The problem here is ambiguity for deciding whether we have a matrix or vector b. For example for A of shape (3, 3, 3) how should we interpret b of shape (3, 3) is a single matrix input to be batch broadcasted or a batch of vectors.

Currently torch.linalg.solve treats the matrix case as primary and b is regarded as vector if b.ndim == 1 or ((A.ndim-b.ndim == 1) and (A.shape[:-1] == b.shape)). This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389.

Interestingly NumPy fails for this case:

a = torch.randn(3, 3, 3)
b = torch.randn(3)
np.linalg.solve(a, b) # doesn't work
np.linalg.solve(a, b.unsqueeze(-1)) # also doesn't work
# both cases work with torch.linalg.solve

Additional context

Memory inefficiency of the actual implementation is discussed here #49252.

cc @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr @IvanYashchuk

@IvanYashchuk IvanYashchuk added the module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul label Feb 26, 2021
@anjali411 anjali411 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 26, 2021
@lezcano
Copy link
Collaborator

lezcano commented Aug 27, 2021

Would this need anything else other than prefixing these functions by a:

TORCH_ASSERT(A.ndim() >= 2);
const auto compatible = [](IntArrayRef a, IntArrayRef b){ return a.size() - 1 == b.size() && 
                                                                 std::equal(a.begin(), a.end() - 2, b.begin()); }
const Tensor B_ =  b.dim() == 1 || compatible(A.sizes(), b.sizes()) ? b.unsqueeze(-1) : b;

?

@lezcano
Copy link
Collaborator

lezcano commented Aug 27, 2021

Actually, that does not take into account broadcasting. This should be done inside the _linalg_broadcast_batch_dims, and we should remember to unsqueeze the dimension before returning the result.

@asmeurer
Copy link
Collaborator

FYI, the ambiguous behavior of np.linalg.solve has been removed in NumPy 2.0. See numpy/numpy#25914. The logic now follows matmul, where solve(a, b) only applies the single-column logic to b if it is 1-dimensional, and treats it as a batch of 2-D matrices in all other cases.

This also aligns with the array API (data-apis/array-api#285).

The logic currently in PyTorch doesn't even match NumPy 1.26, because it only applies the a.ndim - 1 == b.ndim logic when there is no other broadcasting (no size 1 dimensions). If there is a size 1 dimension, torch.linalg.solve falls back to the batched matrix case, whereas NumPy 1.26 solve always uses the batched-column case whenever a.ndim - 1 == b.ndim:

>>> torch.linalg.solve(torch.rand((2, 1, 2, 2)), torch.rand((2, 2, 2))).shape
torch.Size([2, 2, 2, 2])
>>> np.linalg.solve(np.random.random((2, 1, 2, 2)), np.random.random((2, 2, 2))).shape # numpy 1.26
(2, 2, 2)

But regardless, PyTorch should just update to match the array API and NumPy 2.0 behavior, because it removes inherent ambiguity in the definition and matches the behavior of matmul. Any users relying on the old behavior can easily work around it with solve(a, b[..., None])[..., 0].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants