tf.nn.max_pool
Stay organized with collections
Save and categorize content based on your preferences.
Performs max pooling on the input.
tf.nn.max_pool(
input, ksize, strides, padding, data_format=None, name=None
)
For a given window of ksize
, takes the maximum value within that window.
Used for reducing computation and preventing overfitting.
Consider an example of pooling with 2x2, non-overlapping windows:
matrix = tf.constant([
[0, 0, 1, 7],
[0, 2, 0, 0],
[5, 2, 0, 0],
[0, 0, 9, 8],
])
reshaped = tf.reshape(matrix, (1, 4, 4, 1))
tf.nn.max_pool(reshaped, ksize=2, strides=2, padding="SAME")
<tf.Tensor: shape=(1, 2, 2, 1), dtype=int32, numpy=
array([[[[2],
[7]],
[[5],
[9]]]], dtype=int32)>
We can adjust the window size using the ksize
parameter. For example, if we
were to expand the window to 3:
tf.nn.max_pool(reshaped, ksize=3, strides=2, padding="SAME")
<tf.Tensor: shape=(1, 2, 2, 1), dtype=int32, numpy=
array([[[[5],
[7]],
[[9],
[9]]]], dtype=int32)>
We've now picked up two additional large numbers (5 and 9) in two of the
pooled spots.
Note that our windows are now overlapping, since we're still moving by 2 units
on each iteration. This is causing us to see the same 9 repeated twice, since
it is part of two overlapping windows.
We can adjust how far we move our window with each iteration using the
strides
parameter. Updating this to the same value as our window size
eliminates the overlap:
tf.nn.max_pool(reshaped, ksize=3, strides=3, padding="SAME")
<tf.Tensor: shape=(1, 2, 2, 1), dtype=int32, numpy=
array([[[[2],
[7]],
[[5],
[9]]]], dtype=int32)>
Because the window does not neatly fit into our input, padding is added around
the edges, giving us the same result as when we used a 2x2 window. We can skip
padding altogether and simply drop the windows that do not fully fit into our
input by instead passing "VALID"
to the padding
argument:
tf.nn.max_pool(reshaped, ksize=3, strides=3, padding="VALID")
<tf.Tensor: shape=(1, 1, 1, 1), dtype=int32, numpy=array([[[[5]]]],
dtype=int32)>
Now we've grabbed the largest value in the 3x3 window starting from the upper-
left corner. Since no other windows fit in our input, they are dropped.
Args |
input
|
Tensor of rank N+2, of shape [batch_size] + input_spatial_shape +
[num_channels] if data_format does not start with "NC" (default), or
[batch_size, num_channels] + input_spatial_shape if data_format starts
with "NC". Pooling happens over the spatial dimensions only.
|
ksize
|
An int or list of ints that has length 1 , N or N+2 . The size
of the window for each dimension of the input tensor.
|
strides
|
An int or list of ints that has length 1 , N or N+2 . The
stride of the sliding window for each dimension of the input tensor.
|
padding
|
Either the string "SAME" or "VALID" indicating the type of
padding algorithm to use, or a list indicating the explicit paddings at
the start and end of each dimension. See
here
for more information. When explicit padding is used and data_format is
"NHWC" , this should be in the form [[0, 0], [pad_top, pad_bottom],
[pad_left, pad_right], [0, 0]] . When explicit padding used and
data_format is "NCHW" , this should be in the form [[0, 0], [0, 0],
[pad_top, pad_bottom], [pad_left, pad_right]] . When using explicit
padding, the size of the paddings cannot be greater than the sliding
window size.
|
data_format
|
A string. Specifies the channel dimension. For N=1 it can be
either "NWC" (default) or "NCW", for N=2 it can be either "NHWC" (default)
or "NCHW" and for N=3 either "NDHWC" (default) or "NCDHW".
|
name
|
Optional name for the operation.
|
Returns |
A Tensor of format specified by data_format .
The max pooled output tensor.
|
Raises |
ValueError
|
If
- explicit padding is used with an input tensor of rank 5.
- explicit padding is used with data_format='NCHW_VECT_C'.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-04-26 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[],null,["# tf.nn.max_pool\n\n\u003cbr /\u003e\n\n|----------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/ops/nn_ops.py#L4705-L4845) |\n\nPerforms max pooling on the input.\n\n#### View aliases\n\n\n**Compat aliases for migration**\n\nSee\n[Migration guide](https://www.tensorflow.org/guide/migrate) for\nmore details.\n\n[`tf.compat.v1.nn.max_pool_v2`](https://www.tensorflow.org/api_docs/python/tf/nn/max_pool)\n\n\u003cbr /\u003e\n\n tf.nn.max_pool(\n input, ksize, strides, padding, data_format=None, name=None\n )\n\nFor a given window of `ksize`, takes the maximum value within that window.\nUsed for reducing computation and preventing overfitting.\n\nConsider an example of pooling with 2x2, non-overlapping windows: \n\n matrix = tf.constant([\n [0, 0, 1, 7],\n [0, 2, 0, 0],\n [5, 2, 0, 0],\n [0, 0, 9, 8],\n ])\n reshaped = tf.reshape(matrix, (1, 4, 4, 1))\n tf.nn.max_pool(reshaped, ksize=2, strides=2, padding=\"SAME\")\n \u003ctf.Tensor: shape=(1, 2, 2, 1), dtype=int32, numpy=\n array([[[[2],\n [7]],\n [[5],\n [9]]]], dtype=int32)\u003e\n\nWe can adjust the window size using the `ksize` parameter. For example, if we\nwere to expand the window to 3: \n\n tf.nn.max_pool(reshaped, ksize=3, strides=2, padding=\"SAME\")\n \u003ctf.Tensor: shape=(1, 2, 2, 1), dtype=int32, numpy=\n array([[[[5],\n [7]],\n [[9],\n [9]]]], dtype=int32)\u003e\n\nWe've now picked up two additional large numbers (5 and 9) in two of the\npooled spots.\n\nNote that our windows are now overlapping, since we're still moving by 2 units\non each iteration. This is causing us to see the same 9 repeated twice, since\nit is part of two overlapping windows.\n\nWe can adjust how far we move our window with each iteration using the\n`strides` parameter. Updating this to the same value as our window size\neliminates the overlap: \n\n tf.nn.max_pool(reshaped, ksize=3, strides=3, padding=\"SAME\")\n \u003ctf.Tensor: shape=(1, 2, 2, 1), dtype=int32, numpy=\n array([[[[2],\n [7]],\n [[5],\n [9]]]], dtype=int32)\u003e\n\nBecause the window does not neatly fit into our input, padding is added around\nthe edges, giving us the same result as when we used a 2x2 window. We can skip\npadding altogether and simply drop the windows that do not fully fit into our\ninput by instead passing `\"VALID\"` to the `padding` argument: \n\n tf.nn.max_pool(reshaped, ksize=3, strides=3, padding=\"VALID\")\n \u003ctf.Tensor: shape=(1, 1, 1, 1), dtype=int32, numpy=array([[[[5]]]],\n dtype=int32)\u003e\n\nNow we've grabbed the largest value in the 3x3 window starting from the upper-\nleft corner. Since no other windows fit in our input, they are dropped.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|---------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `input` | Tensor of rank N+2, of shape `[batch_size] + input_spatial_shape + [num_channels]` if `data_format` does not start with \"NC\" (default), or `[batch_size, num_channels] + input_spatial_shape` if data_format starts with \"NC\". Pooling happens over the spatial dimensions only. |\n| `ksize` | An int or list of `ints` that has length `1`, `N` or `N+2`. The size of the window for each dimension of the input tensor. |\n| `strides` | An int or list of `ints` that has length `1`, `N` or `N+2`. The stride of the sliding window for each dimension of the input tensor. |\n| `padding` | Either the `string` `\"SAME\"` or `\"VALID\"` indicating the type of padding algorithm to use, or a list indicating the explicit paddings at the start and end of each dimension. See [here](https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2) for more information. When explicit padding is used and data_format is `\"NHWC\"`, this should be in the form `[[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]`. When explicit padding used and data_format is `\"NCHW\"`, this should be in the form `[[0, 0], [0, 0], [pad_top, pad_bottom], [pad_left, pad_right]]`. When using explicit padding, the size of the paddings cannot be greater than the sliding window size. |\n| `data_format` | A string. Specifies the channel dimension. For N=1 it can be either \"NWC\" (default) or \"NCW\", for N=2 it can be either \"NHWC\" (default) or \"NCHW\" and for N=3 either \"NDHWC\" (default) or \"NCDHW\". |\n| `name` | Optional name for the operation. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| A `Tensor` of format specified by `data_format`. The max pooled output tensor. ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|--------------|---------------------------------------------------------------------------------------------------------------------------------|\n| `ValueError` | If \u003cbr /\u003e - explicit padding is used with an input tensor of rank 5. - explicit padding is used with data_format='NCHW_VECT_C'. |"]]