tf.expand_dims
Stay organized with collections
Save and categorize content based on your preferences.
Returns a tensor with a length 1 axis inserted at index axis
.
tf.expand_dims(
input, axis, name=None
)
Used in the notebooks
Used in the guide |
Used in the tutorials |
|
|
Given a tensor input
, this operation inserts a dimension of length 1 at the
dimension index axis
of input
's shape. The dimension index follows Python
indexing rules: It's zero-based, a negative index it is counted backward
from the end.
This operation is useful to:
- Add an outer "batch" dimension to a single element.
- Align axes for broadcasting.
- To add an inner vector length axis to a tensor of scalars.
For example:
If you have a single image of shape [height, width, channels]
:
image = tf.zeros([10,10,3])
You can add an outer batch
axis by passing axis=0
:
tf.expand_dims(image, axis=0).shape.as_list()
[1, 10, 10, 3]
The new axis location matches Python list.insert(axis, 1)
:
tf.expand_dims(image, axis=1).shape.as_list()
[10, 1, 10, 3]
Following standard Python indexing rules, a negative axis
counts from the
end so axis=-1
adds an inner most dimension:
tf.expand_dims(image, -1).shape.as_list()
[10, 10, 3, 1]
This operation requires that axis
is a valid index for input.shape
,
following Python indexing rules:
-1-tf.rank(input) <= axis <= tf.rank(input)
This operation is related to:
Args |
input
|
A Tensor .
|
axis
|
Integer specifying the dimension index at which to expand the
shape of input . Given an input of D dimensions, axis must be in range
[-(D+1), D] (inclusive).
|
name
|
Optional string. The name of the output Tensor .
|
Returns |
A tensor with the same data as input , with an additional dimension
inserted at the index specified by axis .
|
Raises |
TypeError
|
If axis is not specified.
|
InvalidArgumentError
|
If axis is out of range [-(D+1), D] .
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-04-26 UTC.
[null,null,["Last updated 2024-04-26 UTC."],[],[],null,["# tf.expand_dims\n\n\u003cbr /\u003e\n\n|-----------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/ops/array_ops.py#L391-L458) |\n\nReturns a tensor with a length 1 axis inserted at index `axis`. \n\n tf.expand_dims(\n input, axis, name=None\n )\n\n### Used in the notebooks\n\n| Used in the guide | Used in the tutorials |\n|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| - [Extension types](https://www.tensorflow.org/guide/extension_type) - [Import a JAX model using JAX2TF](https://www.tensorflow.org/guide/jax2tf) - [Migrate \\`tf.feature_column\\`s to Keras preprocessing layers](https://www.tensorflow.org/guide/migrate/migrating_feature_columns) - [Understanding masking \\& padding](https://www.tensorflow.org/guide/keras/understanding_masking_and_padding) - [Working with RNNs](https://www.tensorflow.org/guide/keras/working_with_rnns) | - [Integrated gradients](https://www.tensorflow.org/tutorials/interpretability/integrated_gradients) - [Playing CartPole with the Actor-Critic method](https://www.tensorflow.org/tutorials/reinforcement_learning/actor_critic) - [Generate music with an RNN](https://www.tensorflow.org/tutorials/audio/music_generation) - [DeepDream](https://www.tensorflow.org/tutorials/generative/deepdream) - [pix2pix: Image-to-image translation with a conditional GAN](https://www.tensorflow.org/tutorials/generative/pix2pix) |\n\nGiven a tensor `input`, this operation inserts a dimension of length 1 at the\ndimension index `axis` of `input`'s shape. The dimension index follows Python\nindexing rules: It's zero-based, a negative index it is counted backward\nfrom the end.\n\nThis operation is useful to:\n\n- Add an outer \"batch\" dimension to a single element.\n- Align axes for broadcasting.\n- To add an inner vector length axis to a tensor of scalars.\n\n#### For example:\n\nIf you have a single image of shape `[height, width, channels]`: \n\n image = tf.zeros([10,10,3])\n\nYou can add an outer `batch` axis by passing `axis=0`: \n\n tf.expand_dims(image, axis=0).shape.as_list()\n [1, 10, 10, 3]\n\nThe new axis location matches Python `list.insert(axis, 1)`: \n\n tf.expand_dims(image, axis=1).shape.as_list()\n [10, 1, 10, 3]\n\nFollowing standard Python indexing rules, a negative `axis` counts from the\nend so `axis=-1` adds an inner most dimension: \n\n tf.expand_dims(image, -1).shape.as_list()\n [10, 10, 3, 1]\n\nThis operation requires that `axis` is a valid index for `input.shape`,\nfollowing Python indexing rules: \n\n -1-tf.rank(input) \u003c= axis \u003c= tf.rank(input)\n\nThis operation is related to:\n\n- [`tf.squeeze`](../tf/squeeze), which removes dimensions of size 1.\n- [`tf.reshape`](../tf/reshape), which provides more flexible reshaping capability.\n- [`tf.sparse.expand_dims`](../tf/sparse/expand_dims), which provides this functionality for [`tf.SparseTensor`](../tf/sparse/SparseTensor)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|---------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `input` | A `Tensor`. |\n| `axis` | Integer specifying the dimension index at which to expand the shape of `input`. Given an input of D dimensions, `axis` must be in range `[-(D+1), D]` (inclusive). |\n| `name` | Optional string. The name of the output `Tensor`. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| A tensor with the same data as `input`, with an additional dimension inserted at the index specified by `axis`. ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|------------------------|------------------------------------------|\n| `TypeError` | If `axis` is not specified. |\n| `InvalidArgumentError` | If `axis` is out of range `[-(D+1), D]`. |\n\n\u003cbr /\u003e"]]