tfp.stats.expected_calibration_error_quantiles
Stay organized with collections
Save and categorize content based on your preferences.
Expected calibration error via quantiles(exp(pred_log_prob),num_buckets)
.
tfp.stats.expected_calibration_error_quantiles(
hit,
pred_log_prob,
num_buckets=20,
axis=0,
log_space_buckets=False,
name=None
)
Calibration is a measure of how well a model reports its own uncertainty. A
model is said to be "calibrated" if buckets of predicted probabilities have
the same within bucket average accurcy. The exected calibration error is the
average absolute difference between predicted probability and (bucket) average
accuracy. That is:
bucket weight = bucket_count / tf.reduce_sum(bucket_count, axis=0)
bucket_error = abs(bucket_accuracy - bucket_confidence)
ece = tf.reduce_sum(bucket_weight * bucket_error, axis=0)
where bucket_accuracy, bucket_confidence, bucket_count
are statistics
aggregated by num_buckets
-quantiles of tf.math.exp(pred_log_prob)
. Note:
bucket_*
always have num_buckets
size for the zero-th dimension.
Args |
hit
|
bool Tensor where True means the model prediction was correct
and False means the model prediction was incorrect. Shape must
broadcast with pred_log_prob.
|
pred_log_prob
|
Tensor representing the model's predicted log probability
for the given hit . Shape must broadcast with hit .
|
num_buckets
|
int representing the number of buckets over which to
aggregate hits. Buckets are quantiles of exp(pred_log_prob) .
Default value: 20 .
|
axis
|
Dimension over which to compute buckets and aggregate stats.
Default value: 0 .
|
log_space_buckets
|
When False bucket edges are computed from
tf.math.exp(pred_log_prob) ; when True bucket edges are computed from
pred_log_prob .
Default value: False .
|
name
|
Prefer str name used for ops created by this function.
Default value: None (i.e.,
"expected_calibration_error_quantiles" ).
|
Returns |
ece
|
Expected calibration error; tf.reduce_sum(abs(bucket_accuracy -
bucket_confidence) * bucket_count, axis=0) / tf.reduce_sum(bucket_count,
axis=0) .
|
bucket_accuracy
|
Tensor representing the within bucket average hits, i.e.,
total bucket hits divided by bucket count. Has shape
tf.concat([[num_buckets], tf.shape(tf.reduce_sum(pred_log_prob,
axis=axis))], axis=0) .
|
bucket_confidence
|
Tensor representing the within bucket average
probability, i.e., total bucket predicted probability divided by bucket
count. Has shape tf.concat([[num_buckets],
tf.shape(tf.reduce_sum(pred_log_prob, axis=axis))], axis=0) .
|
bucket_count
|
Tensor representing the total number of obervations in each
bucket. Has shape tf.concat([[num_buckets],
tf.shape(tf.reduce_sum(pred_log_prob, axis=axis))], axis=0) .
|
bucket_pred_log_prob
|
Tensor representing pred_log_prob bucket edges.
Always in log space, regardless of the value of log_space_buckets .
|
bucket
|
int Tensor representing the bucket within which pred_log_prob
lies.
|
Examples
# Example 1: Generic use.
label = tf.cast([0, 0, 1, 0, 1, 1], dtype=tf.bool)
log_pred = tf.math.log([0.1, 0.05, 0.5, 0.2, 0.99, 0.99])
(
ece,
acc,
conf,
cnt,
edges,
bucket,
) = tfp.stats.expected_calibration_error_quantiles(
label, log_pred, num_buckets=3)
# ece ==> tf.Tensor(0.145, shape=(), dtype=float32)
# acc ==> tf.Tensor([0. 0. 1.], shape=(3,), dtype=float32)
# conf ==> tf.Tensor([0.075, 0.2, 0.826665], shape=(3,), dtype=float32)
# cnt ==> tf.Tensor([2. 1. 3.], shape=(3,), dtype=float32)
# Example 2: Categorgical classification.
# Assume we have evidence `x`, targets `y`, and model function `dnn`.
d = tfd.Categorical(logits=dnn(x))
def all_categories(d):
num_classes = tf.shape(d.logits_parameter())[-1]
batch_ndims = tf.size(d.batch_shape_tensor())
expand_shape = tf.pad(
[num_classes], paddings=[[0, batch_ndims]], constant_values=1)
return tf.reshape(tf.range(num_classes, dtype=d.dtype), expand_shape)
all_pred_log_prob = d.log_prob(all_categories(d))
yhat = tf.argmax(all_pred_log_prob, axis=0)
def rollaxis(x, shift):
return tf.transpose(x, tf.roll(tf.range(tf.rank(x)), shift=shift, axis=0))
pred_log_prob = tf.gather(rollaxis(all_pred_log_prob, shift=-1),
yhat,
batch_dims=len(d.batch_shape))
hit = tf.equal(y, yhat)
(
ece,
acc,
conf,
cnt,
edges,
bucket,
) = tfp.stats.expected_calibration_error_quantiles(
hit, pred_log_prob, num_buckets=10)
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-11-21 UTC.
[null,null,["Last updated 2023-11-21 UTC."],[],[],null,["# tfp.stats.expected_calibration_error_quantiles\n\n\u003cbr /\u003e\n\n|----------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/probability/blob/v0.23.0/tensorflow_probability/python/stats/calibration.py#L350-L513) |\n\nExpected calibration error via `quantiles(exp(pred_log_prob),num_buckets)`. \n\n tfp.stats.expected_calibration_error_quantiles(\n hit,\n pred_log_prob,\n num_buckets=20,\n axis=0,\n log_space_buckets=False,\n name=None\n )\n\nCalibration is a measure of how well a model reports its own uncertainty. A\nmodel is said to be \"calibrated\" if buckets of predicted probabilities have\nthe same within bucket average accurcy. The exected calibration error is the\naverage absolute difference between predicted probability and (bucket) average\naccuracy. That is: \n\n bucket weight = bucket_count / tf.reduce_sum(bucket_count, axis=0)\n bucket_error = abs(bucket_accuracy - bucket_confidence)\n ece = tf.reduce_sum(bucket_weight * bucket_error, axis=0)\n\nwhere `bucket_accuracy, bucket_confidence, bucket_count` are statistics\naggregated by `num_buckets`-quantiles of [`tf.math.exp(pred_log_prob)`](https://www.tensorflow.org/api_docs/python/tf/math/exp). Note:\n`bucket_*` always have `num_buckets` size for the zero-th dimension.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|---------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `hit` | `bool` `Tensor` where `True` means the model prediction was correct and `False` means the model prediction was incorrect. Shape must broadcast with pred_log_prob. |\n| `pred_log_prob` | `Tensor` representing the model's predicted log probability for the given `hit`. Shape must broadcast with `hit`. |\n| `num_buckets` | `int` representing the number of buckets over which to aggregate hits. Buckets are quantiles of `exp(pred_log_prob)`. Default value: `20`. |\n| `axis` | Dimension over which to compute buckets and aggregate stats. Default value: `0`. |\n| `log_space_buckets` | When `False` bucket edges are computed from [`tf.math.exp(pred_log_prob)`](https://www.tensorflow.org/api_docs/python/tf/math/exp); when `True` bucket edges are computed from `pred_log_prob`. Default value: `False`. |\n| `name` | Prefer `str` name used for ops created by this function. Default value: `None` (i.e., `\"expected_calibration_error_quantiles\"`). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `ece` | Expected calibration error; `tf.reduce_sum(abs(bucket_accuracy - bucket_confidence) * bucket_count, axis=0) / tf.reduce_sum(bucket_count, axis=0)`. |\n| `bucket_accuracy` | `Tensor` representing the within bucket average hits, i.e., total bucket hits divided by bucket count. Has shape `tf.concat([[num_buckets], tf.shape(tf.reduce_sum(pred_log_prob, axis=axis))], axis=0)`. |\n| `bucket_confidence` | `Tensor` representing the within bucket average probability, i.e., total bucket predicted probability divided by bucket count. Has shape `tf.concat([[num_buckets], tf.shape(tf.reduce_sum(pred_log_prob, axis=axis))], axis=0)`. |\n| `bucket_count` | `Tensor` representing the total number of obervations in each bucket. Has shape `tf.concat([[num_buckets], tf.shape(tf.reduce_sum(pred_log_prob, axis=axis))], axis=0)`. |\n| `bucket_pred_log_prob` | `Tensor` representing `pred_log_prob` bucket edges. Always in log space, regardless of the value of `log_space_buckets`. |\n| `bucket` | `int` `Tensor` representing the bucket within which `pred_log_prob` lies. |\n\n\u003cbr /\u003e\n\n#### Examples\n\n # Example 1: Generic use.\n label = tf.cast([0, 0, 1, 0, 1, 1], dtype=tf.bool)\n log_pred = tf.math.log([0.1, 0.05, 0.5, 0.2, 0.99, 0.99])\n (\n ece,\n acc,\n conf,\n cnt,\n edges,\n bucket,\n ) = tfp.stats.expected_calibration_error_quantiles(\n label, log_pred, num_buckets=3)\n # ece ==\u003e tf.Tensor(0.145, shape=(), dtype=float32)\n # acc ==\u003e tf.Tensor([0. 0. 1.], shape=(3,), dtype=float32)\n # conf ==\u003e tf.Tensor([0.075, 0.2, 0.826665], shape=(3,), dtype=float32)\n # cnt ==\u003e tf.Tensor([2. 1. 3.], shape=(3,), dtype=float32)\n\n # Example 2: Categorgical classification.\n # Assume we have evidence `x`, targets `y`, and model function `dnn`.\n d = tfd.Categorical(logits=dnn(x))\n def all_categories(d):\n num_classes = tf.shape(d.logits_parameter())[-1]\n batch_ndims = tf.size(d.batch_shape_tensor())\n expand_shape = tf.pad(\n [num_classes], paddings=[[0, batch_ndims]], constant_values=1)\n return tf.reshape(tf.range(num_classes, dtype=d.dtype), expand_shape)\n all_pred_log_prob = d.log_prob(all_categories(d))\n yhat = tf.argmax(all_pred_log_prob, axis=0)\n def rollaxis(x, shift):\n return tf.transpose(x, tf.roll(tf.range(tf.rank(x)), shift=shift, axis=0))\n pred_log_prob = tf.gather(rollaxis(all_pred_log_prob, shift=-1),\n yhat,\n batch_dims=len(d.batch_shape))\n hit = tf.equal(y, yhat)\n (\n ece,\n acc,\n conf,\n cnt,\n edges,\n bucket,\n ) = tfp.stats.expected_calibration_error_quantiles(\n hit, pred_log_prob, num_buckets=10)"]]