0% found this document useful (0 votes)
2 views

rethinking_function_space_vari

The document presents a novel approach to function-space variational inference in Bayesian neural networks (BNNs), addressing the challenges of defining meaningful prior distributions and scaling to high-dimensional data. The proposed method involves linearizing the BNN's posterior predictive distribution, leading to improved predictive accuracy and uncertainty estimates compared to traditional parameter-space variational inference. Empirical evaluations demonstrate the effectiveness of this approach on various high-dimensional prediction tasks, outperforming existing methods in terms of accuracy and uncertainty calibration.

Uploaded by

daomanhkaiyako
Copyright
© © All Rights Reserved
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
2 views

rethinking_function_space_vari

The document presents a novel approach to function-space variational inference in Bayesian neural networks (BNNs), addressing the challenges of defining meaningful prior distributions and scaling to high-dimensional data. The proposed method involves linearizing the BNN's posterior predictive distribution, leading to improved predictive accuracy and uncertainty estimates compared to traditional parameter-space variational inference. Empirical evaluations demonstrate the effectiveness of this approach on various high-dimensional prediction tasks, outperforming existing methods in terms of accuracy and uncertainty calibration.

Uploaded by

daomanhkaiyako
Copyright
© © All Rights Reserved
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 14

3rd Symposium on Advances in Approximate Bayesian Inference, 2020 1–14

Rethinking Function-Space Variational Inference in


Bayesian Neural Networks

Anonymous Authors
Anonymous Institution

Abstract
Bayesian neural networks (bnns) define distributions over functions induced by distributions
over parameters. In practice, this model specification makes it difficult to define and
use meaningful prior distributions over functions that could aid in training. What’s
more, previous attempts at defining an explicit function-space variational objective for
approximate inference in bnns require approximations that do not scale to high-dimensional
data. We propose a new function-space approach to variational inference in bnns and
derive a tractable variational by linearizing the bnn’s posterior predictive distribution about
its mean parameters, allowing function-space variational inference to be scaled to large
and high-dimensional datasets. We evaluate this approach empirically and show that it
leads to models with competitive predictive accuracy and significantly improved predictive
uncertainty estimates compared to parameter-space variational inference.

1. Introduction
Approximate inference in Bayesian neural networks (bnns) typically involves performing
probabilistic inference directly over a set of stochastic network parameters. Unfortunately,
standard approaches for parameter-space inference in bnns often do not result in approximate
posterior predictive distributions that reliably exhibit high predictive uncertainty away from
the training data or under distribution shift, making them of limited use in practice.
Instead of explicitly performing approximate inference over bnn parameters, we propose a
method for tractable and efficient approximate inference in bnns by inferring an approximate
posterior distribution over the network parameters implicitly and optimizing a variational
objective over the induced distribution over functions instead. This way, it is possible to
to better control the distribution over functions induced by the network parameters and
obtain higher-quality uncertainty estimates than state-of-the-art Bayesian and non-Bayesian
methods. We evaluate the resulting approximate posterior predictive distribution empirically
on a number of high-dimensional prediction tasks and demonstrate that it outperforms
related methods in terms of accuracy, out-of-distribution detection, and calibration.
The main contributions of this paper is the conceptualization and formalization of a new
approach to function-space variational inference in bnns that is more scalable and better
performing than previously proposed approaches. We address the conceptual and practical
limitations of prior work, perform an extensive empirical evaluation on high-dimensional
prediction tasks, and conduct ablation studies on the proposed method.

c A. Authors.
Rethinking function-Space Variational Inference in Bayesian Neural Networks

Predictive Mean
4 4 0.24
2 0.90
Function Draw
0.21
Training Data 0.75
1 2 2 0.18

V[y|D; x]
E[y|D; x]
0.60 0.15
0
0 0.45 0 0.12
−1 0.09
0.30
−2 −2 0.06
−2 0.15 0.03

−6 −4 −2 0 2 4 6 0.00 0.00
−4 −2 0 2 4 −4 −2 0 2 4

(a) Predictive Distribution (b) Predictive Mean (c) Predictive Variance


Figure 1: 1D regression on the Snelson dataset and binary classification on the Two Moons dataset.
The plots show the predictive distribution of a bnn, obtained via function-space variational inference
(fsvi) under the local approximation described in Section 4. For further plots, see Appendix 4.

2. Preliminaries
def
We consider supervised learning tasks on data D = {(xn , yn )}N n=1 with inputs xn ∈ X ⊆ R
D

