tf.keras.ops.scatter_update
Stay organized with collections
Save and categorize content based on your preferences.
Update inputs via updates at scattered (sparse) indices.
tf.keras.ops.scatter_update(
inputs, indices, updates
)
At a high level, this operation does inputs[indices] = updates
.
Assume inputs
is a tensor of shape (D0, D1, ..., Dn)
, there are 2 main
usages of scatter_update
.
indices
is a 2D tensor of shape (num_updates, n)
, where num_updates
is the number of updates to perform, and updates
is a 1D tensor of
shape (num_updates,)
. For example, if inputs
is zeros((4, 4, 4))
,
and we want to update inputs[1, 2, 3]
and inputs[0, 1, 3]
as 1, then
we can use:
inputs = np.zeros((4, 4, 4))
indices = [[1, 2, 3], [0, 1, 3]]
updates = np.array([1., 1.])
inputs = keras.ops.scatter_update(inputs, indices, updates)
2 indices
is a 2D tensor of shape (num_updates, k)
, where num_updates
is the number of updates to perform, and k
(k < n
) is the size of
each index in indices
. updates
is a n - k
-D tensor of shape
(num_updates, inputs.shape[k:])
. For example, if
inputs = np.zeros((4, 4, 4))
, and we want to update inputs[1, 2, :]
and inputs[2, 3, :]
as [1, 1, 1, 1]
, then indices
would have shape
(num_updates, 2)
(k = 2
), and updates
would have shape
(num_updates, 4)
(inputs.shape[2:] = 4
). See the code below:
inputs = np.zeros((4, 4, 4))
indices = [[1, 2], [2, 3]]
updates = np.array([[1., 1., 1, 1,], [1., 1., 1, 1,])
inputs = keras.ops.scatter_update(inputs, indices, updates)
Args |
inputs
|
A tensor, the tensor to be updated.
|
indices
|
A tensor or list/tuple of shape (N, inputs.ndim) , specifying
indices to update. N is the number of indices to update, must be
equal to the first dimension of updates .
|
updates
|
A tensor, the new values to be put to inputs at indices .
|
Returns |
A tensor, has the same shape and dtype as inputs .
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-06-07 UTC.
[null,null,["Last updated 2024-06-07 UTC."],[],[],null,["# tf.keras.ops.scatter_update\n\n\u003cbr /\u003e\n\n|---------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/keras-team/keras/tree/v3.3.3/keras/src/ops/core.py#L75-L124) |\n\nUpdate inputs via updates at scattered (sparse) indices. \n\n tf.keras.ops.scatter_update(\n inputs, indices, updates\n )\n\nAt a high level, this operation does `inputs[indices] = updates`.\nAssume `inputs` is a tensor of shape `(D0, D1, ..., Dn)`, there are 2 main\nusages of `scatter_update`.\n\n1. `indices` is a 2D tensor of shape `(num_updates, n)`, where `num_updates` is the number of updates to perform, and `updates` is a 1D tensor of shape `(num_updates,)`. For example, if `inputs` is `zeros((4, 4, 4))`, and we want to update `inputs[1, 2, 3]` and `inputs[0, 1, 3]` as 1, then we can use:\n\n inputs = np.zeros((4, 4, 4))\n indices = [[1, 2, 3], [0, 1, 3]]\n updates = np.array([1., 1.])\n inputs = keras.ops.scatter_update(inputs, indices, updates)\n\n2 `indices` is a 2D tensor of shape `(num_updates, k)`, where `num_updates`\nis the number of updates to perform, and `k` (`k \u003c n`) is the size of\neach index in `indices`. `updates` is a `n - k`-D tensor of shape\n`(num_updates, inputs.shape[k:])`. For example, if\n`inputs = np.zeros((4, 4, 4))`, and we want to update `inputs[1, 2, :]`\nand `inputs[2, 3, :]` as `[1, 1, 1, 1]`, then `indices` would have shape\n`(num_updates, 2)` (`k = 2`), and `updates` would have shape\n`(num_updates, 4)` (`inputs.shape[2:] = 4`). See the code below: \n\n inputs = np.zeros((4, 4, 4))\n indices = [[1, 2], [2, 3]]\n updates = np.array([[1., 1., 1, 1,], [1., 1., 1, 1,])\n inputs = keras.ops.scatter_update(inputs, indices, updates)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-----------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `inputs` | A tensor, the tensor to be updated. |\n| `indices` | A tensor or list/tuple of shape `(N, inputs.ndim)`, specifying indices to update. `N` is the number of indices to update, must be equal to the first dimension of `updates`. |\n| `updates` | A tensor, the new values to be put to `inputs` at `indices`. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| A tensor, has the same shape and dtype as `inputs`. ||\n\n\u003cbr /\u003e"]]