tf.case
bookmark_borderbookmark
Stay organized with collections
Save and categorize content based on your preferences.
Create a case operation.
tf.case(
pred_fn_pairs,
default=None,
exclusive=False,
strict=False,
name='case'
)
Used in the notebooks
See also tf.switch_case
.
The pred_fn_pairs
parameter is a list of pairs of size N.
Each pair contains a boolean scalar tensor and a python callable that
creates the tensors to be returned if the boolean evaluates to True.
default
is a callable generating a list of tensors. All the callables
in pred_fn_pairs
as well as default
(if provided) should return the same
number and types of tensors.
If exclusive==True
, all predicates are evaluated, and an exception is
thrown if more than one of the predicates evaluates to True
.
If exclusive==False
, execution stops at the first predicate which
evaluates to True, and the tensors generated by the corresponding function
are returned immediately. If none of the predicates evaluate to True, this
operation returns the tensors generated by default
.
tf.case
supports nested structures as implemented in
tf.nest
. All of the callables must return the same (possibly nested) value
structure of lists, tuples, and/or named tuples. Singleton lists and tuples
form the only exceptions to this: when returned by a callable, they are
implicitly unpacked to single values. This behavior is disabled by passing
strict=True
.
Example 1:
Pseudocode:
if (x < y) return 17;
else return 23;
Expressions:
f1 = lambda: tf.constant(17)
f2 = lambda: tf.constant(23)
r = tf.case([(tf.less(x, y), f1)], default=f2)
Example 2:
Pseudocode:
if (x < y && x > z) raise OpError("Only one predicate may evaluate to True");
if (x < y) return 17;
else if (x > z) return 23;
else return -1;
Expressions:
def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = tf.case([(tf.less(x, y), f1), (tf.greater(x, z), f2)],
default=f3, exclusive=True)
Args |
pred_fn_pairs
|
List of pairs of a boolean scalar tensor and a callable which
returns a list of tensors.
|
default
|
Optional callable that returns a list of tensors.
|
exclusive
|
True iff at most one predicate is allowed to evaluate to True .
|
strict
|
A boolean that enables/disables 'strict' mode; see above.
|
name
|
A name for this operation (optional).
|
Returns |
The tensors returned by the first pair whose predicate evaluated to True, or
those returned by default if none does.
|
Raises |
TypeError
|
If pred_fn_pairs is not a list/tuple.
|
TypeError
|
If pred_fn_pairs is a list but does not contain 2-tuples.
|
TypeError
|
If fns[i] is not callable for any i, or default is not
callable.
|
v2 compatibility
pred_fn_pairs
could be a dictionary in v1. However, tf.Tensor and
tf.Variable are no longer hashable in v2, so cannot be used as a key for a
dictionary. Please use a list or a tuple instead.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-04-26 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[],null,["# tf.case\n\n\u003cbr /\u003e\n\n|------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/ops/control_flow_case.py#L33-L135) |\n\nCreate a case operation. \n\n tf.case(\n pred_fn_pairs,\n default=None,\n exclusive=False,\n strict=False,\n name='case'\n )\n\n### Used in the notebooks\n\n| Used in the tutorials |\n|----------------------------------------------------------------------------------------------------------------|\n| - [Tutorial on Multi Armed Bandits in TF-Agents](https://www.tensorflow.org/agents/tutorials/bandits_tutorial) |\n\nSee also [`tf.switch_case`](../tf/switch_case).\n\nThe `pred_fn_pairs` parameter is a list of pairs of size N.\nEach pair contains a boolean scalar tensor and a python callable that\ncreates the tensors to be returned if the boolean evaluates to True.\n`default` is a callable generating a list of tensors. All the callables\nin `pred_fn_pairs` as well as `default` (if provided) should return the same\nnumber and types of tensors.\n\nIf `exclusive==True`, all predicates are evaluated, and an exception is\nthrown if more than one of the predicates evaluates to `True`.\nIf `exclusive==False`, execution stops at the first predicate which\nevaluates to True, and the tensors generated by the corresponding function\nare returned immediately. If none of the predicates evaluate to True, this\noperation returns the tensors generated by `default`.\n\n[`tf.case`](../tf/case) supports nested structures as implemented in\n[`tf.nest`](../tf/nest). All of the callables must return the same (possibly nested) value\nstructure of lists, tuples, and/or named tuples. Singleton lists and tuples\nform the only exceptions to this: when returned by a callable, they are\nimplicitly unpacked to single values. This behavior is disabled by passing\n`strict=True`.\n\n**Example 1:**\n\n#### Pseudocode:\n\n if (x \u003c y) return 17;\n else return 23;\n\n#### Expressions:\n\n f1 = lambda: tf.constant(17)\n f2 = lambda: tf.constant(23)\n r = tf.case([(tf.less(x, y), f1)], default=f2)\n\n**Example 2:**\n\n#### Pseudocode:\n\n if (x \u003c y && x \u003e z) raise OpError(\"Only one predicate may evaluate to True\");\n if (x \u003c y) return 17;\n else if (x \u003e z) return 23;\n else return -1;\n\n#### Expressions:\n\n def f1(): return tf.constant(17)\n def f2(): return tf.constant(23)\n def f3(): return tf.constant(-1)\n r = tf.case([(tf.less(x, y), f1), (tf.greater(x, z), f2)],\n default=f3, exclusive=True)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-----------------|------------------------------------------------------------------------------------------|\n| `pred_fn_pairs` | List of pairs of a boolean scalar tensor and a callable which returns a list of tensors. |\n| `default` | Optional callable that returns a list of tensors. |\n| `exclusive` | True iff at most one predicate is allowed to evaluate to `True`. |\n| `strict` | A boolean that enables/disables 'strict' mode; see above. |\n| `name` | A name for this operation (optional). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| The tensors returned by the first pair whose predicate evaluated to True, or those returned by `default` if none does. ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|-------------|----------------------------------------------------------------------|\n| `TypeError` | If `pred_fn_pairs` is not a list/tuple. |\n| `TypeError` | If `pred_fn_pairs` is a list but does not contain 2-tuples. |\n| `TypeError` | If `fns[i]` is not callable for any i, or `default` is not callable. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\nv2 compatibility\n----------------\n\n\u003cbr /\u003e\n\n`pred_fn_pairs` could be a dictionary in v1. However, tf.Tensor and\ntf.Variable are no longer hashable in v2, so cannot be used as a key for a\ndictionary. Please use a list or a tuple instead.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e"]]