-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Description
Today PyTorch's test suite principally uses assertEqual(), a TestCase method, to compare tensors (and containers of tensors, scalars, and NumPy arrays).
pytorch/torch/testing/_internal/common_utils.py
Line 1229 in ea4af15
def assertEqual(self, x, y, msg: Optional[str] = None, *, |
Unfortunately, this method is imperfect:
- it's a method on PyTorch's TestCase, which can make using it to compare tensors in libraries built on PyTorch tricky, and in practice there's significant use of torch.equal() to compare tensors in library tests; torch.equal() only allows comparing two tensors for bitwise equality, however
- the
exact_device
kwarg, which requires devices match, is still set to False for backwards compatibility; allowing tensors on different devices to compare as equal has been the source of bugs in the past - it's called "assertEqual," but by default it actually asserts "closeness" and not "equality"
To provide a tensor comparison function suitable for both PyTorch's test suite and the test suites of libraries built on PyTorch, this RFC proposes creating a new function, torch.testing.assert_close(), with the following signature:
# Proposed testing function signature
torch.testing.assert_close(actual, expected, *, msg=None, rtol=None, atol=None, equal_nan=True, check_device=True, check_dtype=True)
# Current assertEqual signature for comparison
torch.testing.TestCase.assertEqual(self, x, y, msg=None, *, rtol, atol, equal_nan=True, exact_device=False, exact_dtype=True)
The new signature and the current signature are very similar. Their differences are:
- the new function is named assert_close to more accurately reflect its behavior and not to conflict with NumPy's testing.assert_equal name
- msg is a kwarg-only argument
- "exact_device" and "exact_dtype" have been changed to "check_device" and "check_dtype" for consistency
- the default value for "exact_device" is changed to True (note this is a TODO for assertEqual())
The behavior of this new function will be almost identical to assertEquals' behavior, except for the different default values and possibly new features like improved support for error handling.
I'd like to further propose that if and when we're happy with testing.assert_close we replace uses of assertEqual() with it in our test suite. This will create a period of a little confusion where some tests use the newer assert_close and some tests use the older assertEqual(), but our test suite has gone through more significant revisions without issue, and even assertEqual's behavior has confusingly changed over time. In fact, there are still lingering remnants of those changes left in the codebase, like assertEqualIgnoreType().
cc @ngimel @ezyang @cpuhrsch @nairbv @anjali411 @mthrok @pmeier