and targets yn ∈ Y, where Y ⊆ R for regression and Y ⊆ {0, 1} for classification tasks.
Q Q

Bayesian Neural Networks Consider a neural network f (x; θ) parameterized by stochas-


tic parameters θ ∈ RP and define a conditional distribution of targets given function values
f (x; θ): p(y | x, θ; f ). For θ ∼ p(θ), we thus obtain a joint distribution p(y | x, θ; f ) p(θ),
where the semicolon denotes a dependency on some non-stochastic quantity (in this case, the
architecture of the neural network), and the resulting distribution of the targets under a given
architecture and parameter vector is determined by the model p(y | x, θ; f ). For example, for
regression and softmax classification tasks, the conditional distribution p(y | x, θ; f ) would be
a Gaussian distribution with mean f (x; θ) (and some variance) or a categorical distribution
defined via a softmax model (Bridle, 1990), respectively.
Mean-Field Variational Inference Since bnns are non-linear in their parameters, exact
inference over the network parameters is analytically intractable. Mean-field variational
inference is a variational approach for finding an approximate posterior distribution over
network parameters. For further details, see Appendix 2.

3. A Function-Space Perspective on Variational Inference


In this section, we present a function-space perspective on variational inference in bnns and
discuss shortcomings of prior approaches to function-space variational inference (fsvi).
Consider again the probabilistic model p(y | x, θ; f ) p(θ) defined in Section 2. Instead of
defining the probabilistic model explicitly in terms of the parameters, we will instead define it
explicitly in terms of the stochastic functions induced by the stochastic parameters θ ∼ p(θ).
Specifically, we consider a probabilistic model of the targets and a latent random function f
distributed according to some prior distribution over functions, p(f | x; θ), parameterized by
θ. For a model p(y | f (x; θ)), we can then express the joint distribution over targets and
latent random functions as p(y, f | x) = p(y | f )p(f | x; θ) and frame the inference problem
of finding a posterior distribution over functions, p(f |D), variationally as minimizing the KL
divergence DKL (q(f ; θ) k p(f |D)), where q(f ; θ) and p(f |D) are distributions over functions
defined on an infinite index set. For a likelihood function defined on a finite set of training

2
Rethinking function-Space Variational Inference in Bayesian Neural Networks

targets y, we can express this minimization problem as maximizing the variational objective
def
F(q(f ; θ)) = Eq(fX ;θ) [log p(y | fXD )] − DKL (q(f ; θ) k p(f )), (1)
D

where DKL (q(f ; θ) k p(f )) is again a KL divergence between distributions over functions.
For a measure-theoretic derivation of this result, see Appendix 1. Unfortunately, it is not
immediately clear how to evaluate such a KL divergence if q(f ; θ) and p(f ) are bnn posterior
and prior predictive distributions. In an effort to make this objective more tractable, Sun
et al. (2019) show that DKL (q(f ; θ) k p(f )) can be expressed as the supremum of the KL
divergence from q(f ; θ) to p(f ) over all finite sets of evaluation points, XI , that is,

F(q(f ; θ)) = Eq(fX ;θ) [log p(y | fXD )] − sup DKL (q(fXI ; θ) k p(fXI )). (2)
D
n∈N,XI ∈X n

Unfortunately, this objective is still extremely challenging to estimate in practice: The


supremum cannot be found analytically, searching for it iteratively may lead to undesirable
optimization behavior (Sun et al., 2019), and the KL divergence term itself is intractable
as well—even for finite XI . What’s more, existing approaches to approximating the KL
divergence do not scale to high input or target dimensions (Sun et al., 2019).
We propose a fundamentally different approach to function-space variational inference.
Starting from Equation (1), we consider a locally accurate approximation to q(f ; θ) and p(f )
by linearizing them about their mean parameters. Assuming a Gaussian distribution over the
network parameters, this approximation turns q(f ; θ) and p(f ) into Gaussian processes. To
evaluate the resulting locally accurate KL divergence, we make a prior conditional matching
assumption, which results in a tractable KL divergence evaluated at a finite number of
evaluation points. We present this approximation in more detail next.

4. Function-Space Variational Inference via Local Linearization


The primary obstacle to making the objective in Equation (1) tractable is the KL divergence
from q(f ; θ) to p(f ). We start by considering the distribution over parameters that gives
rise to the distribution over functions q(f ; θ). Specifically,

Assumption 1 (Mean-Field Variational Distribution Over Parameters):


Assume a factorized Gaussian variational distribution, q(θ) = N (θ | µ, Σ).

