numpy.squeeze() in Python
Last Updated :
15 Apr, 2025
The numpy.squeeze() is a useful Python function, which is utilized for the removal of single-dimensional elements from the shape of a NumPy array. It comes in very handy when you have to discard redundant dimensions (like a dimension with size 1) after operations that introduce extra dimensions.
Basic usage of numpy.squeeze()
Python
import numpy as np
in_arr = np.array([[[2, 2, 2], [2, 2, 2]]])
print ("Input array : ", in_arr)
print("Shape of input array : ", in_arr.shape)
out_arr = np.squeeze(in_arr)
print ("output squeezed array : ", out_arr)
print("Shape of output array : ", out_arr.shape)
Output :
Input array : [[[2 2 2]
[2 2 2]]]
Shape of input array : (1, 2, 3)
output squeezed array : [[2 2 2]
[2 2 2]]
Shape of output array : (2, 3)
Explanation:
- Input: A 3D array of shape
(1, 2, 3)
—the first dimension has size 1. - Output: The
numpy.squeeze()
function removes the first dimension of size 1, resulting in a 2D array of shape (2, 3)
.
Syntax of numpy.squeeze() in Python
numpy.squeeze(arr, axis=None )
Parameters:
- arr: Input array.
- axis: Selects a subset of the single-dimensional entries in the shape. If an axis is selected with shape entry greater than one, an error is raised.
Return Type: The function returns a new array, which is a view of the input array with the single-dimensional entries removed from its shape
Example 1: Using the axis
parameter to squeeze a specific dimension
Python
import numpy as np
in_arr = geek.arange(9).reshape(1, 3, 3)
print ("Input array : ", in_arr)
out_arr = np.squeeze(in_arr, axis = 0)
print ("output array : ", out_arr)
print("The shapes of Input and Output array : ")
print(in_arr.shape, out_arr.shape)
Output :
Input array : [[[0 1 2]
[3 4 5]
[6 7 8]]]
output array : [[0 1 2]
[3 4 5]
[6 7 8]]
The shapes of Input and Output array :
(1, 3, 3) (3, 3)
Explanation: The input array has shape (1, 3, 3)
. By specifying axis = 0
, we remove the first dimension, resulting in an output array with shape (3, 3)
.
Example 3: Error when trying to squeeze a non-singleton dimension
Python
import numpy as np
in_arr = np.arange(9).reshape(1, 3, 3)
print("Input array: ", in_arr)
# Trying to squeeze the axis that doesn't have size 1
try:
out_arr = np.squeeze(in_arr, axis=1)
except ValueError as e:
print("Error: ", e
Output :
Input array: [[[0 1 2]
[3 4 5]
[6 7 8]]]
Error: cannot select an axis to squeeze out which has size not equal to one
Explanation: The input array has a shape of (1, 3, 3)
. The attempt to squeeze axis = 1
, which corresponds to the second dimension with size 3, results in a ValueError
, as the dimension is not of size 1.
Similar Reads
Python | Numpy matrix.squeeze() With the help of matrix.squeeze() method, we are able to squeeze the size of a matrix by using the same method. But remember one thing we use this method on Nx1 size of matrix which gives out as 1xN matrix. Syntax : matrix.squeeze() Return : Return a squeezed matrix Example #1 : In this example we a
1 min read
numpy.where() in Python We will explore the basics of numpy.where(), how it works, and practical use cases to illustrate its importance in data manipulation and analysis.Syntax of numpy.where()Syntax :numpy.where(condition[, x, y]) Parameters condition: A condition that tests elements of the array.x (optional): Values from
3 min read
numpy.take() in Python The numpy.take() function returns elements from array along the mentioned axis and indices. Syntax: numpy.take(array, indices, axis = None, out = None, mode ='raise') Parameters : array : array_like, input array indices : index of the values to be fetched axis : [int, optional] axis over which we ne
2 min read
numpy.trunc() in Python The numpy.trunc() is a mathematical function that returns the truncated value of the elements of array. The trunc of the scalar x is the nearest integer i which, closer to zero than x. This simply means that, the fractional part of the signed number x is discarded by this function. Syntax : numpy.tr
2 min read
numpy.round_() in Python The round_() function in NumPy rounds the elements of an array to a specified number of decimal places. This function is extremely useful when working with floating-point numbers and when precision is important in scientific computing or data analysis.Syntax: numpy.round_(arr, decimals=0, out=None)P
3 min read
numpy.rint() in Python The numpy.rint() is a mathematical function that rounds elements of the array to the nearest integer. Syntax : numpy.rint(x[, out]) = ufunc ârintâ) Parameters : array : [array_like] Input array. Return : An array with all array elements being rounded off, having same type and shape as input. Code #1
2 min read