tf.keras.ops.custom_gradient
Stay organized with collections
Save and categorize content based on your preferences.
Decorator to define a function with a custom gradient.
tf.keras.ops.custom_gradient(
f
)
This decorator allows fine grained control over the gradients of a sequence
for operations. This may be useful for multiple reasons, including providing
a more efficient or numerically stable gradient for a sequence of
operations.
Args |
f
|
Function f(*args) that returns a tuple
(output, grad_fn) , where:
args is a sequence of (nested structures of) tensor inputs to
the function.
output is a (nested structure of) tensor outputs of applying
operations in forward_fn to args .
grad_fn is a function with the signature grad_fn(*args,
upstream) which returns a tuple of tensors the same size as
(flattened) args : the derivatives of tensors in output with
respect to the tensors in args . upstream is a tensor or
sequence of tensors holding the initial value gradients for each
tensor in output .
|
Returns |
A function h(*args) which returns the same value as
f(*args)[0] and whose gradient is determined by
f(*args)[1] .
|
Examples:
- Backend-agnostic example.
@ops.custom_gradient
def log1pexp(x):
e = ops.exp(x)
def grad(*args, upstream=None):
if upstream is None:
(upstream,) = args
return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
return ops.log(1 + e), grad
Note that the grad function that returns gradient computation
requires args
as well as an upstream
keyword argument, depending
on the backend being set. With the JAX and TensorFlow backends,
it requires only one argument, whereas it might use the upstream
argument in the case of the PyTorch backend.
When working with TensorFlow/JAX backend, grad(upstream)
is sufficient. With PyTorch, the grad
function requires
*args
as well as upstream
, e.g. def grad(*args, upstream)
.
Follow the previous example to use @ops.custom_gradient
in
a way that is compatible with all backends.
- Here's JAX & TensorFlow-specific example:
@ops.custom_gradient
def log1pexp(x):
e = ops.exp(x)
def grad(upstream):
return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
return ops.log(1 + e), grad
- Lastly, here's a PyTorch-specific example,
using
*args
& upstream
:
@ops.custom_gradient
def log1pexp(x):
e = ops.exp(x)
def grad(*args, upstream):
return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))
return ops.log(1 + e), grad
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-06-07 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-06-07 UTC."],[],[],null,["# tf.keras.ops.custom_gradient\n\n\u003cbr /\u003e\n\n|----------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/keras-team/keras/tree/v3.3.3/keras/src/ops/core.py#L629-L710) |\n\nDecorator to define a function with a custom gradient. \n\n tf.keras.ops.custom_gradient(\n f\n )\n\nThis decorator allows fine grained control over the gradients of a sequence\nfor operations. This may be useful for multiple reasons, including providing\na more efficient or numerically stable gradient for a sequence of\noperations.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-----|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `f` | Function `f(*args)` that returns a tuple `(output, grad_fn)`, where: \u003cbr /\u003e - `args` is a sequence of (nested structures of) tensor inputs to the function. - `output` is a (nested structure of) tensor outputs of applying operations in `forward_fn` to `args`. - `grad_fn` is a function with the signature `grad_fn(*args, upstream)` which returns a tuple of tensors the same size as (flattened) `args`: the derivatives of tensors in `output` with respect to the tensors in `args`. `upstream` is a tensor or sequence of tensors holding the initial value gradients for each tensor in `output`. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| A function `h(*args)` which returns the same value as `f(*args)[0]` and whose gradient is determined by `f(*args)[1]`. ||\n\n\u003cbr /\u003e\n\n#### Examples:\n\n1. Backend-agnostic example.\n\n @ops.custom_gradient\n def log1pexp(x):\n e = ops.exp(x)\n\n def grad(*args, upstream=None):\n if upstream is None:\n (upstream,) = args\n return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))\n\n return ops.log(1 + e), grad\n\nNote that the grad function that returns gradient computation\nrequires `args` as well as an `upstream` keyword argument, depending\non the backend being set. With the JAX and TensorFlow backends,\nit requires only one argument, whereas it might use the `upstream`\nargument in the case of the PyTorch backend.\n\nWhen working with TensorFlow/JAX backend, `grad(upstream)`\nis sufficient. With PyTorch, the `grad` function requires\n`*args` as well as `upstream`, e.g. `def grad(*args, upstream)`.\nFollow the previous example to use [`@ops.custom_gradient`](../../../tf/keras/ops/custom_gradient) in\na way that is compatible with all backends.\n\n1. Here's JAX \\& TensorFlow-specific example:\n\n @ops.custom_gradient\n def log1pexp(x):\n e = ops.exp(x)\n def grad(upstream):\n return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))\n return ops.log(1 + e), grad\n\n1. Lastly, here's a PyTorch-specific example, using `*args` \\& `upstream`:\n\n @ops.custom_gradient\n def log1pexp(x):\n e = ops.exp(x)\n def grad(*args, upstream):\n return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))\n return ops.log(1 + e), grad"]]