-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Description
TL;DR
We are planning to deprecate and later remove the undocumented torch.testing.assert_allclose
in favor of torch.testing.assert_close
that was introduced with torch==1.9.0
. They both have the same purpose with the latter being more strict by default but also more configurable.
Status quo
assert_close
exists in the torch.testing
namespace since torch==0.4.0
. It was introduced as testing utility for numerics similar to numpy.testing.assert_allclose
, but was never documented.
Although in the meantime TestCase.assertEqual
took over in most places internally, torch.testing.assert_allclose
is still used in some places. More importantly though, due to the lack of public tensor testing functions, torch.testing.assert_allclose
is used in downstream projects. A simple GitHub search turns up approximately 1k hits.
In torch==1.9.0
we introduced torch.testing.assert_close
that has the same purpose but is fully documented. There are a few differences compared to torch.testing.assert_allclose
that will be showcased next.
Differences
The design goal of torch.testing.assert_close
is to make it very clear what is being tested without making the user jump to hoops to achieve this. In general that means assert_close
is stricter by default than assert_allclose
, but is highly configurable.
Default tolerances
assert_allclose
uses
pytorch/torch/testing/_core.py
Lines 411 to 415 in 0263865
_default_tolerances = { | |
'float64': (1e-5, 1e-8), # NumPy default | |
'float32': (1e-4, 1e-5), # This may need to be changed | |
'float16': (1e-3, 1e-3), # This may need to be changed | |
} |
whereas assert_close
uses
pytorch/torch/testing/_asserts.py
Lines 30 to 38 in 0263865
_DTYPE_PRECISIONS = { | |
torch.float16: (0.001, 1e-5), | |
torch.bfloat16: (0.016, 1e-5), | |
torch.float32: (1.3e-6, 1e-5), | |
torch.float64: (1e-7, 1e-7), | |
torch.complex32: (0.001, 1e-5), | |
torch.complex64: (1.3e-6, 1e-5), | |
torch.complex128: (1e-7, 1e-7), | |
} |
For both the tolerances are 0
in case the dtype is not in the mapping. That means that assert_allclose
does not check for closeness but rather for equality in case the inputs are complex or of dtype bfloat16
.
Input types
assert_allclose
converts everything to a tensor
pytorch/torch/testing/_core.py
Lines 254 to 257 in 0263865
if not isinstance(actual, torch.Tensor): | |
actual = torch.tensor(actual) | |
if not isinstance(expected, torch.Tensor): | |
expected = torch.tensor(expected, dtype=actual.dtype) |
whereas assert_close
requires a direct relation of the input types:
pytorch/torch/testing/_asserts.py
Lines 573 to 581 in 0263865
msg_fmtstr = f"Except for Python scalars, {{}}, but got {type(actual)} and {type(expected)} instead." | |
directly_related = isinstance(actual, type(expected)) or isinstance(expected, type(actual)) | |
if not directly_related: | |
return _TestingErrorMeta(AssertionError, msg_fmtstr.format("input types need to be directly related")) | |
if allow_subclasses or type(actual) is type(expected): | |
return None | |
return _TestingErrorMeta(AssertionError, msg_fmtstr.format("type equality is required if allow_subclasses=False")) |
That means assert_close
does not support comparing a torch.Tensor
to a numpy.ndarray
or a sequence of scalars, which is possible with assert_allclose
. On the other hand this enables assert_close
to effectively check (nested) containers, i.e. sequences or mappings, and provide a traceback where the failure happened.
Attribute checking / equal_nan
assert_allclose
does not supported any attribute checking besides the shape, which is necessary anyway to compare the values. assert_close
supports for example checking the devices, dtypes, or strides for equality (the former two are checked by default). Together with equal_nan=True
, assert_allclose
is in general more permissive than assert_close
.
Non-default tensors
In contrast to assert_allclose
, assert_close
has support for sparse COO / CSR and quantized tensors. These changes are already supported in the nightly releases and will make their way into the next release.
Upgrade guide
We encourage all users to try torch.testing.assert_close
as a 1-to-1 replacement for torch.testing.assert_allclose
if the use case allows it. Otherwise a thin wrapper around torch.testing.assert_close
can be used to get the exact behavior of assert_allclose
back:
_DTYPE_PRECISIONS = {
torch.float16: (1e-3, 1e-3),
torch.float32: (1e-4, 1e-5),
torch.float64: (1e-5, 1e-8),
}
def _get_default_rtol_and_atol(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[float, float]:
actual_rtol, actual_atol = _DTYPE_PRECISIONS.get(actual.dtype, (0.0, 0.0))
expected_rtol, expected_atol = _DTYPE_PRECISIONS.get(expected.dtype, (0.0, 0.0))
return max(actual_rtol, expected_rtol), max(actual_atol, expected_atol)
def assert_allclose(
actual: Any,
expected: Any,
rtol: Optional[float] = None,
atol: Optional[float] = None,
equal_nan: bool = True,
msg: str = "",
) -> None:
if not isinstance(actual, torch.Tensor):
actual = torch.tensor(actual)
if not isinstance(expected, torch.Tensor):
expected = torch.tensor(expected, dtype=actual.dtype)
if rtol is None and atol is None:
rtol, atol = _get_default_rtol_and_atol(actual, expected)
torch.testing.assert_close(
actual,
expected,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
check_device=True,
check_dtype=False,
check_stride=False,
check_is_coalesced=False,
msg=msg or None,
)
RFC
Although assert_allclose
is not documented at all, the current plan is too deprecate it anyway due its widespread usage (see #61841). Please weigh in here if you disagree with this plan.
cc'ing maintainers of projects in the official PyTorch ecosystem that currently use torch.testing.assert_allclose
and thus would be affected by its retirement (apologies in advance if I picked the "wrong" maintainer).