We denote the distribution over functions induced by q(θ) as q(f ; θ). To obtain a tractable
distribution over functions, we make a local approximation about the bnn’s mean parameters:

Assumption 2 (Linearization about Mean Parameters):


Linearize the stochastic function f about its mean parameters µ, to obtain the locally
accurate approximation f (x; θ) ≈ f˜(x; θ) ≡ f (x; µ) + Jµ (x)(θ − µ), where Jµ denotes the
Jacobian ∂f ∂θ
(x;θ)
|θ=µ and θ ∼ N (θ | µ, Σ).

Due to local linearity, the approximation f˜(x; θ) will be accurate for realizations θ̂ close to
µ, and hence, the distribution over f˜(x; θ) (induced by θ) will be close to the distribution
over f (x; θ) for small variance parameters Σ. Under the two assumptions above, we obtain
a locally accurate approximate distribution over functions:

3
Rethinking function-Space Variational Inference in Bayesian Neural Networks

Proposition 1 (Predictive Distribution of Linearized BNN) Consider θ, f (x; θ), and


f˜(x; θ) as defined above. The mean and variance of q̃(f˜(x; θ)) are given by
Eq̃(f˜(x;θ)) [f˜(x; θ)] = f (x; µ) and V(f˜(x; θ)) = Jµ (x)ΣJµ (x0 )>

and the predictive distribution q̃ over f˜(x; θ) is a Gaussian process given by


q̃(f˜(x); θ) = GP(f˜ | f (x; µ), Jµ (x)ΣJµ (x0 )> ). (3)
Proof See Appendix 1.
Under the linearization about µ, we obtain a local approximation to the objective:
F(q(f ; θ)) ≈ Eq(fX ;θ) [log p(y | fXD )] − DKL (q̃(f˜); θ) k p̃(f˜)), (4)
D

where p̃(f˜) is a local approximation to a prior distribution over functions induced by some
Gaussian prior distribution over the network parameters. The approximate variational
objective in Equation (4) includes a KL divergence between two Gaussian processes q̃(f˜)
and p̃(f˜). Unfortunately, in the absence of additional assumptions about q̃(f˜) and p̃(f˜), this
KL divergence is still intractable (de G. Matthews et al., 2016).
To obtain a tractable variational objective, we make an assumption about how the
variational predictive distribution q̃(f˜; θ) = q̃(f˜X∗ , f˜XD , f˜XI ; θ) factorizes, where XD is a
def
finite set of training inputs, XI is a finite set of so-called inducing points, X∗ = X \{XD , XI }
is an infinite set of evaluation points containing all points in the data space except for XD
and XI , and f˜X are function values at evaluation points X. Specifically, we assume prior
conditional matching, that is:
Assumption 3 (Prior Conditional Matching):
Let the variational distribution factorize as
q̃(f˜X∗ , f˜XD , f˜XI ; θ) = p̃(f˜X∗ | f˜XD )p̃(f˜XD | f˜XI )q̃(f˜XI ; θ).
def
(5)
Under Assumption 3, we can now simplify the function-space variational objective as follows:
Proposition 2 Under Assumptions 1, 2, and 3, we obtain the variational objective
F(q(f ; θ)) = Eq(fX ;θ) [log p(y | fXD )] − DKL (q̃(f˜XI ; θ) k p̃(f˜XI )), (6)
D

where DKL (q̃(f˜XI ; θ) k p̃(f˜XI )) is analytically tractable.


Proof Sketch We can then express the variational objective in Equation (4) as
F(q(f ; θ)) = Eq(fX ;θ) [log p(y | fXD )] − DKL (q̃(f˜X∗ , f˜XD , f˜XI ; θ) k p̃(f˜X∗ , f˜XD , f˜XI )), (7)
D

and under prior conditional matching, this objectives becomes


DKL (p̃(f˜X∗ | f˜XD )p̃(f˜XD | f˜XI )q̃(f˜XI ) k p̃(f˜X∗ | f˜XD )p̃(f˜XD | f˜XI )p̃(f˜XI )), (8)

which simplifies to DKL (q̃(f˜XI ; θ) k p̃(f˜XI )), which is a KL divergence between multivariate
Gaussian distributions and can be expressed analytically. For a full, measure-theoretic proof,
see de G. Matthews et al. (2016, Sections 3.2 and 3.3)

4
Rethinking function-Space Variational Inference in Bayesian Neural Networks

Algorithm 1 fsvi: Function-Space Variational Inference


Input: data D, size |XI |, learning rate η, prior mean µ0 , prior variance Σ0 ;
1 Initialization: θ ∼ pθ0 , I ∼ pI ; while `(q(f ; θ))−DKL (q̃(f˜XI ; θ) k p̃(f˜XI )) not converged do
2 XI ⊂ I, B ⊂ D
3 Θ(XI , XI ) = Jµ (XI )ΣJµ (XI )>
`(q(f ; θ)) = S1 Si=1 (XB ,yB ) log p(yB | f (XB , i ; θ)), i ∼ N ( | 0, I)
P P
4

P|X | [Σ0 ]jj [Θ(XI ,XI )]jj +(f (XI ;µ)−µ0 )2
5 DKL (q̃(f˜XI ; θ) k p̃(f˜XI )) = j=1I − 21 log √ + 2[Σ0 ]jj
[Θ(XI ,XI )]jj
 
6 θ ← θ − η∇θ `(q(f ; θ)) − DKL (q̃(f˜XI ; θ) k p̃(f˜XI ))

5. Empirical Evaluation
We evaluate the proposed function-space variational inference (fsvi) method on illustrative
regression and classification tasks as well as on high-dimensional classification tasks prior
work (Sun et al., 2019) was unable to scale to. We show that fsvi (sometimes significantly)
outperforms existing Bayesian and non-Bayesian methods in terms of their in-distribution
uncertainty calibration and out-of-distribution uncertainty estimation.

5.1. Illustrative Examples


Figure 1 shows the posterior predictive distribution obtained via fsvi on a 1D regression
and a binary classification problem. On the regression problem (Figure 1(a)), the posterior
predictive distribution is certain about the training data, becomes somewhat uncertain when
interpolating between datapoints (in the interval [−1, 0]), and grows very uncertain away
from the training data, as desired. Figures 1(b) and 1(c) show the posterior predictive
mean and variance obtained via fsvi on the Two Moons classification task. As can be
seen in Figure 1(b), the predictive mean (in the form of binary class probabilities) is highly
confident around the data manifold and converges to 0.5, the maximum level of uncertainty,
further away from it. Similarly, the epistemic uncertainty over the class probabilities shown
in Figure 1(c), is low on and close to the data manifold and increases further away from it.

5.2. Evaluation of In- and Out-of-Distribution Performance


To evaluate the predictive performance of fsvi, we consider a selection of widely used high-
dimensional classification datasets to which prior approaches to function-space variational
inference were unable to scale.
Table 1 and Figure 2 show that fsvi consistently either performs on par with related
methods or outperforms them. More specifically, Figure 2, shows that fsvi has as a predictive
entropy on out-of-distribution inputs that is as high or higher than that of ensembles or
bnns with mfvi, indicating high uncertainty under distribution shift, as we would like. This
high level of uncertainty is reflected in the corresponding receiver operating characteristic
(ROC) curve, which shows that fsvi outperforms other methods at distinguishing in- from
out-of-distribution inputs. Finally, as can be seen in the rightmost column of Figure 2,

5
Rethinking function-Space Variational Inference in Bayesian Neural Networks

Table 1: Comparison of in- and out-of-distribution performance metrics. AUROC: area under ROC
curve. ECE: expected calibration error. 1 Computed from mutual information scores. 2 Computed
from predictive entropy scores. 3 bnn trained with mfvi (Blundell et al., 2015). 4 bnn map estimate
obtained by training a deterministic neural network with weight regularization.

fashionMNIST/MNIST CIFAR-10/SVHN
1 2
Model Accuracy ECE AUROC AUROC Accuracy ECE AUROC1 AUROC2
Ours: fsvi 91% 0.02 91% 86% 85% 0.03 54% 99%
mfvi3 91% 0.04 85% 84% 85% 0.04 68% 97%
map4 91% 0.07 72% 76% 86% 0.09 72% 90%
Ensemble 93% 0.02 91% 90% 91% 0.03 88% 98%

Density of examples p(y|x) ≥ τ


1.0 1.0
FSVI FSVI
8000
# of Examples

MFVI MFVI
True Positive Rate

0.8 0.8
Ensemble of NNs
6000 Ensemble of NNs
Single NN
0.6 0.6 Single NN
4000
0.4 0.4
FSVI
2000 MFVI
0.2 0.2
Ensemble of NNs
0 Single NN
0.0 0.0
0.0 0.5 1.0 1.5 2.0 2.5 0.0 0.2 0.4 0.6 0.8 1.0
0.0 0.2 0.4 0.6 0.8 1.0
Entropy (nats) False Positive Rate Confidence Threshold τ

(a) Predictive Entropy (b) ROC Curve (c) OOD Confidence

Density of examples p(y|x) ≥ τ


1.0 1.0
True Positive Rate

6000 FSVI FSVI


# of Examples

MFVI MFVI
5000 0.8 0.8
Ensemble of NNs
Ensemble of NNs
4000 Single NN
0.6 0.6 Single NN
3000
0.4 FSVI 0.4
2000
MFVI
1000 0.2 Ensemble of NNs 0.2
Single NN
0 0.0
0.0
0.0 0.5 1.0 1.5 2.0 2.5 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0

Entropy (nats) False Positive Rate Confidence Threshold τ

(d ) Predictive Entropy (e) ROC Curve (f ) OOD Confidence


Figure 2: Uncertainty evaluation metrics for in- and out-of-distribution (OOD) prediction. Top
row: results for models trained on fashionMNIST, with MNIST images as OOD inputs; Bottom row:
results for models trained on CIFAR-10, with SVHN images as OOD inputs. Left column: closer to
diagonal is better; Center column: closer to top left corner is better; Right column: closer to bottom
left corner is better. For further details, see Appendix 4.

fsvi exhibits low confidence on out-of-distribution inputs, as desired. Table 1 further


corroborates these observations and shows that fsvi leads to high predictive accuracy and
good uncertainty estimates. In Appendix 4, we provide an ablation study on the effect of the
number of inducing points on a bnn’s predictive accuracy and the quality of its predictive
uncertainty estimates.

6. Conclusion
We proposed a fundamentally new approach to variational inference in bnns, where the
parameters are inferred indirectly by performing inference over the induced distribution over
functions. We showed that fsvi exhibits an in- and out-of-distribution predictive performance
on par or better than related state-of-the-art Bayesian and non-Bayesian approaches.

6
Rethinking function-Space Variational Inference in Bayesian Neural Networks

References
Charles Blundell, Julien Cornebise, Koray Kavukcuoglu, and Daan Wierstra. Weight uncer-
tainty in neural network. volume 37 of Proceedings of Machine Learning Research, pages
1613–1622, Lille, France, 07–09 Jul 2015. PMLR. URL http://proceedings.mlr.press/
v37/blundell15.html.

John S. Bridle. Probabilistic interpretation of feedforward classification network outputs, with


relationships to statistical pattern recognition. In Françoise Fogelman Soulié and Jeanny
Hérault, editors, Neurocomputing, pages 227–236, Berlin, Heidelberg, 1990. Springer Berlin
Heidelberg. ISBN 978-3-642-76153-9.

Alexander G. de G. Matthews, James Hensman, Richard Turner, and Zoubin Ghahramani.


On sparse variational methods and the kullback-leibler divergence between stochastic
processes. volume 51 of Proceedings of Machine Learning Research, pages 231–239,
Cadiz, Spain, 09–11 May 2016. PMLR. URL http://proceedings.mlr.press/v51/
matthews16.html.

M. J. Schervish. Theory of Statistics. Springer-Verlag, New York, NY, 1995.

Shengyang Sun, Guodong Zhang, Jiaxin Shi, and Roger B. Grosse. Functional variational
bayesian neural networks. In 7th International Conference on Learning Representations,
ICLR 2019, New Orleans, LA, USA, May 6-9, 2019. OpenReview.net, 2019. URL
https://openreview.net/forum?id=rkxacs0qY7.

Joost van Amersfoort, Lewis Smith, Yee Whye Teh, and Yarin Gal. Uncertainty estimation
using a single deep deterministic neural network. In International Conference on Machine
Learning, 2020.

7
Rethinking function-Space Variational Inference in Bayesian Neural Networks

Supplementary Materials
1. Theoretical Results
1.1. Linearization of Bayesian Neural Network Predictive Distribution
Proposition 1 (Predictive Distribution of Linearized BNN):
Consider θ, f (x; θ), and f˜(x; θ) as defined above. The mean and variance of q̃(f˜(x; θ)) are
given by
Eq̃(f˜(x;θ)) [f˜(x; θ)] = f (x; µ) and V(f˜(x; θ)) = Jµ (x)ΣJµ (x0 )>

and the predictive distribution q̃ over f˜(x; θ) is given by


q̃(f˜(x); θ) = GP(f˜ | f (x; µ), Jµ (x)ΣJµ (x0 )> ). (1.1)

Proof Since θ ∼ N (θ|µ, Σ), and f˜(x; θ) = f (x; µ)+Jµ (x)(θ−µ) is a linear transformation
of θ, f˜(x; θ) is a Gaussian process
q̃(f˜(x); θ) = GP(f˜|m(x), S(x, x0 )) (1.2)
with some predictive mean m(x) and predictive covariance S(x, x0 ). To find q̃(f˜(x; θ)), we
need to find the predictive mean m(x) and the predictive covariance S(x, x0 ), which, by
definition, we can write as:
m(x) = E[f˜(x; θ)] (1.3)
and
S(x, x0 ) = Cov(f˜(x; θ), f˜(x0 ; θ)) (1.4)
= E[(f˜(x; θ) − E[f˜(x; θ)]) (f˜(x0 ; θ) − E[f˜(x0 ; θ)])> ]. (1.5)
To see that m(x) = E[f˜(x; θ)] = f (x; µ), note that, by linearity of expectation, we have
m(x) = E[f˜(x; θ)] (1.6)
= E[f (x; µ) + Jµ (x)(θ − µ)] (1.7)
= f (x; µ) + Jµ (x)(E[θ] − µ) (1.8)
= f (x; µ). (1.9)
To see that S(x, x0 ) = Cov(f˜(x; θ), f˜(x0 ; θ)) = Jµ (x)ΣJµ (x0 )> , note that in general,
Cov(X, X) = E[XX> ] + E[X]E[X]> , and hence,
Cov(f˜(x; θ), f˜(x0 ; θ)) = E[f˜(x; θ)f˜(x0 ; θ)> ] − E[f˜(x; θ)]E[f˜(x0 ; θ)]> . (1.10)
We already know that E[f˜(x; θ)] = fµ (x), so we only need to find E[f˜(x; θ)f˜(x0 ; θ)> ]:
Eq(θ) [f˜(x; θ)f˜(x0 ; θ)> ]
(1.11)
=Eq(θ) [(f (x; µ) + Jµ (x)(θ − µ))(f (x0 ; µ) + Jµ (x0 )(θ − µ))> ]
p p
=E∼N (|0,I) [(fµ (x) + Jµ (x)(µ +  Σ − µ))(fµ (x) + Jµ (x0 )(µ +  Σ − µ))> ]
(1.12)
p p
=E∼N (|0,I) [(f (x; µ) + Jµ (x)( Σ))(f (x0 ; µ) + Jµ (x0 )( Σ))> ], (1.13)

8
Rethinking function-Space Variational Inference in Bayesian Neural Networks

where the reparameterization is possible by Assumption 1. With some algebra, this expression
can be further simplified to

Eq(θ) [f˜(x; θ)f˜(x0 ; θ)> ] = f (x; µ)f (x0 ; µ)> + Jµ (x)ΣJµ (x0 )> . (1.14)

Since the fµ (x)fµ (x0 )> terms cancel out, we obtain the covariance function

S(x, x0 ) = Cov(f˜(x; θ), f˜(x0 ; θ)) = Jµ (x)ΣJµ (x0 )> ,


def
(1.15)

which concludes the proof.

1.2. Function-Space Variational Objective


This proof follows steps from de G. Matthews et al. (2016). Consider measures P̂ and P
both of which define distributions over some function f , indexed by an infinite index set X.
Let D be a dataset and let XD denote a set of inputs and yD a set of targets. Consider the
measure-theoretic version of Bayes’ Theorem (Schervish, 1995):

dP̂ pX (Y | f )
(f ) = , (1.16)
dP p(Y )
R
where pX (Y | f ) is the likelihood and p(Y ) = RX pX (Y | f )dP (f ) is the marginal likelihood.
We assume that the likelihood function is evaluated at a finite subset of the index set X.
Denote by πC : RX → RC a projection function that takes a function and returns the same
function, evaluated at a finite set of points C, so we can write

dP̂ dP̂XD p(yD | πXD (f ))


(f ) = (πXD (f )) = , (1.17)
dP dPXD p(yD )
R
and similarly, the marginal likelihood becomes p(yD ) = XD p(yD | fXD )dPXD (fXD ). Now,
considering the measure-theoretic version of the KL divergence between an approximating
stochastic process Q and a posterior stochastic process P̂ , we can write

dQ dP̂
Z Z
DKL (Q k P̂ ) = log (f )dQ(f ) − log (f )dQ(f ), (1.18)
RX dP RX dP
where P is some prior stochastic process. Now, considering the second term, we can apply
the measure-theoretic Bayes’ Theorem to obtain

dP̂ dP̂XD
Z Z
log (f )dQ(f ) = log (fXD )dQXD (fXD ) (1.19)
RX dP R D
X dPXD
= EQXD [log p (yD | fXD )] − log p(yD ), (1.20)

giving us
dQ
Z
DKL (Q k P̂ ) = log (f )dQ(f ) − EQXD [log p (yD | fXD )] + log p(yD ). (1.21)
RX dP

9
Rethinking function-Space Variational Inference in Bayesian Neural Networks

Rearranging, we can get


dQ
Z
p(yD ) = EQXD [log p (yD | fXD )] − log (f )dQ(f ) + DKL (Q k P ]) (1.22)
RX dP
dQ
Z
≥ EQXD [log p (yD | fXD )] − log (f )dQ(f ). (1.23)
RX dP

By the measure-theoretic definition of the KL divergence, we can thus write


dQ
Z
p(yD ) ≥ EQXD [log p (yD | fXD )] − log (f )dQ(f ) (1.24)
RX dP
= EQXD [log p (yD | fXD )] − DKL (Q k P ), (1.25)

which corresponds to the expression for the function-space variational objective in Section 3.

10
Rethinking function-Space Variational Inference in Bayesian Neural Networks

2. Further Background
Mean-Field Variational Inference Since bnns are non-linear in their parameters, exact
inference over the network parameters is analytically intractable. Mean-field variational
inference is a variational approach for finding an approximate posterior distribution over
network parameters and and use it to draw samples from a bnn’s approximate posterior
predictive distribution. Under a mean-field assumption, the joint distribution over the
stochastic network parameters has a diagonal covariance, rendering the variational parameters
independent of one another. Furthermore, to obtain a tractable variational objective, prior
works assume the prior and variational distributions over the networ parameters to be
Gaussian (Blundell et al., 2015). This approximation results in a tractable and scalable
variational objective, given by
def
L(q(θ)) = Eq(fX ;θ) [log p(y | fXD )] − DKL (q(θ) k p(θ)), (2.26)
D

where p(θ) = N (θ | 0, I), and fXD are function values at the training inputs XD .
Gaussian Processes Gaussian process (gp) models define distributions over functions.
Unlike in bnns where a prior distribution over parameters implicitly induces a prior
distribution over functions, Gaussian processes explicitly define distributions over func-
tions by specifying a covariance function over possible function realizations. A gp prior
p(f | x) = GP(m(x), k(x, x0 )) is completely specified by its mean and covariance function,
m(·) and k(·, ·).

3. Model Details
For the experiments presented in this paper, we diagonalized the covariance of the linearized
bnn in the KL divergence. While this simplification is not necessary, it speeds up training
larger numbers of inducing points. Furthermore, we assumed a bnn prior that is locally
equal to a gp with zero mean and a diagonal covariance scaled by 1e6 on all evaluation
points. We chose the set of inducing points by uniformly sampling within some range, e.g.,
within the range of admissible pixel values for prediction tasks with image inputs.
For the Two Moons and Snelson experiments, we use an fully-connected neural network
with two hidden layers and 100 hidden units per layer. For the fashionMNIST/MNIST
experiments, we use a three-layer convolutional neural network without batch normalization.
For the CIFAR-10/SVHN experiments, we use a seven layer convolutional neural network
without batch normalization. We use the Adam optimizer for all experiments with learning
rates between η = 1e-4 and η = 1e-3 (depending on which yielded the best performance).
The deterministic neural networks that were used for the ensemble were trained with a
weight decay of λ = 1e-1. We used early stopping when training bnns with mfvi to avoid
overfitting. We did not use early stopping to train bnns with fsvi.

11
Rethinking function-Space Variational Inference in Bayesian Neural Networks

4. Further Empirical Results

2 Predictive Mean 2 Predictive Mean


Function Draw Function Draw
Training Data Training Data
1 1

0 0

−1 −1

−2 −2

−6 −4 −2 0 2 4 6 −6 −4 −2 0 2 4 6

(a) fsvi Posterior Predictive Distribution (b) mfvi Posterior Predictive Distribution

Figure 3: 1D Regression on the Snelson datasets. The plots show the predictive distribution of a
bnn obtained via function-space variational inference (fsvi) under the local approximation described
in Section 4 (Figure 3(a)) and obtained via mean-field variational inference (mfvi) (Figure 3(b)).
The plots show noisy data (in red), the posterior predictive means (in black), ten function draws
from the bnns (in blue), and two standard deviations of the empirical distribution over functions.

4 4 0.24
0.90
0.21
2 0.75 2 0.18

V[y|D; x]
E[y|D; x]

0.60 0.15
0 0.45 0 0.12
0.09
0.30
−2 −2 0.06
0.15 0.03
0.00 0.00
−4 −2 0 2 4 −4 −2 0 2 4
(a) Posterior Predictive Mean (b) Posterior Predictive Variance

Figure 4: Binary classification on the Two Moons dataset. The plots show the posterior predictive
mean (Figure 4(a)) and variance (Figure 4(b)) of a bnn obtained via fsvi. They represent the expected
class probabilities and the model’s epistemic uncertainty over the class probabilities, respectively.
The predictive distribution is able to faithfully capture the geometry of the data manifold and
exhibits high uncertainty over the class probabilities in areas of the data space of which the data is
not informative. In contrast, related methods, such as ensembles, are unable to accurately capture
the geometry of the data manifold only exhibit high uncertainty around the decision boundary (van
Amersfoort et al., 2020).

12
Rethinking function-Space Variational Inference in Bayesian Neural Networks

1.0 1.0

0.8 0.8
Model Accuracy

Model Accuracy
0.6 0.6

0.4 0.4
FSVI FSVI
MFVI MFVI
0.2 0.2
Ensemble of NNs Ensemble of NNs
Single NN Single NN
0.0 0.0

0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0
Model Confidence Model Confidence

(a) fashionMNIST (b) CIFAR-10

Figure 5: Reliability Diagram. Figure 5(a) shows the reliability of different models, expressed as a
plot of model accuracy against model confidence for models trained on fashionMNIST and evaluated
on a fashionMNIST test set. Figure 5(b) shows the reliability of different models, expressed as a plot
of model accuracy against model confidence for models trained on CIFAR-10 and evaluated on a
CIFAR-10 test set.

1.0 1.0
True Positive Rate

True Positive Rate

0.8 0.8

0.6 0.6

0.4 0.4
FSVI FSVI
MFVI MFVI
0.2 0.2
Ensemble of NNs Ensemble of NNs
Single NN Single NN
0.0 0.0

0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0
False Positive Rate False Positive Rate

(a) fashionMNIST/MNIST (b) CIFAR-10/SVHN

Figure 6: Receiver operating characteristic (ROC) curve. Figure 6(a) shows the ROC for different
predictive distributions for the binary classification problem of distinguishing in-distribution inputs
(fashionMNIST) from out-of-distribution inputs (MNIST). Figure 6(b) shows the ROC for different
predictive distributions for the binary classification problem of distinguishing in-distribution inputs
(CIFAR-10) from out-of-distribution inputs (SVHN).

13
Rethinking function-Space Variational Inference in Bayesian Neural Networks

Density of examples p(y|x) ≥ τ

Density of examples p(y|x) ≥ τ


1.0 1.0
FSVI FSVI
MFVI MFVI
0.8 0.8
Ensemble of NNs Ensemble of NNs
Single NN Single NN
0.6 0.6

0.4 0.4

0.2 0.2

0.0 0.0
0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0
Confidence Threshold τ Confidence Threshold τ

(a) fashionMNIST/MNIST (b) CIFAR-10/SVHN

Figure 7: Confidence on Out-of-Distribution Inputs. Figure 7(a) shows the confidence of different
predictive means of models trained on fashionMNIST, evaluated on out-of-distribution inputs
(MNIST). Figure 7(b) shows the confidence of different predictive means of models trained on
CIFAR-10, evaluated on out-of-distribution inputs (SVHN). Curves further to the left are better, as
they indicate that a model assigns low confidence to a higher number of out-of-distribution inputs.
Curves that cover a larger area (i.e., that are further to the top left) are better, as they indicate a
higher true than false positive rate.

92.8 0.058 0.45 0.95


Accuracy ECE Entropy AUROC
Entropy (in nats)
Accuracy (in %)

92.4 0.055 0.40 0.91

AUROC
ECE

92.0 0.052 0.35 0.87

91.6 0.049 0.30 0.83

91.2 0.046 0.25 0.79


5 10 15 20 25 30 35 40 5 10 15 20 25 30 35 40
Number of inducing points Number of inducing points

(a) Effect on In-Distribution Predictions (b) Effect on Out-of-Distribution Predictions

Figure 8: Figure 8(a) shows the effect of increasing the number of inducing points on in-distribution
predictions for bnns trained viafsvi on fashionMNIST. Increasing the number of inducing points
does not affect the test accuracy, but does increase the predictive mean’s expected calibration error.
Figure 8(b) shows that increasing the the number of inducing points also increases the predictive
entropy on out-of-distribution inputs as well as the area under the receiver operating characteristic
curve computed from it.

14

You might also like