This is a utility class for implementing optimized summary recording via a
two-function approach, specifically important for TPUs. Two tf.function
versions of a given function are created: one with soft device placement
enabled (for use on steps that require summary writing), and one with summary
writing and soft device placement entirely disabled (for use on all other
steps). This removes any performance impact of summaries on steps where they
aren't recorded (b/148418718).
This class can be used as a base class to implement summary optimizations for
a function with a specific signature. For example, to implement efficient TPU
summaries for a standard train() method (as in orbit.AbstractTrainer):
classTrainFunctionWithSummaries(orbit.utils.OptionalSummariesFunction):'''Implements a two-program approach for summaries on TPU.'''def__call__(self,num_steps):iftf.summary.should_record_summaries():output=self.with_summaries(tf.constant(1))num_steps-=1ifnum_steps >=1:output=self.without_summaries(num_steps)returnoutput
This can be used directly or to implement a decorator:
A wrapped version of the underlying function with summaries
enabled (using whatever the active predicate is for
tf.summary.record_if), and placed inside a "soft device placement"
context to enable summary recording on TPU.
without_summaries
A wrapped version of the underlying function with all
summary recording disabled.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2025-04-18 UTC."],[],[],null,["# orbit.utils.OptionalSummariesFunction\n\n\u003cbr /\u003e\n\n|------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/models/blob/v2.19.1/orbit/utils/tpu_summaries.py#L34-L145) |\n\nWrapper that provides versions of a function with and without summaries. \n\n orbit.utils.OptionalSummariesFunction(\n function, **tf_function_kwargs\n )\n\nThis is a utility class for implementing optimized summary recording via a\ntwo-function approach, specifically important for TPUs. Two [`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function)\nversions of a given `function` are created: one with soft device placement\nenabled (for use on steps that require summary writing), and one with summary\nwriting and soft device placement entirely disabled (for use on all other\nsteps). This removes any performance impact of summaries on steps where they\naren't recorded (b/148418718).\n\nThis class can be used as a base class to implement summary optimizations for\na function with a specific signature. For example, to implement efficient TPU\nsummaries for a standard `train()` method (as in [`orbit.AbstractTrainer`](../../orbit/AbstractTrainer)): \n\n class TrainFunctionWithSummaries(orbit.utils.OptionalSummariesFunction):\n '''Implements a two-program approach for summaries on TPU.'''\n\n def __call__(self, num_steps):\n if tf.summary.should_record_summaries():\n output = self.with_summaries(tf.constant(1))\n num_steps -= 1\n if num_steps \u003e= 1:\n output = self.without_summaries(num_steps)\n return output\n\nThis can be used directly or to implement a decorator: \n\n def train_function_with_summaries(function=None, **kwargs):\n if function is not None:\n return TrainFunctionWithSummaries(function, **kwargs)\n return functools.partial(TrainFunctionWithSummaries, **kwargs)\n\nThe decorator can be applied directly to `train()` methods: \n\n @train_function_with_summaries\n def train(self, num_steps):\n ...\n\nA similar approach approach can be implemented for functions with different\nsignatures.\n| **Note:** The above approach assumes that the frequency of summary writing is based on a step interval that is divisible by the number of steps executed in each call to the `train()` function. This is enforced by the [`orbit.Controller`](../../orbit/Controller).\n\nThis wrapper properly handles instance methods (see `__get__`).\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|------------------------|----------------------------------------------------------------------------------------------------------|\n| `function` | The underlying function to wrap. |\n| `**tf_function_kwargs` | Additional arguments to pass to [`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Attributes ---------- ||\n|---------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `with_summaries` | A wrapped version of the underlying function with summaries enabled (using whatever the active predicate is for [`tf.summary.record_if`](https://www.tensorflow.org/api_docs/python/tf/summary/record_if)), and placed inside a \"soft device placement\" context to enable summary recording on TPU. |\n| `without_summaries` | A wrapped version of the underlying function with all summary recording disabled. |\n\n\u003cbr /\u003e"]]