
Data Structure
Networking
RDBMS
Operating System
Java
MS Excel
iOS
HTML
CSS
Android
Python
C Programming
C++
C#
MongoDB
MySQL
Javascript
PHP
- Selected Reading
- UPSC IAS Exams Notes
- Developer's Best Practices
- Questions and Answers
- Effective Resume Writing
- HR Interview Questions
- Computer Glossary
- Who is Who
Torch Argmax Method in Python PyTorch
To find the indices of the maximum value of the elements in an input tensor, we can apply the torch.argmax() function. It returns the indices only, not the element value. If the input tensor has multiple maximal values, then the function will return the index of the first maximal element. We can apply the torch.argmax() function to compute the indices of the maximum values of a tensor across a dimension..
Syntax
torch.argmax(input)
Steps
We could use the following steps to find the indices of the maximum values of all elements in the input tensor −
Import the required library. In all the following examples, the required Python library is torch. Make sure you have already installed it.
import torch
Define an input tensor input.
input = torch.randn(3,4)
Compute the indices of the maximum values of all the elements in the tensor input.
indices = torch.argmax(input)
Print the above computed tensor with indices.
print("Indices:
", indices)
Example 1
# Import the required library import torch # define an input tensor input = torch.tensor([0., -1., 2., 8.]) # print above defined tensor print("Input Tensor:
", input) # Compute indices of the maximum value indices = torch.argmax(input) # print the indices print("Indices:
", indices)
Output
Input Tensor: tensor([ 0., -1., 2., 8.]) Indices: tensor(3)
In the above Python example, we find the index of the maximum value of the element of an input 1D tensor. The maximum value in the input tensor is 8 and the index of this element is 3.
Example 2
In this program, we compute the condition number with respect to the different matrix norms.
# Import the required library import torch # define an input tensor input = torch.randn(4,4) # print above defined tensor print("Input Tensor:
", input) # Compute indices of the maximum value indices = torch.argmax(input) # print the indices print("Indices:
", indices) # Compute indices of the maximum value in dim 0 indices = torch.argmax(input, dim=0) # print the indices print("Indices in dim 0:
", indices) # Compute indices of the maximum value in dim 1 indices = torch.argmax(input, dim=1) # print the indices print("Indices in dim 1:
", indices)
Output
Input Tensor: tensor([[-1.6729, 1.2613, -1.2882, -0.8133], [ 0.9192, 0.9301, -0.2372, 0.0162], [-0.4669, 0.6604, -0.7982, 0.2621], [ 0.6436, 1.0328, 2.4573, 0.0606]]) Indices: tensor(14) Indices in dim 0: tensor([1, 0, 3, 2]) Indices in dim 1: tensor([1, 1, 1, 2])
In the above Python example, we find the indices of the maximum value of the element of an input 2D tensor in different dimensions. We generated the elements of the input tensor using the torch.randn() method, so you may notice getting different input tensor and indices.