rethinking_function_space_vari
rethinking_function_space_vari
Anonymous Authors
Anonymous Institution
Abstract
Bayesian neural networks (bnns) define distributions over functions induced by distributions
over parameters. In practice, this model specification makes it difficult to define and
use meaningful prior distributions over functions that could aid in training. What’s
more, previous attempts at defining an explicit function-space variational objective for
approximate inference in bnns require approximations that do not scale to high-dimensional
data. We propose a new function-space approach to variational inference in bnns and
derive a tractable variational by linearizing the bnn’s posterior predictive distribution about
its mean parameters, allowing function-space variational inference to be scaled to large
and high-dimensional datasets. We evaluate this approach empirically and show that it
leads to models with competitive predictive accuracy and significantly improved predictive
uncertainty estimates compared to parameter-space variational inference.
1. Introduction
Approximate inference in Bayesian neural networks (bnns) typically involves performing
probabilistic inference directly over a set of stochastic network parameters. Unfortunately,
standard approaches for parameter-space inference in bnns often do not result in approximate
posterior predictive distributions that reliably exhibit high predictive uncertainty away from
the training data or under distribution shift, making them of limited use in practice.
Instead of explicitly performing approximate inference over bnn parameters, we propose a
method for tractable and efficient approximate inference in bnns by inferring an approximate
posterior distribution over the network parameters implicitly and optimizing a variational
objective over the induced distribution over functions instead. This way, it is possible to
to better control the distribution over functions induced by the network parameters and
obtain higher-quality uncertainty estimates than state-of-the-art Bayesian and non-Bayesian
methods. We evaluate the resulting approximate posterior predictive distribution empirically
on a number of high-dimensional prediction tasks and demonstrate that it outperforms
related methods in terms of accuracy, out-of-distribution detection, and calibration.
The main contributions of this paper is the conceptualization and formalization of a new
approach to function-space variational inference in bnns that is more scalable and better
performing than previously proposed approaches. We address the conceptual and practical
limitations of prior work, perform an extensive empirical evaluation on high-dimensional
prediction tasks, and conduct ablation studies on the proposed method.
c A. Authors.
Rethinking function-Space Variational Inference in Bayesian Neural Networks
Predictive Mean
4 4 0.24
2 0.90
Function Draw
0.21
Training Data 0.75
1 2 2 0.18
V[y|D; x]
E[y|D; x]
0.60 0.15
0
0 0.45 0 0.12
−1 0.09
0.30
−2 −2 0.06
−2 0.15 0.03
−6 −4 −2 0 2 4 6 0.00 0.00
−4 −2 0 2 4 −4 −2 0 2 4
2. Preliminaries
def
We consider supervised learning tasks on data D = {(xn , yn )}N n=1 with inputs xn ∈ X ⊆ R
D
and targets yn ∈ Y, where Y ⊆ R for regression and Y ⊆ {0, 1} for classification tasks.
Q Q
2
Rethinking function-Space Variational Inference in Bayesian Neural Networks
targets y, we can express this minimization problem as maximizing the variational objective
def
F(q(f ; θ)) = Eq(fX ;θ) [log p(y | fXD )] − DKL (q(f ; θ) k p(f )), (1)
D
where DKL (q(f ; θ) k p(f )) is again a KL divergence between distributions over functions.
For a measure-theoretic derivation of this result, see Appendix 1. Unfortunately, it is not
immediately clear how to evaluate such a KL divergence if q(f ; θ) and p(f ) are bnn posterior
and prior predictive distributions. In an effort to make this objective more tractable, Sun
et al. (2019) show that DKL (q(f ; θ) k p(f )) can be expressed as the supremum of the KL
divergence from q(f ; θ) to p(f ) over all finite sets of evaluation points, XI , that is,
F(q(f ; θ)) = Eq(fX ;θ) [log p(y | fXD )] − sup DKL (q(fXI ; θ) k p(fXI )). (2)
D
n∈N,XI ∈X n
We denote the distribution over functions induced by q(θ) as q(f ; θ). To obtain a tractable
distribution over functions, we make a local approximation about the bnn’s mean parameters:
Due to local linearity, the approximation f˜(x; θ) will be accurate for realizations θ̂ close to
µ, and hence, the distribution over f˜(x; θ) (induced by θ) will be close to the distribution
over f (x; θ) for small variance parameters Σ. Under the two assumptions above, we obtain
a locally accurate approximate distribution over functions:
3
Rethinking function-Space Variational Inference in Bayesian Neural Networks
where p̃(f˜) is a local approximation to a prior distribution over functions induced by some
Gaussian prior distribution over the network parameters. The approximate variational
objective in Equation (4) includes a KL divergence between two Gaussian processes q̃(f˜)
and p̃(f˜). Unfortunately, in the absence of additional assumptions about q̃(f˜) and p̃(f˜), this
KL divergence is still intractable (de G. Matthews et al., 2016).
To obtain a tractable variational objective, we make an assumption about how the
variational predictive distribution q̃(f˜; θ) = q̃(f˜X∗ , f˜XD , f˜XI ; θ) factorizes, where XD is a
def
finite set of training inputs, XI is a finite set of so-called inducing points, X∗ = X \{XD , XI }
is an infinite set of evaluation points containing all points in the data space except for XD
and XI , and f˜X are function values at evaluation points X. Specifically, we assume prior
conditional matching, that is:
Assumption 3 (Prior Conditional Matching):
Let the variational distribution factorize as
q̃(f˜X∗ , f˜XD , f˜XI ; θ) = p̃(f˜X∗ | f˜XD )p̃(f˜XD | f˜XI )q̃(f˜XI ; θ).
def
(5)
Under Assumption 3, we can now simplify the function-space variational objective as follows:
Proposition 2 Under Assumptions 1, 2, and 3, we obtain the variational objective
F(q(f ; θ)) = Eq(fX ;θ) [log p(y | fXD )] − DKL (q̃(f˜XI ; θ) k p̃(f˜XI )), (6)
D
which simplifies to DKL (q̃(f˜XI ; θ) k p̃(f˜XI )), which is a KL divergence between multivariate
Gaussian distributions and can be expressed analytically. For a full, measure-theoretic proof,
see de G. Matthews et al. (2016, Sections 3.2 and 3.3)
4
Rethinking function-Space Variational Inference in Bayesian Neural Networks
5. Empirical Evaluation
We evaluate the proposed function-space variational inference (fsvi) method on illustrative
regression and classification tasks as well as on high-dimensional classification tasks prior
work (Sun et al., 2019) was unable to scale to. We show that fsvi (sometimes significantly)
outperforms existing Bayesian and non-Bayesian methods in terms of their in-distribution
uncertainty calibration and out-of-distribution uncertainty estimation.
5
Rethinking function-Space Variational Inference in Bayesian Neural Networks
Table 1: Comparison of in- and out-of-distribution performance metrics. AUROC: area under ROC
curve. ECE: expected calibration error. 1 Computed from mutual information scores. 2 Computed
from predictive entropy scores. 3 bnn trained with mfvi (Blundell et al., 2015). 4 bnn map estimate
obtained by training a deterministic neural network with weight regularization.
fashionMNIST/MNIST CIFAR-10/SVHN
1 2
Model Accuracy ECE AUROC AUROC Accuracy ECE AUROC1 AUROC2
Ours: fsvi 91% 0.02 91% 86% 85% 0.03 54% 99%
mfvi3 91% 0.04 85% 84% 85% 0.04 68% 97%
map4 91% 0.07 72% 76% 86% 0.09 72% 90%
Ensemble 93% 0.02 91% 90% 91% 0.03 88% 98%
MFVI MFVI
True Positive Rate
0.8 0.8
Ensemble of NNs
6000 Ensemble of NNs
Single NN
0.6 0.6 Single NN
4000
0.4 0.4
FSVI
2000 MFVI
0.2 0.2
Ensemble of NNs
0 Single NN
0.0 0.0
0.0 0.5 1.0 1.5 2.0 2.5 0.0 0.2 0.4 0.6 0.8 1.0
0.0 0.2 0.4 0.6 0.8 1.0
Entropy (nats) False Positive Rate Confidence Threshold τ
MFVI MFVI
5000 0.8 0.8
Ensemble of NNs
Ensemble of NNs
4000 Single NN
0.6 0.6 Single NN
3000
0.4 FSVI 0.4
2000
MFVI
1000 0.2 Ensemble of NNs 0.2
Single NN
0 0.0
0.0
0.0 0.5 1.0 1.5 2.0 2.5 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0
6. Conclusion
We proposed a fundamentally new approach to variational inference in bnns, where the
parameters are inferred indirectly by performing inference over the induced distribution over
functions. We showed that fsvi exhibits an in- and out-of-distribution predictive performance
on par or better than related state-of-the-art Bayesian and non-Bayesian approaches.
6
Rethinking function-Space Variational Inference in Bayesian Neural Networks
References
Charles Blundell, Julien Cornebise, Koray Kavukcuoglu, and Daan Wierstra. Weight uncer-
tainty in neural network. volume 37 of Proceedings of Machine Learning Research, pages
1613–1622, Lille, France, 07–09 Jul 2015. PMLR. URL http://proceedings.mlr.press/
v37/blundell15.html.
Shengyang Sun, Guodong Zhang, Jiaxin Shi, and Roger B. Grosse. Functional variational
bayesian neural networks. In 7th International Conference on Learning Representations,
ICLR 2019, New Orleans, LA, USA, May 6-9, 2019. OpenReview.net, 2019. URL
https://openreview.net/forum?id=rkxacs0qY7.
Joost van Amersfoort, Lewis Smith, Yee Whye Teh, and Yarin Gal. Uncertainty estimation
using a single deep deterministic neural network. In International Conference on Machine
Learning, 2020.
7
Rethinking function-Space Variational Inference in Bayesian Neural Networks
Supplementary Materials
1. Theoretical Results
1.1. Linearization of Bayesian Neural Network Predictive Distribution
Proposition 1 (Predictive Distribution of Linearized BNN):
Consider θ, f (x; θ), and f˜(x; θ) as defined above. The mean and variance of q̃(f˜(x; θ)) are
given by
Eq̃(f˜(x;θ)) [f˜(x; θ)] = f (x; µ) and V(f˜(x; θ)) = Jµ (x)ΣJµ (x0 )>
Proof Since θ ∼ N (θ|µ, Σ), and f˜(x; θ) = f (x; µ)+Jµ (x)(θ−µ) is a linear transformation
of θ, f˜(x; θ) is a Gaussian process
q̃(f˜(x); θ) = GP(f˜|m(x), S(x, x0 )) (1.2)
with some predictive mean m(x) and predictive covariance S(x, x0 ). To find q̃(f˜(x; θ)), we
need to find the predictive mean m(x) and the predictive covariance S(x, x0 ), which, by
definition, we can write as:
m(x) = E[f˜(x; θ)] (1.3)
and
S(x, x0 ) = Cov(f˜(x; θ), f˜(x0 ; θ)) (1.4)
= E[(f˜(x; θ) − E[f˜(x; θ)]) (f˜(x0 ; θ) − E[f˜(x0 ; θ)])> ]. (1.5)
To see that m(x) = E[f˜(x; θ)] = f (x; µ), note that, by linearity of expectation, we have
m(x) = E[f˜(x; θ)] (1.6)
= E[f (x; µ) + Jµ (x)(θ − µ)] (1.7)
= f (x; µ) + Jµ (x)(E[θ] − µ) (1.8)
= f (x; µ). (1.9)
To see that S(x, x0 ) = Cov(f˜(x; θ), f˜(x0 ; θ)) = Jµ (x)ΣJµ (x0 )> , note that in general,
Cov(X, X) = E[XX> ] + E[X]E[X]> , and hence,
Cov(f˜(x; θ), f˜(x0 ; θ)) = E[f˜(x; θ)f˜(x0 ; θ)> ] − E[f˜(x; θ)]E[f˜(x0 ; θ)]> . (1.10)
We already know that E[f˜(x; θ)] = fµ (x), so we only need to find E[f˜(x; θ)f˜(x0 ; θ)> ]:
Eq(θ) [f˜(x; θ)f˜(x0 ; θ)> ]
(1.11)
=Eq(θ) [(f (x; µ) + Jµ (x)(θ − µ))(f (x0 ; µ) + Jµ (x0 )(θ − µ))> ]
p p
=E∼N (|0,I) [(fµ (x) + Jµ (x)(µ + Σ − µ))(fµ (x) + Jµ (x0 )(µ + Σ − µ))> ]
(1.12)
p p
=E∼N (|0,I) [(f (x; µ) + Jµ (x)( Σ))(f (x0 ; µ) + Jµ (x0 )( Σ))> ], (1.13)
8
Rethinking function-Space Variational Inference in Bayesian Neural Networks
where the reparameterization is possible by Assumption 1. With some algebra, this expression
can be further simplified to
Eq(θ) [f˜(x; θ)f˜(x0 ; θ)> ] = f (x; µ)f (x0 ; µ)> + Jµ (x)ΣJµ (x0 )> . (1.14)
Since the fµ (x)fµ (x0 )> terms cancel out, we obtain the covariance function
dP̂ pX (Y | f )
(f ) = , (1.16)
dP p(Y )
R
where pX (Y | f ) is the likelihood and p(Y ) = RX pX (Y | f )dP (f ) is the marginal likelihood.
We assume that the likelihood function is evaluated at a finite subset of the index set X.
Denote by πC : RX → RC a projection function that takes a function and returns the same
function, evaluated at a finite set of points C, so we can write
dQ dP̂
Z Z
DKL (Q k P̂ ) = log (f )dQ(f ) − log (f )dQ(f ), (1.18)
RX dP RX dP
where P is some prior stochastic process. Now, considering the second term, we can apply
the measure-theoretic Bayes’ Theorem to obtain
dP̂ dP̂XD
Z Z
log (f )dQ(f ) = log (fXD )dQXD (fXD ) (1.19)
RX dP R D
X dPXD
= EQXD [log p (yD | fXD )] − log p(yD ), (1.20)
giving us
dQ
Z
DKL (Q k P̂ ) = log (f )dQ(f ) − EQXD [log p (yD | fXD )] + log p(yD ). (1.21)
RX dP
9
Rethinking function-Space Variational Inference in Bayesian Neural Networks
which corresponds to the expression for the function-space variational objective in Section 3.
10
Rethinking function-Space Variational Inference in Bayesian Neural Networks
2. Further Background
Mean-Field Variational Inference Since bnns are non-linear in their parameters, exact
inference over the network parameters is analytically intractable. Mean-field variational
inference is a variational approach for finding an approximate posterior distribution over
network parameters and and use it to draw samples from a bnn’s approximate posterior
predictive distribution. Under a mean-field assumption, the joint distribution over the
stochastic network parameters has a diagonal covariance, rendering the variational parameters
independent of one another. Furthermore, to obtain a tractable variational objective, prior
works assume the prior and variational distributions over the networ parameters to be
Gaussian (Blundell et al., 2015). This approximation results in a tractable and scalable
variational objective, given by
def
L(q(θ)) = Eq(fX ;θ) [log p(y | fXD )] − DKL (q(θ) k p(θ)), (2.26)
D
where p(θ) = N (θ | 0, I), and fXD are function values at the training inputs XD .
Gaussian Processes Gaussian process (gp) models define distributions over functions.
Unlike in bnns where a prior distribution over parameters implicitly induces a prior
distribution over functions, Gaussian processes explicitly define distributions over func-
tions by specifying a covariance function over possible function realizations. A gp prior
p(f | x) = GP(m(x), k(x, x0 )) is completely specified by its mean and covariance function,
m(·) and k(·, ·).
3. Model Details
For the experiments presented in this paper, we diagonalized the covariance of the linearized
bnn in the KL divergence. While this simplification is not necessary, it speeds up training
larger numbers of inducing points. Furthermore, we assumed a bnn prior that is locally
equal to a gp with zero mean and a diagonal covariance scaled by 1e6 on all evaluation
points. We chose the set of inducing points by uniformly sampling within some range, e.g.,
within the range of admissible pixel values for prediction tasks with image inputs.
For the Two Moons and Snelson experiments, we use an fully-connected neural network
with two hidden layers and 100 hidden units per layer. For the fashionMNIST/MNIST
experiments, we use a three-layer convolutional neural network without batch normalization.
For the CIFAR-10/SVHN experiments, we use a seven layer convolutional neural network
without batch normalization. We use the Adam optimizer for all experiments with learning
rates between η = 1e-4 and η = 1e-3 (depending on which yielded the best performance).
The deterministic neural networks that were used for the ensemble were trained with a
weight decay of λ = 1e-1. We used early stopping when training bnns with mfvi to avoid
overfitting. We did not use early stopping to train bnns with fsvi.
11
Rethinking function-Space Variational Inference in Bayesian Neural Networks
0 0
−1 −1
−2 −2
−6 −4 −2 0 2 4 6 −6 −4 −2 0 2 4 6
(a) fsvi Posterior Predictive Distribution (b) mfvi Posterior Predictive Distribution
Figure 3: 1D Regression on the Snelson datasets. The plots show the predictive distribution of a
bnn obtained via function-space variational inference (fsvi) under the local approximation described
in Section 4 (Figure 3(a)) and obtained via mean-field variational inference (mfvi) (Figure 3(b)).
The plots show noisy data (in red), the posterior predictive means (in black), ten function draws
from the bnns (in blue), and two standard deviations of the empirical distribution over functions.
4 4 0.24
0.90
0.21
2 0.75 2 0.18
V[y|D; x]
E[y|D; x]
0.60 0.15
0 0.45 0 0.12
0.09
0.30
−2 −2 0.06
0.15 0.03
0.00 0.00
−4 −2 0 2 4 −4 −2 0 2 4
(a) Posterior Predictive Mean (b) Posterior Predictive Variance
Figure 4: Binary classification on the Two Moons dataset. The plots show the posterior predictive
mean (Figure 4(a)) and variance (Figure 4(b)) of a bnn obtained via fsvi. They represent the expected
class probabilities and the model’s epistemic uncertainty over the class probabilities, respectively.
The predictive distribution is able to faithfully capture the geometry of the data manifold and
exhibits high uncertainty over the class probabilities in areas of the data space of which the data is
not informative. In contrast, related methods, such as ensembles, are unable to accurately capture
the geometry of the data manifold only exhibit high uncertainty around the decision boundary (van
Amersfoort et al., 2020).
12
Rethinking function-Space Variational Inference in Bayesian Neural Networks
1.0 1.0
0.8 0.8
Model Accuracy
Model Accuracy
0.6 0.6
0.4 0.4
FSVI FSVI
MFVI MFVI
0.2 0.2
Ensemble of NNs Ensemble of NNs
Single NN Single NN
0.0 0.0
0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0
Model Confidence Model Confidence
Figure 5: Reliability Diagram. Figure 5(a) shows the reliability of different models, expressed as a
plot of model accuracy against model confidence for models trained on fashionMNIST and evaluated
on a fashionMNIST test set. Figure 5(b) shows the reliability of different models, expressed as a plot
of model accuracy against model confidence for models trained on CIFAR-10 and evaluated on a
CIFAR-10 test set.
1.0 1.0
True Positive Rate
0.8 0.8
0.6 0.6
0.4 0.4
FSVI FSVI
MFVI MFVI
0.2 0.2
Ensemble of NNs Ensemble of NNs
Single NN Single NN
0.0 0.0
0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0
False Positive Rate False Positive Rate
Figure 6: Receiver operating characteristic (ROC) curve. Figure 6(a) shows the ROC for different
predictive distributions for the binary classification problem of distinguishing in-distribution inputs
(fashionMNIST) from out-of-distribution inputs (MNIST). Figure 6(b) shows the ROC for different
predictive distributions for the binary classification problem of distinguishing in-distribution inputs
(CIFAR-10) from out-of-distribution inputs (SVHN).
13
Rethinking function-Space Variational Inference in Bayesian Neural Networks
0.4 0.4
0.2 0.2
0.0 0.0
0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0
Confidence Threshold τ Confidence Threshold τ
Figure 7: Confidence on Out-of-Distribution Inputs. Figure 7(a) shows the confidence of different
predictive means of models trained on fashionMNIST, evaluated on out-of-distribution inputs
(MNIST). Figure 7(b) shows the confidence of different predictive means of models trained on
CIFAR-10, evaluated on out-of-distribution inputs (SVHN). Curves further to the left are better, as
they indicate that a model assigns low confidence to a higher number of out-of-distribution inputs.
Curves that cover a larger area (i.e., that are further to the top left) are better, as they indicate a
higher true than false positive rate.
AUROC
ECE
Figure 8: Figure 8(a) shows the effect of increasing the number of inducing points on in-distribution
predictions for bnns trained viafsvi on fashionMNIST. Increasing the number of inducing points
does not affect the test accuracy, but does increase the predictive mean’s expected calibration error.
Figure 8(b) shows that increasing the the number of inducing points also increases the predictive
entropy on out-of-distribution inputs as well as the area under the receiver operating characteristic
curve computed from it.
14