Skip to content

RFC: retire torch.testing.assert_allclose in favor of torch.testing.assert_close #61844

@pmeier

Description

@pmeier

TL;DR

We are planning to deprecate and later remove the undocumented torch.testing.assert_allclose in favor of torch.testing.assert_close that was introduced with torch==1.9.0. They both have the same purpose with the latter being more strict by default but also more configurable.

Status quo

assert_close exists in the torch.testing namespace since torch==0.4.0. It was introduced as testing utility for numerics similar to numpy.testing.assert_allclose, but was never documented.

Although in the meantime TestCase.assertEqual took over in most places internally, torch.testing.assert_allclose is still used in some places. More importantly though, due to the lack of public tensor testing functions, torch.testing.assert_allclose is used in downstream projects. A simple GitHub search turns up approximately 1k hits.

In torch==1.9.0 we introduced torch.testing.assert_close that has the same purpose but is fully documented. There are a few differences compared to torch.testing.assert_allclose that will be showcased next.

Differences

The design goal of torch.testing.assert_close is to make it very clear what is being tested without making the user jump to hoops to achieve this. In general that means assert_close is stricter by default than assert_allclose, but is highly configurable.

Default tolerances

assert_allclose uses

_default_tolerances = {
'float64': (1e-5, 1e-8), # NumPy default
'float32': (1e-4, 1e-5), # This may need to be changed
'float16': (1e-3, 1e-3), # This may need to be changed
}

whereas assert_close uses

_DTYPE_PRECISIONS = {
torch.float16: (0.001, 1e-5),
torch.bfloat16: (0.016, 1e-5),
torch.float32: (1.3e-6, 1e-5),
torch.float64: (1e-7, 1e-7),
torch.complex32: (0.001, 1e-5),
torch.complex64: (1.3e-6, 1e-5),
torch.complex128: (1e-7, 1e-7),
}

For both the tolerances are 0 in case the dtype is not in the mapping. That means that assert_allclose does not check for closeness but rather for equality in case the inputs are complex or of dtype bfloat16.

Input types

assert_allclose converts everything to a tensor

if not isinstance(actual, torch.Tensor):
actual = torch.tensor(actual)
if not isinstance(expected, torch.Tensor):
expected = torch.tensor(expected, dtype=actual.dtype)

whereas assert_close requires a direct relation of the input types:

msg_fmtstr = f"Except for Python scalars, {{}}, but got {type(actual)} and {type(expected)} instead."
directly_related = isinstance(actual, type(expected)) or isinstance(expected, type(actual))
if not directly_related:
return _TestingErrorMeta(AssertionError, msg_fmtstr.format("input types need to be directly related"))
if allow_subclasses or type(actual) is type(expected):
return None
return _TestingErrorMeta(AssertionError, msg_fmtstr.format("type equality is required if allow_subclasses=False"))

That means assert_close does not support comparing a torch.Tensor to a numpy.ndarray or a sequence of scalars, which is possible with assert_allclose. On the other hand this enables assert_close to effectively check (nested) containers, i.e. sequences or mappings, and provide a traceback where the failure happened.

Attribute checking / equal_nan

assert_allclose does not supported any attribute checking besides the shape, which is necessary anyway to compare the values. assert_close supports for example checking the devices, dtypes, or strides for equality (the former two are checked by default). Together with equal_nan=True, assert_allclose is in general more permissive than assert_close.

Non-default tensors

In contrast to assert_allclose, assert_close has support for sparse COO / CSR and quantized tensors. These changes are already supported in the nightly releases and will make their way into the next release.

Upgrade guide

We encourage all users to try torch.testing.assert_close as a 1-to-1 replacement for torch.testing.assert_allclose if the use case allows it. Otherwise a thin wrapper around torch.testing.assert_close can be used to get the exact behavior of assert_allclose back:

_DTYPE_PRECISIONS = {
    torch.float16: (1e-3, 1e-3),
    torch.float32: (1e-4, 1e-5),
    torch.float64: (1e-5, 1e-8),
}


def _get_default_rtol_and_atol(actual: torch.Tensor, expected: torch.Tensor) -> Tuple[float, float]:
    actual_rtol, actual_atol = _DTYPE_PRECISIONS.get(actual.dtype, (0.0, 0.0))
    expected_rtol, expected_atol = _DTYPE_PRECISIONS.get(expected.dtype, (0.0, 0.0))
    return max(actual_rtol, expected_rtol), max(actual_atol, expected_atol)


def assert_allclose(
    actual: Any,
    expected: Any,
    rtol: Optional[float] = None,
    atol: Optional[float] = None,
    equal_nan: bool = True,
    msg: str = "",
) -> None:
    if not isinstance(actual, torch.Tensor):
        actual = torch.tensor(actual)
    if not isinstance(expected, torch.Tensor):
        expected = torch.tensor(expected, dtype=actual.dtype)

    if rtol is None and atol is None:
        rtol, atol = _get_default_rtol_and_atol(actual, expected)

    torch.testing.assert_close(
        actual,
        expected,
        rtol=rtol,
        atol=atol,
        equal_nan=equal_nan,
        check_device=True,
        check_dtype=False,
        check_stride=False,
        check_is_coalesced=False,
        msg=msg or None,
    )

RFC

Although assert_allclose is not documented at all, the current plan is too deprecate it anyway due its widespread usage (see #61841). Please weigh in here if you disagree with this plan.

cc'ing maintainers of projects in the official PyTorch ecosystem that currently use torch.testing.assert_allclose and thus would be affected by its retirement (apologies in advance if I picked the "wrong" maintainer).

cc @mruberry @VitalyFedyunin @walterddr

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: deprecationmodule: testingIssues related to the torch.testing module (not tests)module: testsIssues related to tests (not the torch.testing module)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions