To compare two tensors element-wise in PyTorch, we use the torch.eq() method. It compares the corresponding elements and returns "True" if the two elements are same, else it returns "False". We can compare two tensors with same or different dimensions, but the size of both the tensors must match at non-singleton dimension.
Steps
Import the required library. In all the following Python examples, the required Python library is torch. Make sure you have already installed it.
Create a PyTorch tensor and print it.
Compute torch.eq(input1, input2). It returns a tensor of "True" and/or "False". It compares the tensor element-wise, and returns True if the corresponding elements are equal, else it returns False.
Print the returned tensor.
Example 1
The following Python program shows how to compare two 1-D tensors element-wise.
# import necessary library import torch # Create two tensors T1 = torch.Tensor([2.4,5.4,-3.44,-5.43,43.5]) T2 = torch.Tensor([2.4,5.5,-3.44,-5.43, 43]) # print above created tensors print("T1:", T1) print("T2:", T2) # Compare tensors T1 and T2 element-wise print(torch.eq(T1, T2))
Output
T1: tensor([ 2.4000, 5.4000, -3.4400, -5.4300, 43.5000]) T2: tensor([ 2.4000, 5.5000, -3.4400, -5.4300, 43.0000]) tensor([ True, False, True, True, False])
Example 2
The following Python program shows how to compare two 2-D tensors element-wise.
# import necessary library import torch # create two 4x3 2D tensors T1 = torch.Tensor([[2,3,-32], [43,4,-53], [4,37,-4], [3,75,34]]) T2 = torch.Tensor([[2,3,-32], [4,4,-53], [4,37,4], [3,-75,34]]) # print above created tensors print("T1:", T1) print("T2:", T2) # Conpare tensors T1 and T2 element-wise print(torch.eq(T1, T2))
Output
T1: tensor([[ 2., 3., -32.], [ 43., 4., -53.], [ 4., 37., -4.], [ 3., 75., 34.]]) T2: tensor([[ 2., 3., -32.], [ 4., 4., -53.], [ 4., 37., 4.], [ 3., -75., 34.]]) tensor([[ True, True, True], [False, True, True], [ True, True, False], [ True, False, True]])
Example 3
The following Python program shows how to compare a 1-D tensor with a 2-D tensor element-wise.
# import necessary library import torch # Create two tensors T1 = torch.Tensor([2.4,5.4,-3.44,-5.43,43.5]) T2 = torch.Tensor([[2.4,5.5,-3.44,-5.43, 7], [1.0,5.4,3.88,4.0,5.78]]) # Print above created tensors print("T1:", T1) print("T2:", T2) # Compare the tensors T1 and T2 element-wise print(torch.eq(T1, T2))
Output
T1: tensor([ 2.4000, 5.4000, -3.4400, -5.4300, 43.5000]) T2: tensor([[ 2.4000, 5.5000, -3.4400, -5.4300, 7.0000], [ 1.0000, 5.4000, 3.8800, 4.0000, 5.7800]]) tensor([[ True, False, True, True, False], [False, True, False, False, False]])