Skip to content

Resolved: add torch.testing.assert_close() #56544

@mruberry

Description

@mruberry

Today PyTorch's test suite principally uses assertEqual(), a TestCase method, to compare tensors (and containers of tensors, scalars, and NumPy arrays).

def assertEqual(self, x, y, msg: Optional[str] = None, *,

Unfortunately, this method is imperfect:

  • it's a method on PyTorch's TestCase, which can make using it to compare tensors in libraries built on PyTorch tricky, and in practice there's significant use of torch.equal() to compare tensors in library tests; torch.equal() only allows comparing two tensors for bitwise equality, however
  • the exact_device kwarg, which requires devices match, is still set to False for backwards compatibility; allowing tensors on different devices to compare as equal has been the source of bugs in the past
  • it's called "assertEqual," but by default it actually asserts "closeness" and not "equality"

To provide a tensor comparison function suitable for both PyTorch's test suite and the test suites of libraries built on PyTorch, this RFC proposes creating a new function, torch.testing.assert_close(), with the following signature:

# Proposed testing function signature
torch.testing.assert_close(actual, expected, *, msg=None, rtol=None, atol=None, equal_nan=True, check_device=True, check_dtype=True)

# Current assertEqual signature for comparison
torch.testing.TestCase.assertEqual(self, x, y, msg=None, *, rtol, atol, equal_nan=True, exact_device=False, exact_dtype=True)

The new signature and the current signature are very similar. Their differences are:

  • the new function is named assert_close to more accurately reflect its behavior and not to conflict with NumPy's testing.assert_equal name
  • msg is a kwarg-only argument
  • "exact_device" and "exact_dtype" have been changed to "check_device" and "check_dtype" for consistency
  • the default value for "exact_device" is changed to True (note this is a TODO for assertEqual())

The behavior of this new function will be almost identical to assertEquals' behavior, except for the different default values and possibly new features like improved support for error handling.

I'd like to further propose that if and when we're happy with testing.assert_close we replace uses of assertEqual() with it in our test suite. This will create a period of a little confusion where some tests use the newer assert_close and some tests use the older assertEqual(), but our test suite has gone through more significant revisions without issue, and even assertEqual's behavior has confusingly changed over time. In fact, there are still lingering remnants of those changes left in the codebase, like assertEqualIgnoreType().

cc @ngimel @ezyang @cpuhrsch @nairbv @anjali411 @mthrok @pmeier

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: testingIssues related to the torch.testing module (not tests)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions