tf.switch_case
Stay organized with collections
Save and categorize content based on your preferences.
Create a switch/case operation, i.e.
tf.switch_case(
branch_index, branch_fns, default=None, name='switch_case'
)
an integer-indexed conditional.
See also tf.case
.
This op can be substantially more efficient than tf.case
when exactly one
branch will be selected. tf.switch_case
is more like a C++ switch/case
statement than tf.case
, which is more like an if/elif/elif/else chain.
The branch_fns
parameter is either a dict from int
to callables, or list
of (int
, callable) pairs, or simply a list of callables (in which case the
index is implicitly the key). The branch_index
Tensor
is used to select an
element in branch_fns
with matching int
key, falling back to default
if none match, or max(keys)
if no default
is provided. The keys must form
a contiguous set from 0
to len(branch_fns) - 1
.
tf.switch_case
supports nested structures as implemented in tf.nest
. All
callables must return the same (possibly nested) value structure of lists,
tuples, and/or named tuples.
Example:
Pseudocode:
switch (branch_index) { // c-style switch
case 0: return 17;
case 1: return 31;
default: return -1;
}
or
branches = {0: lambda: 17, 1: lambda: 31}
branches.get(branch_index, lambda: -1)()
Expressions:
def f1(): return tf.constant(17)
def f2(): return tf.constant(31)
def f3(): return tf.constant(-1)
r = tf.switch_case(branch_index, branch_fns={0: f1, 1: f2}, default=f3)
# Equivalent: tf.switch_case(branch_index, branch_fns={0: f1, 1: f2, 2: f3})
Args |
branch_index
|
An int Tensor specifying which of branch_fns should be
executed.
|
branch_fns
|
A dict mapping int s to callables, or a list of (int ,
callable) pairs, or simply a list of callables (in which case the index
serves as the key). Each callable must return a matching structure of
tensors.
|
default
|
Optional callable that returns a structure of tensors.
|
name
|
A name for this operation (optional).
|
Returns |
The tensors returned by the callable identified by branch_index , or those
returned by default if no key matches and default was provided, or those
returned by the max-keyed branch_fn if no default is provided.
|
Raises |
TypeError
|
If branch_fns is not a list/dictionary.
|
TypeError
|
If branch_fns is a list but does not contain 2-tuples or
callables.
|
TypeError
|
If fns[i] is not callable for any i, or default is not
callable.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-04-26 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[],null,["# tf.switch_case\n\n\u003cbr /\u003e\n\n|--------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/ops/control_flow_switch_case.py#L181-L253) |\n\nCreate a switch/case operation, i.e.\n\n#### View aliases\n\n\n**Compat aliases for migration**\n\nSee\n[Migration guide](https://www.tensorflow.org/guide/migrate) for\nmore details.\n\n[`tf.compat.v1.switch_case`](https://www.tensorflow.org/api_docs/python/tf/switch_case)\n\n\u003cbr /\u003e\n\n tf.switch_case(\n branch_index, branch_fns, default=None, name='switch_case'\n )\n\nan integer-indexed conditional.\n\nSee also [`tf.case`](../tf/case).\n\nThis op can be substantially more efficient than [`tf.case`](../tf/case) when exactly one\nbranch will be selected. [`tf.switch_case`](../tf/switch_case) is more like a C++ switch/case\nstatement than [`tf.case`](../tf/case), which is more like an if/elif/elif/else chain.\n\nThe `branch_fns` parameter is either a dict from `int` to callables, or list\nof (`int`, callable) pairs, or simply a list of callables (in which case the\nindex is implicitly the key). The `branch_index` `Tensor` is used to select an\nelement in `branch_fns` with matching `int` key, falling back to `default`\nif none match, or `max(keys)` if no `default` is provided. The keys must form\na contiguous set from `0` to `len(branch_fns) - 1`.\n\n[`tf.switch_case`](../tf/switch_case) supports nested structures as implemented in [`tf.nest`](../tf/nest). All\ncallables must return the same (possibly nested) value structure of lists,\ntuples, and/or named tuples.\n\n**Example:**\n\n#### Pseudocode:\n\n switch (branch_index) { // c-style switch\n case 0: return 17;\n case 1: return 31;\n default: return -1;\n }\n\nor \n\n branches = {0: lambda: 17, 1: lambda: 31}\n branches.get(branch_index, lambda: -1)()\n\n#### Expressions:\n\n def f1(): return tf.constant(17)\n def f2(): return tf.constant(31)\n def f3(): return tf.constant(-1)\n r = tf.switch_case(branch_index, branch_fns={0: f1, 1: f2}, default=f3)\n # Equivalent: tf.switch_case(branch_index, branch_fns={0: f1, 1: f2, 2: f3})\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `branch_index` | An int Tensor specifying which of `branch_fns` should be executed. |\n| `branch_fns` | A `dict` mapping `int`s to callables, or a `list` of (`int`, callable) pairs, or simply a list of callables (in which case the index serves as the key). Each callable must return a matching structure of tensors. |\n| `default` | Optional callable that returns a structure of tensors. |\n| `name` | A name for this operation (optional). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| The tensors returned by the callable identified by `branch_index`, or those returned by `default` if no key matches and `default` was provided, or those returned by the max-keyed `branch_fn` if no `default` is provided. ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|-------------|-----------------------------------------------------------------------|\n| `TypeError` | If `branch_fns` is not a list/dictionary. |\n| `TypeError` | If `branch_fns` is a list but does not contain 2-tuples or callables. |\n| `TypeError` | If `fns[i]` is not callable for any i, or `default` is not callable. |\n\n\u003cbr /\u003e"]]