Given samples from the posterior over parameters, return the predictive
distribution over observations at each time T, given observations up
through time T-1.
Args
model
An instance of StructuralTimeSeries representing a
time-series model. This represents a joint distribution over
time-series and their parameters with batch shape [b1, ..., bN].
observed_time_series
floatTensor of shape
concat([sample_shape, model.batch_shape, [num_timesteps, 1]]) where
sample_shape corresponds to i.i.d. observations, and the trailing [1]
dimension may (optionally) be omitted if num_timesteps > 1. Any NaNs
are interpreted as missing observations; missingness may be also be
explicitly specified by passing a tfp.sts.MaskedTimeSeries instance.
parameter_samples
Python list of Tensors representing posterior samples
of model parameters, with shapes [concat([[num_posterior_draws],
param.prior.batch_shape, param.prior.event_shape]) for param in
model.parameters]. This may optionally also be a map (Python dict) of
parameter names to Tensor values.
timesteps_are_event_shape
Deprecated, for backwards compatibility only.
If False, the predictive distribution will return per-timestep
probabilities
Default value: True.
Returns
predictive_dist
a tfd.MixtureSameFamily instance with event shape
[num_timesteps] if timesteps_are_event_shape else [] and
batch shape concat([sample_shape, model.batch_shape,
[] if timesteps_are_event_shape else [num_timesteps]), with
num_posterior_draws mixture components. The tth step represents the
forecast distribution p(observed_time_series[t] |
observed_time_series[0:t-1], parameter_samples).
Examples
Suppose we've built a model and fit it to data using HMC:
frommatplotlibimportpylabaspltdefplot_one_step_predictive(observed_time_series,forecast_mean,forecast_scale):plt.figure(figsize=(12,6))num_timesteps=forecast_mean.shape[-1]c1,c2=(0.12,0.47,0.71),(1.0,0.5,0.05)plt.plot(observed_time_series,label="observed time series",color=c1)plt.plot(forecast_mean,label="one-step prediction",color=c2)plt.fill_between(np.arange(num_timesteps),forecast_mean-2*forecast_scale,forecast_mean+2*forecast_scale,alpha=0.1,color=c2)plt.legend()plot_one_step_predictive(observed_time_series,forecast_mean=predictive_means,forecast_scale=predictive_scales)
To detect anomalous timesteps, we check whether the observed value at each
step is within a 95% predictive interval, i.e., two standard deviations from
the mean:
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-11-21 UTC."],[],[],null,["# tfp.sts.one_step_predictive\n\n\u003cbr /\u003e\n\n|----------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/probability/blob/v0.23.0/tensorflow_probability/python/sts/forecast.py#L38-L194) |\n\nCompute one-step-ahead predictive distributions for all timesteps. (deprecated argument values) \n\n tfp.sts.one_step_predictive(\n model,\n observed_time_series,\n parameter_samples,\n timesteps_are_event_shape=True\n )\n\n### Used in the notebooks\n\n| Used in the tutorials |\n|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| - [Structural Time Series Modeling Case Studies: Atmospheric CO2 and Electricity Demand](https://www.tensorflow.org/probability/examples/Structural_Time_Series_Modeling_Case_Studies_Atmospheric_CO2_and_Electricity_Demand) |\n\n| **Deprecated:** SOME ARGUMENT VALUES ARE DEPRECATED: `(timesteps_are_event_shape=True)`. They will be removed after 2021-12-31. Instructions for updating: `Predictive distributions returned by`tfp.sts.one_step_predictive`will soon compute per-timestep probabilities (treating timesteps as part of the batch shape) instead of a single probability for an entire series (the current approach, in which timesteps are treated as event shape). Please update your code to pass`timesteps_are_event_shape=False\\` (this will soon be the default) and to explicitly sum over the per-timestep log probabilities if this is required.\n\nGiven samples from the posterior over parameters, return the predictive\ndistribution over observations at each time `T`, given observations up\nthrough time `T-1`.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-----------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `model` | An instance of `StructuralTimeSeries` representing a time-series model. This represents a joint distribution over time-series and their parameters with batch shape `[b1, ..., bN]`. |\n| `observed_time_series` | `float` `Tensor` of shape `concat([sample_shape, model.batch_shape, [num_timesteps, 1]])` where `sample_shape` corresponds to i.i.d. observations, and the trailing `[1]` dimension may (optionally) be omitted if `num_timesteps \u003e 1`. Any `NaN`s are interpreted as missing observations; missingness may be also be explicitly specified by passing a [`tfp.sts.MaskedTimeSeries`](../../tfp/sts/MaskedTimeSeries) instance. |\n| `parameter_samples` | Python `list` of `Tensors` representing posterior samples of model parameters, with shapes `[concat([[num_posterior_draws], param.prior.batch_shape, param.prior.event_shape]) for param in model.parameters]`. This may optionally also be a map (Python `dict`) of parameter names to `Tensor` values. |\n| `timesteps_are_event_shape` | Deprecated, for backwards compatibility only. If `False`, the predictive distribution will return per-timestep probabilities Default value: `True`. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|-------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `predictive_dist` | a `tfd.MixtureSameFamily` instance with event shape `[num_timesteps] if timesteps_are_event_shape else []` and batch shape `concat([sample_shape, model.batch_shape, [] if timesteps_are_event_shape else [num_timesteps])`, with `num_posterior_draws` mixture components. The `t`th step represents the forecast distribution `p(observed_time_series[t] | observed_time_series[0:t-1], parameter_samples)`. |\n\n\u003cbr /\u003e\n\n#### Examples\n\nSuppose we've built a model and fit it to data using HMC: \n\n day_of_week = tfp.sts.Seasonal(\n num_seasons=7,\n observed_time_series=observed_time_series,\n name='day_of_week')\n local_linear_trend = tfp.sts.LocalLinearTrend(\n observed_time_series=observed_time_series,\n name='local_linear_trend')\n model = tfp.sts.Sum(components=[day_of_week, local_linear_trend],\n observed_time_series=observed_time_series)\n\n samples, kernel_results = tfp.sts.fit_with_hmc(model, observed_time_series)\n\nPassing the posterior samples into `one_step_predictive`, we construct a\none-step-ahead predictive distribution: \n\n one_step_predictive_dist = tfp.sts.one_step_predictive(\n model, observed_time_series, parameter_samples=samples)\n\n predictive_means = one_step_predictive_dist.mean()\n predictive_scales = one_step_predictive_dist.stddev()\n\nIf using variational inference instead of HMC, we'd construct a forecast using\nsamples from the variational posterior: \n\n surrogate_posterior = tfp.sts.build_factored_surrogate_posterior(\n model=model)\n loss_curve = tfp.vi.fit_surrogate_posterior(\n target_log_prob_fn=model.joint_distribution(observed_time_series).log_prob,\n surrogate_posterior=surrogate_posterior,\n optimizer=tf.optimizers.Adam(learning_rate=0.1),\n num_steps=200)\n samples = surrogate_posterior.sample(30)\n\n one_step_predictive_dist = tfp.sts.one_step_predictive(\n model, observed_time_series, parameter_samples=samples)\n\nWe can visualize the forecast by plotting: \n\n from matplotlib import pylab as plt\n def plot_one_step_predictive(observed_time_series,\n forecast_mean,\n forecast_scale):\n plt.figure(figsize=(12, 6))\n num_timesteps = forecast_mean.shape[-1]\n c1, c2 = (0.12, 0.47, 0.71), (1.0, 0.5, 0.05)\n plt.plot(observed_time_series, label=\"observed time series\", color=c1)\n plt.plot(forecast_mean, label=\"one-step prediction\", color=c2)\n plt.fill_between(np.arange(num_timesteps),\n forecast_mean - 2 * forecast_scale,\n forecast_mean + 2 * forecast_scale,\n alpha=0.1, color=c2)\n plt.legend()\n\n plot_one_step_predictive(observed_time_series,\n forecast_mean=predictive_means,\n forecast_scale=predictive_scales)\n\nTo detect anomalous timesteps, we check whether the observed value at each\nstep is within a 95% predictive interval, i.e., two standard deviations from\nthe mean: \n\n z_scores = ((observed_time_series[..., 1:] - predictive_means[..., :-1])\n / predictive_scales[..., :-1])\n anomalous_timesteps = tf.boolean_mask(\n tf.range(1, num_timesteps),\n tf.abs(z_scores) \u003e 2.0)"]]