tfp.experimental.util.make_trainable
Stay organized with collections
Save and categorize content based on your preferences.
Constructs a distribution or bijector instance with trainable parameters.
tfp.experimental.util.make_trainable(
*args, seed=None, **kwargs
)
Used in the notebooks
This is a convenience method that instantiates a class with trainable
parameters. Parameters are randomly initialized, and transformed to enforce
any domain constraints. This method assumes that the class exposes a
parameter_properties
method annotating its trainable parameters, and that
the caller provides any additional (non-trainable) arguments required by the
class.
Args |
cls
|
Python class that implements cls.parameter_properties() , e.g., a TFP
distribution (tfd.Normal ) or bijector (tfb.Scale ).
|
initial_parameters
|
a dictionary containing initial values for some or
all of the parameters to cls , OR a Python callable with signature
value = parameter_init_fn(parameter_name, shape, dtype, seed,
constraining_bijector) . If a dictionary is provided, any parameters not
specified will be initialized to a random value in their domain.
Default value: None (equivalent to {} ; all parameters are
initialized randomly).
|
batch_and_event_shape
|
Optional int Tensor desired shape of samples
(for distributions) or inputs (for bijectors), used to determine the shape
of the trainable parameters.
Default value: () .
|
parameter_dtype
|
Optional float dtype for trainable variables.
|
**init_kwargs
|
Additional keyword arguments passed to cls.__init__() to
specify any non-trainable parameters. If a value is passed for
an otherwise-trainable parameter---for example,
trainable(tfd.Normal, scale=1.) ---it will be taken as a fixed value and
no variable will be constructed for that parameter. seed: PRNG seed; see tfp.random.sanitize_seed for details.
|
Returns |
instance
|
instance parameterized by trainable tf.Variable s.
|
Example
Suppose we want to fit a normal distribution to observed data. We could
of course just examine the empirical mean and standard deviation of the data:
samples = [4.57, 6.37, 5.93, 7.98, 2.03, 3.59, 8.55, 3.45, 5.06, 6.44]
model = tfd.Normal(
loc=tf.reduce_mean(samples), # ==> 5.40
scale=tf.math.reduce_std(sample)) # ==> 1.95
and this would be a very sensible approach. But that's boring, so instead,
let's do way more work to get the same result. We'll build a trainable normal
distribution, and explicitly optimize to find the maximum-likelihood estimate
for the parameters given our data:
model = tfp.util.make_trainable(tfd.Normal)
losses = tfp.math.minimize(
lambda: -model.log_prob(samples),
optimizer=tf.optimizers.Adam(0.1),
num_steps=200)
print('Fit Normal distribution with mean {} and stddev {}'.format(
model.mean(),
model.stddev()))
In this trivial case, doing the explicit optimization has few advantages over
the first approach in which we simply matched the empirical moments of the
data. However, trainable distributions are useful more generally. For example,
they can enable maximum-likelihood estimation of distributions when a
moment-matching estimator is not available, and they can also serve as
surrogate posteriors in variational inference.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-11-21 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-11-21 UTC."],[],[],null,["# tfp.experimental.util.make_trainable\n\n\u003cbr /\u003e\n\n|----------------------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/probability/blob/v0.23.0/tensorflow_probability/python/internal/trainable_state_util.py#L327-L343) |\n\nConstructs a distribution or bijector instance with trainable parameters. \n\n tfp.experimental.util.make_trainable(\n *args, seed=None, **kwargs\n )\n\n### Used in the notebooks\n\n| Used in the tutorials |\n|----------------------------------------------------------------------------------------------------------------------|\n| - [TFP Release Notes notebook (0.13.0)](https://www.tensorflow.org/probability/examples/TFP_Release_Notebook_0_13_0) |\n\nThis is a convenience method that instantiates a class with trainable\nparameters. Parameters are randomly initialized, and transformed to enforce\nany domain constraints. This method assumes that the class exposes a\n`parameter_properties` method annotating its trainable parameters, and that\nthe caller provides any additional (non-trainable) arguments required by the\nclass.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `cls` | Python class that implements `cls.parameter_properties()`, e.g., a TFP distribution (`tfd.Normal`) or bijector (`tfb.Scale`). |\n| `initial_parameters` | a dictionary containing initial values for some or all of the parameters to `cls`, OR a Python `callable` with signature `value = parameter_init_fn(parameter_name, shape, dtype, seed, constraining_bijector)`. If a dictionary is provided, any parameters not specified will be initialized to a random value in their domain. Default value: `None` (equivalent to `{}`; all parameters are initialized randomly). |\n| `batch_and_event_shape` | Optional int `Tensor` desired shape of samples (for distributions) or inputs (for bijectors), used to determine the shape of the trainable parameters. Default value: `()`. |\n| `parameter_dtype` | Optional float `dtype` for trainable variables. |\n| `**init_kwargs` | Additional keyword arguments passed to `cls.__init__()` to specify any non-trainable parameters. If a value is passed for an otherwise-trainable parameter---for example, `trainable(tfd.Normal, scale=1.)`---it will be taken as a fixed value and no variable will be constructed for that parameter. seed: PRNG seed; see [`tfp.random.sanitize_seed`](../../../tfp/random/sanitize_seed) for details. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|------------|---------------------------------------------------------------------------------------------------------------|\n| `instance` | instance parameterized by trainable [`tf.Variable`](https://www.tensorflow.org/api_docs/python/tf/Variable)s. |\n\n\u003cbr /\u003e\n\n#### Example\n\nSuppose we want to fit a normal distribution to observed data. We could\nof course just examine the empirical mean and standard deviation of the data: \n\n samples = [4.57, 6.37, 5.93, 7.98, 2.03, 3.59, 8.55, 3.45, 5.06, 6.44]\n model = tfd.Normal(\n loc=tf.reduce_mean(samples), # ==\u003e 5.40\n scale=tf.math.reduce_std(sample)) # ==\u003e 1.95\n\nand this would be a very sensible approach. But that's boring, so instead,\nlet's do way more work to get the same result. We'll build a trainable normal\ndistribution, and explicitly optimize to find the maximum-likelihood estimate\nfor the parameters given our data: \n\n model = tfp.util.make_trainable(tfd.Normal)\n losses = tfp.math.minimize(\n lambda: -model.log_prob(samples),\n optimizer=tf.optimizers.Adam(0.1),\n num_steps=200)\n print('Fit Normal distribution with mean {} and stddev {}'.format(\n model.mean(),\n model.stddev()))\n\nIn this trivial case, doing the explicit optimization has few advantages over\nthe first approach in which we simply matched the empirical moments of the\ndata. However, trainable distributions are useful more generally. For example,\nthey can enable maximum-likelihood estimation of distributions when a\nmoment-matching estimator is not available, and they can also serve as\nsurrogate posteriors in variational inference."]]