A Simple Baseline For Bayesian Uncertainty in Deep Learning
A Simple Baseline For Bayesian Uncertainty in Deep Learning
in Deep Learning
2
Samsung AI Center Moscow
3
Samsung-HSE Laboratory, National Research University Higher School of Economics
Abstract
We propose SWA-Gaussian (SWAG), a simple, scalable, and general purpose
approach for uncertainty representation and calibration in deep learning. Stochastic
Weight Averaging (SWA), which computes the first moment of stochastic gradient
descent (SGD) iterates with a modified learning rate schedule, has recently been
shown to improve generalization in deep learning. With SWAG, we fit a Gaussian
using the SWA solution as the first moment and a low rank plus diagonal covariance
also derived from the SGD iterates, forming an approximate posterior distribution
over neural network weights; we then sample from this Gaussian distribution to
perform Bayesian model averaging. We empirically find that SWAG approximates
the shape of the true posterior, in accordance with results describing the stationary
distribution of SGD iterates. Moreover, we demonstrate that SWAG performs
well on a wide variety of tasks, including out of sample detection, calibration,
and transfer learning, in comparison to many popular alternatives including MC
dropout, KFAC Laplace, SGLD, and temperature scaling.
1 Introduction
Ultimately, machine learning models are used to make decisions. Representing uncertainty is crucial
for decision making. For example, in medical diagnoses and autonomous vehicles we want to protect
against rare but costly mistakes. Deep learning models typically lack a representation of uncertainty,
and provide overconfident and miscalibrated predictions [e.g., 28, 19].
Bayesian methods provide a natural probabilistic representation of uncertainty in deep learning [e.g.,
6, 31, 9], and previously had been a gold standard for inference with neural networks [49]. However,
existing approaches are often highly sensitive to hyperparameter choices, and hard to scale to modern
datasets and architectures, which limits their general applicability in modern deep learning.
In this paper we propose a different approach to Bayesian deep learning: we use the information
contained in the SGD trajectory to efficiently approximate the posterior distribution over the weights
of the neural network. We find that the Gaussian distribution fitted to the first two moments of
SGD iterates, with a modified learning rate schedule, captures the local geometry of the posterior
surprisingly well. Using this Gaussian distribution we are able to obtain convenient, efficient,
accurate and well-calibrated predictions in a broad range of tasks in computer vision. In particular,
our contributions are the following:
33rd Conference on Neural Information Processing Systems (NeurIPS 2019), Vancouver, Canada.
computes an average of SGD iterates with a high constant learning rate schedule, to provide
improved generalization in deep learning and the interpretation of SGD as approximate
Bayesian inference [43]. SWAG additionally computes a low-rank plus diagonal approxima-
tion to the covariance of the iterates, which is used together with the SWA mean, to define a
Gaussian posterior approximation over neural network weights.
• SWAG is motivated by the theoretical analysis of the stationary distribution of SGD iterates
[e.g., 43, 10], which suggests that the SGD trajectory contains useful information about the
geometry of the posterior. In Appendix B we show that the assumptions of Mandt et al. [43]
do not hold for deep neural networks, due to non-convexity and over-parameterization (with
further analysis in the supplementary material). However, we find in Section 4 that in the
low-dimensional subspace spanned by SGD iterates the shape of the posterior distribution is
approximately Gaussian within a basin of attraction. Further, SWAG is able to capture the
geometry of this posterior remarkably well.
• In an exhaustive empirical evaluation we show that SWAG can provide well-calibrated
uncertainty estimates for neural networks across many settings in computer vision. In partic-
ular SWAG achieves higher test likelihood compared to many state-of-the-art approaches,
including MC-Dropout [14], temperature scaling [19], SGLD [59], KFAC-Laplace [54] and
SWA [27] on CIFAR-10, CIFAR-100 and ImageNet, on a range of architectures. We also
demonstrate the effectiveness of SWAG for out-of-domain detection, and transfer learning.
While we primarily focus on image classification, we show that SWAG can significantly im-
prove test perplexities of LSTM networks on language modeling problems, and in Appendix
G we also compare SWAG with Probabilistic Back-propagation (PBP) [23], Deterministic
Variational Inference (DVI) [60], and Deep Gaussian Processes [7] on regression problems.
• We release PyTorch code at https://github.com/wjmaddox/swa_gaussian.
2 Related Work
2.1 Bayesian Methods
Bayesian approaches represent uncertainty by placing a distribution over model parameters, and then
marginalizing these parameters to form a whole predictive distribution, in a procedure known as
Bayesian model averaging. In the late 1990s, Bayesian methods were the state-of-the-art approach to
learning with neural networks, through the seminal works of Neal [49] and MacKay [40]. However,
modern neural networks often contain millions of parameters, the posterior over these parameters
(and thus the loss surface) is highly non-convex, and mini-batch approaches are often needed to
move to a space of good solutions [29]. For these reasons, Bayesian approaches have largely been
intractable for modern neural networks. Here, we review several modern approaches to Bayesian
deep learning.
Markov chain Monte Carlo (MCMC) was at one time a gold standard for inference with neural
networks, through the Hamiltonian Monte Carlo (HMC) work of Neal [49]. However, HMC requires
full gradients, which is computationally intractable for modern neural networks. To extend the HMC
framework, stochastic gradient HMC (SGHMC) was introduced by Chen et al. [9] and allows for
stochastic gradients to be used in Bayesian inference, crucial for both scalability and exploring a space
of solutions that provide good generalization. Alternatively, stochastic gradient Langevin dynamics
(SGLD) [59] uses first order Langevin dynamics in the stochastic gradient setting. Theoretically,
both SGHMC and SGLD asymptotically sample from the posterior in the limit of infinitely small
step sizes. In practice, using finite learning rates introduces approximation errors (see e.g. [43]), and
tuning stochastic gradient MCMC methods can be quite difficult.
Variational Inference: Graves [18] suggested fitting a Gaussian variational posterior approxima-
tion over the weights of neural networks. This technique was generalized by Kingma and Welling
[33] which proposed the reparameterization trick for training deep latent variable models; multiple
variational inference methods based on the reparameterization trick were proposed for DNNs [e.g.,
32, 6, 45, 39]. While variational methods achieve strong performance for moderately sized networks,
they are empirically noted to be difficult to train on larger architectures such as deep residual networks
[22]; Blier and Ollivier [5] argue that the difficulty of training is explained by variational methods
2
providing inusfficient data compression for DNNs despite being designed for data compression (mini-
mum description length). Recent key advances [39, 60] in variational inference for deep learning
typically focus on smaller-scale datasets and architectures. An alternative line of work re-interprets
noisy versions of optimization algorithms: for example, noisy Adam [30] and noisy KFAC [64], as
approximate variational inference.
Dropout Variational Inference: Gal and Ghahramani [14] used a spike and slab variational distri-
bution to view dropout at test time as approximate variational Bayesian inference. Concrete dropout
[15] extends this idea to optimize the dropout probabilities as well. From a practical perspective,
these approaches are quite appealing as they only require ensembling dropout predictions at test time,
and they were succesfully applied to several downstream tasks [28, 46].
Laplace Approximations assume a Gaussian posterior, N (θ∗ , I(θ∗ )−1 ), where θ∗ is a MAP
estimate and I(θ∗ )−1 is the inverse of the Fisher information matrix (expected value of the Hessian
evaluated at θ∗ ). It was notably used for Bayesian neural networks in MacKay [41], where a diagonal
approximation to the inverse of the Hessian was utilized for computational reasons. More recently,
Kirkpatrick et al. [34] proposed using diagonal Laplace approximations to overcome catastrophic
forgetting in deep learning. Ritter et al. [54] proposed the use of either a diagonal or block Kronecker
factored (KFAC) approximation to the Hessian matrix for Laplace approximations, and Ritter et al.
[53] successfully applied the KFAC approach to online learning scenarios.
Mandt et al. [43] proposed to use the iterates of averaged SGD as an MCMC sampler, after analyzing
the dynamics of SGD using tools from stochastic calculus. From a frequentist perspective, Chen et al.
[10] showed that under certain conditions a batch means estimator of the sample covariance matrix of
the SGD iterates converges to A = H(θ)−1 C(θ)H(θ)−1 , where H(θ)−1 is the inverse of the Hessian
of the log likelihood and C(θ) = E(∇ log p(θ)∇ log p(θ)T ) is the covariance of the gradients of the
log likelihood. Chen et al. [10] then show that using A and the sample average of the iterates for a
Gaussian approximation produces well calibrated confidence intervals of the parameters and that the
variance of these estimators achieves the Cramer Rao lower bound (the minimum possible variance).
A description of the asymptotic covariance of the SGD iterates dates back to Ruppert [55] and Polyak
and Juditsky [52], who show asymptotic convergence of Polyak-Ruppert averaging.
Lakshminarayanan et al. [36] proposed using ensembles of several networks for enhanced calibration,
and incorporated an adversarial loss function to be used when possible as well. Outside of probabilistic
neural networks, Guo et al. [19] proposed temperature scaling, a procedure which uses a validation set
and a single hyperparameter to rescale the logits of DNN outputs for enhanced calibration. Kuleshov
et al. [35] propose calibrated regression using a similar rescaling technique.
Standard training of deep neural networks (DNNs) proceeds by applying stochastic gradient descent
on the model weights θ with the following update rule:
B
!
1 X ∇θ log p(θ)
∆θt = −ηt ∇θ log p(yi |fθ (xi )) − ,
B i=1 N
3
where the learning rate is η, the ith input (e.g. image) and label are {xi , yi }, the size of the whole
training set is N , the size of the batch isPB, and the DNN, f, has weight parameters θ.2 The loss
function is a negative log likelihood − i log p(yi |fθ (xi )), combined with a regularizer log p(θ).
This type of maximum likelihood training does not represent uncertainty in the predictions or
parameters θ.
The main idea of SWA [27] is to run SGD with a constant learning rate schedule starting from a
pre-trained solution, and to average the weights of the models it traverses. Denoting the weights of
the network obtained after epoch i of SWA training θi , the SWA solution after T epochs is given
PT
by θSWA = T1 i=1 θi . A high constant learning rate schedule ensures that SGD explores the set of
possible solutions instead of simply converging to a single point in the weight space. Izmailov et al.
[27] argue that conventional SGD training converges to the boundary of the set of high-performing
solutions; SWA on the other hand is able to find a more centered solution that is robust to the shift
between train and test distributions, leading to improved generalization performance. SWA and
related ideas have been successfully applied to a wide range of applications [see e.g. 2, 61, 62, 51]. A
related but different procedure is Polyak-Ruppert averaging [52, 55] in stochastic convex optimization,
which uses a learning rate decaying to zero. Mandt et al. [43] interpret Polyak-Ruppert averaging as a
sampling procedure, with convergence occurring to the true posterior under certain strong conditions.
Additionally, they explore the theoretical feasibility of SGD (and averaged SGD) as an approximate
Bayesian inference scheme; we test their assumptions in Appendix A.
3.3 SWAG-Diagonal
We first consider a simple diagonal format for the covariance matrix. In order to fit a diagonal
covariance approximation, we maintain a running average of the second uncentered moment for each
weight, and then compute the covariance using the following standard identity at the end of training:
PT
θ2 = T1 i=1 θi2 , Σdiag = diag(θ2 − θSWA
2 2
); here the squares in θSWA and θi2 are applied elementwise.
The resulting approximate posterior distribution is then N (θSWA , ΣDiag ). In our experiments, we term
this method SWAG-Diagonal.
Constructing the SWAG-Diagonal posterior approximation requires storing two additional copies
of DNN weights: θSWA and θ2 . Note that these models do not have to be stored on the GPU. The
additional computational complexity of constructing SWAG-Diagonal compared to standard training
is negligible, as it only requires updating the running averages of weights once per epoch.
We now describe the full SWAG algorithm. While the diagonal covariance approximation is standard
in Bayesian deep learning [6, 34], it can be too restrictive. We extend the idea of diagonal covariance
approximations to utilize a more flexible low-rank plus diagonal posterior approximation. SWAG
approximates the sample covariance Σ of the SGD iterates along with the mean θSWA .3
Note that the sample covariance matrix of the SGD iterates can be written as the sum of outer products,
1
PT >
Σ = T −1 i=1 (θi − θSWA )(θi − θSWA ) , and is of rank T . As we do not have access to the value
1
PT
of θSWA during training, we approximate the sample covariance with Σ ≈ T −1 i=1 (θi − θ̄i )(θi −
> 1 >
θ̄i ) = T −1 DD , where D is the deviation matrix comprised of columns Di = (θi − θ̄i ), and θ̄i is
the running estimate of the parameters’ mean obtained from the first i samples. To limit the rank of
the estimated covariance matrix we only use the last K of Di vectors corresponding to the last K
2
We ignore momentum for simplicity in this update; however we utilized momentum in the resulting
experiments and it is covered theoretically [43].
3
We note that stochastic gradient Monte Carlo methods [9, 59] also use the SGD trajectory to construct
samples from the approximate posterior. However, these methods are principally different from SWAG in that
they (1) require adding Gaussian noise to the gradients, (2) decay learning rate to zero and (3) do not construct a
closed-form approximation to the posterior distribution, which for instance enables SWAG to draw new samples
with minimal overhead. We include comparisons to SGLD [59] in the Appendix.
4
epochs of training. Here K is the rank of the resulting approximation and is a hyperparameter of the
method. We define D b to be the matrix with columns equal to Di for i = T − K + 1, . . . , T .
Related methods for estimating the covariance of SGD iterates were considered in Mandt et al. [43]
and Chen et al. [10], but store full-rank covariance Σ and thus scale quadratically in the number of
parameters, which is prohibitively expensive for deep learning applications. We additionally note that
using the deviation matrix for online covariance matrix estimation comes from viewing the online
updates used in Dasgupta and Hsu [12] in matrix fashion.
The full Bayesian model averaging procedure is given in Algorithm 1. As in Izmailov et al. [27]
(SWA) we update the batch normalization statistics after sampling weights for models that use batch
normalization [25]; we investigate the necessity of this update in Appendix D.4.
Algorithm 1 Bayesian Model Averaging with SWAG
θ0 : pretrained weights; η: learning rate; T : number of steps; c: moment update frequency; K: maximum
number of columns in deviation matrix; S: number of samples in Bayesian model averaging
Train SWAG Test Bayesian Model Averaging
θ ← θ0 , θ2 ← θ02 {Initialize moments} for i ← 1, 2, ..., Sdo
for i ← 1, 2, ..., T do b>
Draw θei ∼ N θSWA , 21 Σdiag + 2(K−1)
DbD
(1)
θi ← θi−1 −η∇θ L(θi−1 ){Perform SGD update}
Update batch norm statistics with new sample.
if MOD(i, c) = 0 then p(y ∗ |Data) + = S1 p(y ∗ |θei )
n ← i/c {Number of models} return p(y ∗ |Data)
nθ + θi nθ 2 + θ2
i
θ← , θ2 ← {Moments}
n+1 n+1
if NUM_COLS(D) = K then
b
REMOVE_COL(D[:,b 1])
APPEND_COL(D, θi − θ) {Store deviation}
b
2
return θSWA = θ, Σdiag = θ2 − θ , D b
5
Train loss Train loss Train loss
PreResNet-164 CIFAR-100 PreResNet-164 CIFAR-100 PreResNet-164 CIFAR-100
1.6 80 >5 >5
1.4 5
60 5 40
1.2
40 1.7 1.9
1.0 20
Train loss
20 0.65 0.75
0.8
v4
0 0.34
v2
0.6
0 0.27
Figure 1: Left: Posterior joint density cross-sections along the rays corresponding to different
eigenvectors of SWAG covariance matrix. Middle: Posterior joint density surface in the plane
spanned by eigenvectors of SWAG covariance matrix corresponding to the first and second largest
eigenvalues and (Right:) the third and fourth largest eigenvalues. All plots are produced using
PreResNet-164 on CIFAR-100. The SWAG distribution projected onto these directions fits the
geometry of the posterior density remarkably well.
N (θ; µ, Σ), and then sample from this posterior distribution to perform a Bayesian model average.
In our procedure, optimization with different regularizers, to characterize the Gaussian posterior
approximation, corresponds to approximate Bayesian inference with different priors p(θ).
Prior Choice Typically, weight decay is used to regularize DNNs, corresponding to explicit L2
regularization when SGD without momentum is used to train the model. When SGD is used with
momentum, as is typically the case, implicit regularization still occurs, producing a vague prior on
the weights of the DNN in our procedure. This regularizer can be given an explicit Gaussian-like
form (see Proposition 3 of Loshchilov and Hutter [38]), corresponding to a prior distribution on the
weights.
Thus, SWAG is an approximate Bayesian inference algorithm in our experiments (see Section 5) and
can be applied to most DNNs without any modifications of the training procedure (as long as SGD is
used with weight decay or explicit L2 regularization). Alternative regularization techniques could
also be used, producing different priors on the weights. It may also be possible to similarly utilize
Adam and other stochastic first-order methods, which view as a promising direction for future work.
6
WideResNet28x10 PreResNet-164 VGG-16 WideResNet28x10 PreResNet-164 VGG-16 WideResNet28x10 PreResNet-164 VGG-16 DenseNet-161 ResNet-152
CIFAR-100 CIFAR-100 CIFAR-100 CIFAR-10 CIFAR-10 CIFAR-10 CIFAR-10→STL-10 CIFAR-10→STL-10 CIFAR-10→STL-10 ImageNet ImageNet
0.95 0.14 1.5 1.7
0.80 0.18 0.325 0.87
0.90 1.10 0.90
1.6 1.4 1.6
0.17 0.300 0.86
0.75 0.13
0.85 1.05 1.5
0.16 1.3 0.88
1.4 0.275 0.85
NLL 0.70
0.80
0.15
1.00
1.2
1.4
0.12 0.250
0.86 0.84
0.75 1.2 0.14 0.95 1.3
0.225 1.1
0.65 0.83
0.70 0.13 1.2
0.11 0.90 0.84
1.0 0.200 1.0
0.60 0.65 0.12 1.1 0.82
SWAG SWAG-Diag SGD SWA SGD-Temp SWA-Temp KFAC-Laplace SGD-Drop SWA-Drop SGLD
Figure 2: Negative log likelihoods for SWAG and baselines. Mean and standard deviation (shown
with error-bars) over 3 runs are reported for each experiment on CIFAR datasets. SWAG (blue
star) consistently outperforms alternatives, with lower negative log likelihood, with the largest
improvements on transfer learning. Temperature scaling applied on top of SWA (SWA-Temp) often
performs close to as well on the non-transfer learning tasks, but requires a validation set.
the weight space, and we can not guarantee that SWAG variance estimates are adequate along all
directions in weight space. In particular, we would expect SWAG to under-estimate the variances
along random directions, as the SGD trajectory is in a low-dimensional subspace of the weight
space, and a random vector has a close-to-zero projection on this subspace with high probability. In
Appendix A we visualize the trajectory of SGD applied to a quadratic function, and further discuss
the relation between the geometry of objective and SGD trajectory. In Appendices A and B, we also
empirically test the assumptions behind theory relating the SGD stationary distribution to the true
posterior for neural networks.
5 Experiments
In this section we evaluate the quality of uncertainty estimates as well as predictive accuracy for
SWAG and SWAG-Diagonal on CIFAR-10, CIFAR-100 and ImageNet ILSVRC-2012 [56].
For all methods we analyze test negative log-likelihood, which reflects both the accuracy and the
quality of predictive uncertainty. Following Guo et al. [19] we also consider a variant of reliability
diagrams to evaluate the calibration of uncertainty estimates (see Figure 3) and to show the difference
between a method’s confidence in its predictions and its accuracy. To produce this plot for a given
method we split the test data into 20 bins uniformly based on the confidence of a method (maximum
predicted probability). We then evaluate the accuracy and mean confidence of the method on the
images from each bin, and plot the difference between confidence and accuracy. For a well-calibrated
model, this difference should be close to zero for each bin. We found that this procedure gives a more
effective visualization of the actual confidence distribution of DNN predictions than the standard
reliability diagrams used in Guo et al. [19] and Niculescu-Mizil and Caruana [50].
We provide tables containing the test accuracy, negative log likelihood and expected calibration error
for all methods and datasets in Appendix E.3.
CIFAR datasets On CIFAR datasets we run experiments with VGG-16, PreResNet-164 and
WideResNet-28x10 networks. In order to compare SWAG with existing alternatives we report the
results for standard SGD and SWA [27] solutions (single models), MC-Dropout [14], temperature
scaling [19] applied to SWA and SGD solutions, SGLD [59], and K-FAC Laplace [54] methods. For
all the methods we use our implementations in PyTorch (see Appendix H). We train all networks
for 300 epochs, starting to collect models for SWA and SWAG approximations once per epoch after
epoch 160. For SWAG, K-FAC Laplace, and Dropout we use 30 samples at test time.
7
WideResNet28x10 CIFAR-100 WideResNet28x10 CIFAR-10 → STL-10 DenseNet-161 ImageNet ResNet-152 ImageNet
0.40 0.12
0.15
0.35 0.10 0.10
Confidence - Accuracy
Confidence - Accuracy
Confidence - Accuracy
Confidence - Accuracy
0.10
0.30 0.08 0.08
0.05 0.25 0.05 0.05
0.20 0.02 0.02
0.00
0.15 0.00
0.00
-0.05 0.10 -0.02
-0.03
-0.10 0.05 -0.05
-0.05
0.00 -0.08
0.200 0.759 0.927 0.978 0.993 0.998 0.200 0.759 0.927 0.978 0.993 0.998 0.200 0.759 0.927 0.978 0.993 0.998 0.200 0.759 0.927 0.978 0.993 0.998
Confidence (max prob) Confidence (max prob) Confidence (max prob) Confidence (max prob)
SGD SGLD SWA-Drop SWA-Temp SWAG SWAG-Diag
Figure 3: Reliability diagrams for WideResNet28x10 on CIFAR-100 and transfer task; ResNet-152
and DenseNet-161 on ImageNet. Confidence is the value of the max softmax output. A perfectly
calibrated network has no difference between confidence and accuracy, represented by a dashed black
line. Points below this line correspond to under-confident predictions, whereas points above the
line are overconfident predictions. SWAG is able to substantially improve calibration over standard
training (SGD), as well as SWA. Additionally, SWAG significantly outperforms temperature scaling
for transfer learning (CIFAR-10 to STL), where the target data are not from the same distribution as
the training data.
ImageNet On ImageNet we report our results for SWAG, SWAG-Diagonal, SWA and SGD. We
run experiments with DenseNet-161 [24] and Resnet-152 [22]. For each model we start from a
pre-trained model available in the torchvision package, and run SGD with a constant learning rate
for 10 epochs. We collect models for the SWAG versions and SWA 4 times per epoch. For SWAG
we use 30 samples from the posterior over network weights at test-time, and use randomly sampled
10% of the training data to update batch-normalization statistics for each of the samples. For SGD
with temperature scaling, we use the results reported in Guo et al. [19].
Transfer from CIFAR-10 to STL-10 We use the models trained on CIFAR-10 and evaluate them
on STL-10 [11]. STL-10 has a similar set of classes as CIFAR-10, but the image distribution is
different, so adapting the model from CIFAR-10 to STL-10 is a commonly used transfer learning
benchmark. We provide further details on the architectures and hyperparameters in Appendix H.
Results We visualize the negative log-likelihood for all methods and datasets in Figure 2. On all
considered tasks SWAG and SWAG diagonal perform comparably or better than all the considered
alternatives, SWAG being best overall. We note that the combination of SWA and temperature scaling
presents a competitive baseline. However, unlike SWAG it requires using a validation set to tune the
temperature; further, temperature scaling is not effective when the test data distribution differs from
train, as we observe in experiments on transfer learning from CIFAR-10 to STL-10.
Next, we analyze the calibration of uncertainty estimates provided by different methods. In Figure
3 we present reliability plots for WideResNet on CIFAR-100, DenseNet-161 and ResNet-152 on
ImageNet. The reliability diagrams for all other datasets and architectures are presented in the
Appendix E.1. As we can see, SWAG and SWAG-Diagonal both achieve good calibration across
the board. The low-rank plus diagonal version of SWAG is generally better calibrated than SWAG-
Diagonal. We also present the expected calibration error for each of the methods, architectures and
datasets in Tables A.3,4. Finally, in Tables A.9,10 we present the predictive accuracy for all of the
methods, where SWAG is comparable with SWA and generally outperforms the other approaches.
8
much lower NLL on the larger PreResNet-164 and WideResNet28x10; the results for accuracy and
ECE are analogous.
To evaluate SWAG on out-of-domain data detection we train a WideResNet as described in section 5.1
on the data from five classes of the CIFAR-10 dataset, and then analyze predictions of SWAG variants
along with the baselines on the full test set. We expect the outputted class probabilities on objects that
belong to classes that were not present in the training data to have high-entropy reflecting the model’s
high uncertainty in its predictions, and considerably lower entropy on the images that are similar
to those on which the network was trained. We plot the histograms of predictive entropies on the
in-domain and out-of-domain in Figure A.10 for a qualitative comparison and report the symmetrized
KL divergence between the binned in and out of sample distributions in Table 2, finding that SWAG
and Dropout perform best on this measure. Additional details are in Appendix E.2.
We next apply SWAG to an LSTM network on language modeling tasks on Penn Treebank and
WikiText-2 datasets. In Appendix F we demonstrate that SWAG easily outperforms both SWA and
NT-ASGD [44], a strong baseline for LSTM training, in terms of test and validation perplexities.
We compare SWAG to SWA and the NT-ASGD method [44], which is a strong baseline for training
LSTM models. The main difference between SWA and NT-ASGD, which is also based on weight
averaging, is that NT-ASGD starts weight averaging much earlier than SWA: NT-ASGD switches
to ASGD (averaged SGD) typically around epoch 100 while with SWA we start averaging after
pre-training for 500 epochs. We report test and validation perplexities for different methods and
datasets in Table 1.
As we can see, SWA substantially improves perplexities on both datasets over NT-ASGD. Further,
we observe that SWAG is able to substantially improve test perplexities over the SWA solution.
Table 1: Validation and Test perplexities for NT-ASGD, SWA and SWAG on Penn Treebank and
WikiText-2 datasets.
Method PTB val PTB test WikiText-2 val WikiText-2 test
NT-ASGD 61.2 58.8 68.7 65.6
SWA 59.1 56.7 68.1 65.0
SWAG 58.6 56.26 67.2 64.1
5.5 Regression
Finally, while the empirical focus of our paper is classification calibration, we also compare to
additional approximate BNN inference methods which perform well on smaller architectures, includ-
ing deterministic variational inference (DVI) [60], single-layer deep GPs (DGP) with expectation
propagation [7], SGLD [59], and re-parameterization VI [33] on a set of UCI regression tasks. We
report test log-likelihoods, RMSEs and test calibration results in Appendix Tables 12 and 13 where it
is possible to see that SWAG is competitive with these methods. Additional details are in Appendix
G.
6 Discussion
In this paper we developed SWA-Gaussian (SWAG) for approximate Bayesian inference in deep
learning. There has been a great desire to apply Bayesian methods in deep learning due to their
theoretical properties and past success with small neural networks. We view SWAG as a step towards
practical, scalable, and accurate Bayesian deep learning for large modern neural networks.
A key geometric observation in this paper is that the posterior distribution over neural network
parameters is close to Gaussian in the subspace spanned by the trajectory of SGD. Our work shows
Bayesian model averaging within this subspace can improve predictions over SGD or SWA solutions.
Furthermore, Gur-Ari et al. [20] argue that the SGD trajectory lies in the subspace spanned by the
eigenvectors of the Hessian corresponding to the top eigenvalues, implying that the SGD trajectory
9
subspace corresponds to directions of rapid change in predictions. In recent work, Izmailov et al. [26]
show promising results from directly constructing subspaces for Bayesian inference.
Acknowledgements
WM, PI, and AGW were supported by an Amazon Research Award, Facebook Research, NSF
IIS-1563887, and NSF IIS-1910266. WM was additionally supported by an NSF Graduate Research
Fellowship under Grant No. DGE-1650441. DV was supported by the Russian Science Foundation
grant no.19-71-30020. We would like to thank Jacob Gardner, Polina Kirichenko, and David Widmann
for helpful discussions.
References
[1] Asmussen, S. and Glynn, P. W. (2007). Stochastic simulation: algorithms and analysis.
Number 57 in Stochastic modelling and applied probability. Springer, New York. OCLC:
ocn123113652.
[2] Athiwaratkun, B., Finzi, M., Izmailov, P., and Wilson, A. G. (2019). There are many consistent
explanations for unlabeled data: why you should average. In International Conference on Learning
Representations. arXiv: 1806.05594.
[3] Babichev, D. and Bach, F. (2018). Constant step size stochastic gradient descent for probabilistic
modeling. In Uncertainty in Artificial Intelligence. arXiv preprint arXiv:1804.05567.
[4] Berger, J. O. (2013). Statistical decision theory and Bayesian analysis. Springer Science &
Business Media.
[5] Blier, L. and Ollivier, Y. (2018). The Description Length of Deep Learning models. In Advances
in Neural Information Processing Systems, page 11.
[6] Blundell, C., Cornebise, J., Kavukcuoglu, K., and Wierstra, D. (2015). Weight Uncertainty in
Neural Networks. In International Conference on Machine Learning. arXiv: 1505.05424.
[7] Bui, T., Hernández-Lobato, D., Hernandez-Lobato, J., Li, Y., and Turner, R. (2016). Deep
gaussian processes for regression using approximate expectation propagation. In International
Conference on Machine Learning, pages 1472–1481.
[8] Chaudhari, P. and Soatto, S. (2018). Stochastic gradient descent performs variational infer-
ence, converges to limit cycles for deep networks. In International Conference on Learning
Representations. arXiv: 1710.11029.
[9] Chen, T., Fox, E. B., and Guestrin, C. (2014). Stochastic Gradient Hamiltonian Monte Carlo. In
International Conference on Machine Learning. arXiv: 1402.4102.
[10] Chen, X., Lee, J. D., Tong, X. T., and Zhang, Y. (2016). Statistical Inference for Model
Parameters in Stochastic Gradient Descent. arXiv: 1610.08637.
[11] Coates, A., Ng, A., and Lee, H. (2011). An Analysis of Single-Layer Networks in Unsuper-
vised Feature Learning. In Proceedings of the Fourteenth International Conference on Artificial
Intelligence and Statistics, pages 215–223.
[12] Dasgupta, S. and Hsu, D. (2007). On-Line Estimation with the Multivariate Gaussian Distri-
bution. In Bshouty, N. H. and Gentile, C., editors, Twentieth Annual Conference on Learning
Theory., volume 4539, pages 278–292, Berlin, Heidelberg. Springer Berlin Heidelberg.
[13] Draxler, F., Veschgini, K., Salmhofer, M., and Hamprecht, F. A. (2018). Essentially No Barriers
in Neural Network Energy Landscape. In International Conference on Machine Learning, page 10.
[14] Gal, Y. and Ghahramani, Z. (2016). Dropout as a Bayesian Approximation. In International
Conference on Machine Learning.
[15] Gal, Y., Hron, J., and Kendall, A. (2017). Concrete Dropout. In Advances in Neural Information
Processing Systems. arXiv: 1705.07832.
10
[16] Gardner, J., Pleiss, G., Weinberger, K. Q., Bindel, D., and Wilson, A. G. (2018). Gpytorch:
Blackbox matrix-matrix gaussian process inference with gpu acceleration. In Advances in Neural
Information Processing Systems, pages 7587–7597.
[17] Garipov, T., Izmailov, P., Podoprikhin, D., Vetrov, D. P., and Wilson, A. G. (2018). Loss
surfaces, mode connectivity, and fast ensembling of dnns. In Advances in Neural Information
Processing Systems, pages 8789–8798.
[18] Graves, A. (2011). Practical variational inference for neural networks. In Advances in neural
information processing systems, pages 2348–2356.
[19] Guo, C., Pleiss, G., Sun, Y., and Weinberger, K. Q. (2017). On Calibration of Modern Neural
Networks. In International Conference on Machine Learning. arXiv: 1706.04599.
[20] Gur-Ari, G., Roberts, D. A., and Dyer, E. (2019). Gradient descent happens in a tiny subspace.
[21] Halko, N., Martinsson, P.-G., and Tropp, J. A. (2011). Finding structure with randomness:
Probabilistic algorithms for constructing approximate matrix decompositions. SIAM review,
53(2):217–288.
[22] He, K., Zhang, X., Ren, S., and Sun, J. (2016). Deep Residual Learning for Image Recognition.
In CVPR. arXiv: 1512.03385.
[23] Hernández-Lobato, J. M. and Adams, R. (2015). Probabilistic Backpropagation for Scalable
Learning of Bayesian Neural Networks. In Advances in Neural Information Processing Systems.
[24] Huang, G., Liu, Z., van der Maaten, L., and Weinberger, K. Q. (2017). Densely Connected
Convolutional Networks. In CVPR. arXiv: 1608.06993.
[25] Ioffe, S. and Szegedy, C. (2015). Batch normalization: Accelerating deep network training by
reducing internal covariate shift. arXiv preprint arXiv:1502.03167.
[26] Izmailov, P., Maddox, W. J., Kirichenko, P., Garipov, T., Vetrov, D., and Wilson, A. G. (2019).
Subspace inference for bayesian deep learning. arXiv preprint arXiv:1907.07504.
[27] Izmailov, P., Podoprikhin, D., Garipov, T., Vetrov, D., and Wilson, A. G. (2018). Averaging
weights leads to wider optima and better generalization. Uncertainty in Artificial Intelligence
(UAI).
[28] Kendall, A. and Gal, Y. (2017). What Uncertainties Do We Need in Bayesian Deep Learning
for Computer Vision? In Advances in Neural Information Processing Systems, Long Beach.
[29] Keskar, N. S., Mudigere, D., Nocedal, J., Smelyanskiy, M., and Tang, P. T. P. (2017). On
Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima. In International
Conference on Learning Representations. arXiv: 1609.04836.
[30] Khan, M. E., Nielsen, D., Tangkaratt, V., Lin, W., Gal, Y., and Srivastava, A. (2018). Fast and
Scalable Bayesian Deep Learning by Weight-Perturbation in Adam. In International Conference
on Machine Learning. arXiv: 1806.04854.
[31] Kingma, D. P., Salimans, T., and Welling, M. (2015a). Variational Dropout and the Local
Reparameterization Trick. arXiv:1506.02557 [cs, stat]. arXiv: 1506.02557.
[32] Kingma, D. P., Salimans, T., and Welling, M. (2015b). Variational dropout and the local
reparameterization trick. In Advances in Neural Information Processing Systems, pages 2575–
2583.
[33] Kingma, D. P. and Welling, M. (2013). Auto-encoding variational bayes. In International
Conference on Learning Representations.
[34] Kirkpatrick, J., Pascanu, R., Rabinowitz, N., Veness, J., Desjardins, G., Rusu, A. A., Milan, K.,
Quan, J., Ramalho, T., Grabska-Barwinska, A., et al. (2017). Overcoming catastrophic forgetting
in neural networks. Proceedings of the national academy of sciences, page 201611835.
11
[35] Kuleshov, V., Fenner, N., and Ermon, S. (2018). Accurate Uncertainties for Deep Learning
Using Calibrated Regression. In International Conference on Machine Learning, page 9.
[36] Lakshminarayanan, B., Pritzel, A., and Blundell, C. (2017). Simple and Scalable Predictive
Uncertainty Estimation using Deep Ensembles. In Advances in Neural Information Processing
Systems.
[37] Li, H., Xu, Z., Taylor, G., Studer, C., and Goldstein, T. (2018). Visualizing the Loss Landscape
of Neural Nets. In Advances in Neural Information Processing Systems. arXiv: 1712.09913.
[38] Loshchilov, I. and Hutter, F. (2019). Decoupled Weight Decay Regularization. In International
Conference on Learning Representations. arXiv: 1711.05101.
[39] Louizos, C. and Welling, M. (2017). Multiplicative normalizing flows for variational bayesian
neural networks. In International Conference on Machine Learning.
[40] MacKay, D. J. C. (1992a). Bayesian Interpolation. Neural Computation.
[41] MacKay, D. J. C. (1992b). A Practical Bayesian Framework for Backpropagation Networks.
Neural Computation, 4(3):448–472.
[42] MacKay, D. J. C. (2003). Information theory, inference, and learning algorithms. Cambridge
University Press, Cambridge, UK ; New York.
[43] Mandt, S., Hoffman, M. D., and Blei, D. M. (2017). Stochastic Gradient Descent as Approximate
Bayesian Inference. JMLR, 18:1–35.
[44] Merity, S., Keskar, N. S., and Socher, R. (2017). Regularizing and optimizing lstm language
models. arXiv preprint arXiv:1708.02182.
[45] Molchanov, D., Ashukha, A., and Vetrov, D. (2017). Variational dropout sparsifies deep neural
networks. arXiv preprint arXiv:1701.05369.
[46] Mukhoti, J. and Gal, Y. (2018). Evaluating Bayesian Deep Learning Methods for Semantic
Segmentation.
[47] Müller, U. K. (2013). Risk of bayesian inference in misspecified models, and the sandwich
covariance matrix. Econometrica, 81(5):1805–1849.
[48] Naeini, M. P., Cooper, G. F., and Hauskrecht, M. (2015). Obtaining well calibrated probabilities
using bayesian binning. In AAAI, pages 2901–2907.
[49] Neal, R. M. (1996). Bayesian Learning for Neural Networks, volume 118 of Lecture Notes in
Statistics. Springer New York, New York, NY.
[50] Niculescu-Mizil, A. and Caruana, R. (2005). Predicting good probabilities with supervised
learning. In International Conference on Machine Learning, pages 625–632, Bonn, Germany.
ACM Press.
[51] Nikishin, E., Izmailov, P., Athiwaratkun, B., Podoprikhin, D., Garipov, T., Shvechikov, P.,
Vetrov, D., and Wilson, A. G. (2018). Improving stability in deep reinforcement learning with
weight averaging.
[52] Polyak, B. T. and Juditsky, A. B. (1992). Acceleration of Stochastic Approximation by Averag-
ing. SIAM Journal on Control and Optimization, 30(4):838–855.
[53] Ritter, H., Botev, A., and Barber, D. (2018a). Online Structured Laplace Approximations For
Overcoming Catastrophic Forgetting. In Advances in Neural Information Processing Systems.
arXiv: 1805.07810.
[54] Ritter, H., Botev, A., and Barber, D. (2018b). A Scalable Laplace Approximation for Neural
Networks. In International Conference on Learning Representations.
[55] Ruppert, D. (1988). Efficient Estimators from a Slowly Convergent Robbins-Munro Process.
Technical Report 781, Cornell University, School of Operations Report and Industrial Engineering.
12
[56] Russakovsky, O., Deng, J., Su, H., Krause, J., Satheesh, S., Ma, S., Huang, Z., Karpathy, A.,
Khosla, A., Bernstein, M., Berg, A. C., and Fei-Fei, L. (2015). ImageNet Large Scale Visual
Recognition Challenge. IJCV, 115(3):211–252. arXiv: 1409.0575.
[57] Sagun, L., Evci, U., Guney, V. U., Dauphin, Y., and Bottou, L. (2018). Empirical Analysis of
the Hessian of Over-Parametrized Neural Networks. In International Conference on Learning
Representations Workshop Track. arXiv: 1706.04454.
[58] Vaart, A. W. v. d. (1998). Asymptotic Statistics. Cambridge Series in Statistical and Probabilistic
Mathematics. Cambridge University Press, Cambridge.
[59] Welling, M. and Teh, Y. W. (2011). Bayesian learning via stochastic gradient langevin dynamics.
In Proceedings of the 28th international conference on machine learning (ICML-11), pages
681–688.
[60] Wu, A., Nowozin, S., Meeds, E., Turner, R. E., Hernández-Lobato, J. M., and Gaunt, A. L.
(2019). Fixing variational bayes: Deterministic variational inference for bayesian neural networks.
In Inernational Conference on Learning Representations. arXiv preprint arXiv:1810.03958.
[61] Yang, G., Zhang, T., Kirichenko, P., Bai, J., Wilson, A. G., and De Sa, C. (2019). Swalp:
Stochastic weight averaging in low precision training. In International Conference on Machine
Learning, pages 7015–7024.
[62] Yazici, Y., Foo, C.-S., Winkler, S., Yap, K.-H., Piliouras, G., and Chandrasekhar, V. (2019). The
Unusual Effectiveness of Averaging in GAN Training. In International Conference on Learning
Representations. arXiv: 1806.04498.
[63] Zagoruyko, S. and Komodakis, N. (2016). Wide Residual Networks. In BMVC. arXiv:
1605.07146.
[64] Zhang, G., Sun, S., Duvenaud, D., and Grosse, R. (2017). Noisy Natural Gradient as Variational
Inference. arXiv:1712.02390 [cs, stat]. arXiv: 1712.02390.
[65] Zhang, R., Li, C., Zhang, J., Chen, C., and Wilson, A. G. (2019). Cyclical stochastic gradient
mcmc for bayesian deep learning. arXiv preprint arXiv:1902.03932.
13
10 20 10 20
8 0.0 0.0 8 0.0 0.0
00 00 00 00
6 6
4 4
5.0 20 5.0 20
00 .00 00 .00
2 0 2 0
50
50
.00
.00
0
0
0 0
1.0 1.0
00 00
10
10
50
50
.00
.00
20 20
.00
.00
0
0
2 .00 2 .00
0
0 0
4 4
6 6
20 10 20 10
8 0.0 0.0 8 0.0 0.0
8 060 4 200 0 2 4 6 8 8 060 4 200 0 2 4 6 8
Figure 4: Trajectory of SGD with isotropic Gaussian gradient noise on a quadratic loss function.
Left: SGD without momentum; Right: SGD with momentum.
converges to the optimizer, or diverges to infinity depending on the learning rate. However, when we
add isotropic Gaussian noise to the gradients, SGD converges to the correct Gaussian distribution, as
we visualize in the left panel of Figure 4. Furthermore, adding momentum affects the scale of the
distribution, but not its shape, as we show in the right panel of Figure 4. These conclusions hold as
long as the learning rate in SGD is not too large.
The results we show in Figure 4 are directly predicted by theory in Mandt et al. [43]. In general, if the
gradient noise is not isotropic, the stationary distribution of SGD would be different from the exact
posterior distribution. Mandt et al. [43] provide a thorough empirical study of the SGD trajectory
for convex problems, such as linear and logistic regression, and show that SGD can often provide a
competitive baseline on these problems.
Given the covariance matrix A = H(θ)−1 E(∇ log p(θ)∇ log p(θ)T )H(θ)−1 , Chen et al. [10] show
that a batch means estimator of the iterates (similar to what SWAG uses) themselves will converge
to A in the limit of infinite time. We tried batch means based estimators but saw no improvement;
however, it could be interesting to explore further in future work.
Intriguingly, the covariance A is the same form as sandwich estimators [see e.g. 47, for a Bayesian
analysis in the model mis-specification setting], and so A = H(θ)−1 under model well-specification
[47, 10]. We then tie the covariance matrix of the iterates back to the well known Laplace approx-
imation, which uses H(θ)−1 as its covariance as described by MacKay [42, Chapter 28], thereby
justifying SWAG theoretically as a sample based Laplace approximation.
Finally, in Chapter 4 of Berger [4] constructs an example (Example 10) of fitting a Gaussian
approximation from a MCMC chain, arguing that it empirically performs well in Bayesian decision
theoretic contexts. Berger [4] give the explanation for this as the Bernstein von Mises Theorem
providing that in the limit the posterior will itself converge to a Gaussian. However, we would
expect that even in the infinite data limit the posterior of DNNs would converge to something very
non-Gaussian, with connected modes surrounded by gorges of zero posterior density [17]. One could
use this justification for fitting a Gaussian from the iterates of SGLD or SGHMC instead.
In this section, we investigate the results of Mandt et al. [43] in the context of deep learning. Mandt
et al. [43] uses the following assumptions:
14
4. In the stationary distribution, the loss is approximately quadratic near the optima, i.e.
approximately (θ − θ∗ )> H(θ)(θ − θ∗ ), where H(θ∗ ) is the Hessian at the optimum; further,
the Hessian is assumed to be positive definite.
Assumption 1 is motivated by the central limit theorem, and Assumption 3 is necessary for the
analysis in Mandt et al. [43]. Assumptions 2 and 4 may or may not hold for deep neural networks (as
well as other models). Under these assumptions, Theorem 1 of Mandt et al. [43] derives the optimal
constant learning rate that minimizes the KL-divergence between the SGD stationary distribution and
the posterior6 :
B d
η∗ = 2 , (2)
N tr(C)
where N is the size of the dataset, d is the dimension of the model, B is the minibatch size and C is
the gradient noise covariance.
We computed Equation 2 over the course of training for two neural networks in Figure A.5a, finding
that the predicted optimal learning rate was an order of magnitude larger than what would be used in
practice to explore the loss surface in a reasonable time (about 4 compared to 0.1).
We now focus on seeing how Assumptions 2 and 4 fail for DNNs; this will give further insight into
what portions of the theory do hold, and may give insights into a corrected version of the optimal
learning rate.
In Figure A.5b, the trace of the gradient noise covariance and thus the optimal learning rates are
nearly constant; however, the total variance is much too small to induce effective learning rates,
probably due to over-parameterization effects inducing non full rank gradient covariances as was
found in Chaudhari and Soatto [8]. We note that this experiment is not sufficient to be fully confident
that C is independent of the parameterization near the local optima, but rather that tr(C) is close to
constant; further experiments in this vein are necessary to test if the diagonals of C are constant. The
result that tr(C) is close to constant suggests that a constant learning rate could be used for sampling
in a stationary phase of training. The dimensionality parameter in Equation 2 could be modified to
use the number of effective parameters or the rank of the gradient noise to reduce the optimal learning
rate to a feasible number.
To estimate tr(C) from the gradient noise we need to divide the estimated variance by the batch size
(as V (ĝ(θ)) = BC(θ)), for a correct version of Equation 2. From Assumption 1 and Equation 6 of
Mandt et al. [43], we see that
1
ĝ(θ) ≈ g(θ) + √ ∇g(θ), ∇g(θ) ∼ N (0, C(θ)),
B
where B is the batch size. Thus, collecting the variance of ĝ(θ) (the variance of the stochastic
gradients) will give estimates that are upscaled by a factor of B, leading to a cancellation of the batch
size terms:
2 d
η≈ .
N tr(V (ĝ(θ)))
To include momentum, we can repeat the analysis in Sections 4.1 and 4.3 of Mandt et al. [43] finding
that this also involves scaling the optimal learning rate but by a factor of µ, the momentum term.7
This gives the final optimal learning rate equation as
2µ d
η≈ . (3)
N tr(V (ĝ(θ)))
In Figure 5b, we computed tr(C) for VGG-16 and PreResNet-164 on CIFAR-100 beginning from
the start of training (referred to as from scratch), as well as the start of the SWAG procedure (referred
to in the legend as SWA). We see that tr(C) is never quite constant when trained from scratch, while
for a period of constant learning rate near the end of training, referred to as the stationary phase,
6
An optimal diagonal preconditioner is also derived; our empirical work applies to that setting as well. A
similar analysis with momentum holds as well, adding in only the momentum coefficient.
7
Our experiments used µ = 0.1 corresponding to ρ = 0.9 in PyTorch’s SGD implementation.
15
1
1.0 ×10 ×103
PreResNet110 - Scratch 4 PreResNet110 - Scratch
PreResNet110 - SWA PreResNet110 - SWA
tr(C)
2
0.4
1
0.2
0.0 0
0 50 100 150 200 250 300 0 50 100 150 200 250 300
Epoch Epoch
tr(C) is essentially constant throughout. This discrepancy is likely due to large gradients at the very
beginning of training, indicating that the stationary distribution has not been reached yet.
Next, in Figure 5a, we used the computed tr(C) estimate for all four models and Equation 3 to
compute the optimal learning rate under the assumptions of Mandt et al. [43], finding that these
learning rates are not constant for the estimates beginning at the start of training and that they are too
large (1-3 at the minimum compared to a standard learning rate of 0.1 or 0.01).
To test assumption 4, we used a GPU-enabled Lanczos method from GPyTorch [16] and used
restarting to compute the minimum eigenvalue of the train loss of a pre-trained PreResNet-164 on
CIFAR-100. We found that even at the end of training, the minimum eigenvalue was −272 (the
maximum eigenvalue was 3580 for comparison), indicating that the Hessian is not positive definite.
This result harmonizes with other work analyzing the spectra of the Hessian for DNN training
[37, 57]. Further, Garipov et al. [17] and Draxler et al. [13] argue that the loss surfaces of DNNs
have directions along which the loss is completely flat, suggesting that the loss is nowhere near a
positive-definite quadratic form.
16
Train loss Train loss Train loss
PreResNet-164 CIFAR-100 PreResNet-164 CIFAR-100 PreResNet-164 CIFAR-100
1.6 80 >5 >5
1.4 5
60 5 40
1.2
40 1.7 1.9
1.0 20
Train loss
20 0.65 0.75
0.8
v4
0 0.34
v2
0.6
0 0.27
20 0.58 10 0.67
0.8
v4
0 0.25
v2
0.6
0 0.2
0.4 0.57
0.8
v2
v4
10 0.4 5 0.56
0.8
v2
v4
0.6
0 0.14 0 0.21
Figure 6: Left: Posterior-density cross-sections along the rays corresponding to different eigenvec-
tors of the SWAG covariance matrix. Middle: Posterior-density surface in the plane spanned by
eigenvectors of SWAG covariance matrix corresponding to the first and second largest eigenvalues
and (Right:) the third and fourth largest eigenvalues. Each row in the figure corresponds to an
architecture-dataset pair indicated in the title of each panel.
17
PreResNet56 CIFAR100 WideResNet28x10 CIFAR100 WideResNet28x10 CIFAR100 WideResNet28x10 CIFAR100
0.770 0.90
0.67 0.824 Scale
0.85
1.25 0.5
0.66 0.822 1 0.25
0.765
1/sqrt(2) 0 (SWA)
0.65 0.820 0.80
Mean NLL
Mean NLL
Accuracy SWAG
Accuracy
0.64 0.818
0.760 SWAG-Diag 0.75
Ensemble of SGD Iterates - SWA SWA
0.63 0.816
Ensemble of SGD Iterates - SGD
0.70
SWAG - 30 Samples 0.814
0.62
0.755
SWAG (rank 20) - 30 Samples
SWAG
SWA 0.812
SWAG-Diag 0.65
0.61
Rank 140 SWA
0.750 0.60 0.810 0.60
175 200 225 250 275 300 325 0 20 40 60 80 100 0 20 40 60 80 100 0 20 40 60 80 100
Epoch Number of Samples Number of Samples Number of Samples
We now evaluate the effect of the covariance matrix rank on the SWAG approximation. To do so,
we trained a PreResNet56 on CIFAR-100 with SWAG beginning from epoch 161, and evaluated 30
sample Bayesian model averages obtained at different epochs; the accuracy plot from this experiment
is shown in Figure 7 (a). The rank of each model after epoch 161 is simply min(epoch − 161, 140),
and 30 samples from even a low rank approximation reach the same predictive accuracy as the SWA
model. Interestingly, both SWAG and SWA outperform ensembles of a SGD run and ensembles of
the SGD models in the SWA run.
In
R most situations where SWAG will be used, no closed form expression for the integral
f (y)q(θ|y)dθ, will exist. Thus,
√ Monte Carlo approximations will be used; Monte Carlo inte-
gration converges at a rate of 1/ K, where K is the number of samples used, but practically good
results may be found with very few samples (e.g. Chapter 29 of MacKay [42]).
To test how many samples are needed for good predictive accuracy in a Bayesian model averaging
task, we used a rank 20 approximation for SWAG and then tested the NLL on the test set as a function
of the number of samples for WideResNet28x10 [63] on CIFAR-100.
The results from this experiment are shown in Figure 7 (b, c), where it is possible to see that about 3
samples will match the SWA result for NLL, with about 30 samples necessary for stable accuracy
(about the same as SWA for this network). In most of our experiments, we used 30 samples for
consistency. In practice, we suggest tuning this number by looking at a validation set as well as the
computational resources available and comparing to the free SWA predictions that come with SWAG.
First, we note that the covariance, Σ, estimated using SWAG, is a function of the learning rate (and
momentum) for SGD. While the theoretical work of Mandt et al. [43] suggests that it is possible to
optimally set the learning rate, our experiments in Appendix B show that currently the assumptions
of the theory do not match the empirical reality in deep learning. In practice the learning rate can
be chosen to maximize negative log-likelihood on a validation set. In the linear setting as in Mandt
et al. [43], the learning rate controls the scale of the asymptotic covariance matrix. If the optimal
18
learning rate (Equation 2) is used in this setting, the covariance matches the true posterior. To attempt
to disassociate the learning rate from the covariance in practice, we rescale the covariance matrix
when sampling by a constant factor for a WideResNet on CIFAR-100 shown in Figure 7 (d).
Over several replications, we found that a scale of 0.5 worked best, which is expected because the
low rank plus diagonal covariance incorporates the variance twice (once for the diagonal and once
from the low rank component).
One possible slowdown of SWAG at inference time is in the usage of updated batch norm parameters.
Following Izmailov et al. [27], we found that in order for the averaging and sampling to work well, it
was necessary to update the batch norm parameters of networks after sampling a new model. This is
shown in Figure 8 for a WideResNet on CIFAR-100 for two independently trained models.
WideResNet28x10 CIFAR100
2.0
No Batch Norm
1.8 Batch Norm
1.6
Mean NLL
1.4
1.2
1.0
0.8
0.6
0 5 10 15 20 25 30
Number of Samples
Figure 8: NLL by number of samples for SWAG with and without batch norm updates after sampling.
Updating the batch norm parameters after sampling results in a significant improvement in NLL.
From our experimental findings, we see that given an equal amount of training time, SWAG typically
outperforms other methods for uncertainty calibration. SWAG additionally does not require a valida-
tion set like temperature scaling and Platt scaling (e.g. Guo et al. [19], Kuleshov et al. [35]). SWAG
also appears to have a distinct advantage over temperature scaling, and other popular alternatives,
when the target data are from a different distribution than the training data, as shown by our transfer
learning experiments.
Deep ensembles [36] require several times longer training for equal calibration, but often perform
somewhat better due to incorporating several independent training runs. Thus SWAG will be
particularly valuable when training time is limited, but inference time may not be. One possible
application is thus in medical applications when image sizes (for semantic segmentation) are large,
but predictions can be parallelized and may not have to be instantaneous.
We provide the additional reliability diagrams for all methods and datasets in Figure 9. SWAG
consistently improves calibration over SWA, and performs on par or better than temperature scaling.
In transfer learning temperature scaling fails to achieve good calibration, while SWAG still provides
a significant improvement over SWA.
Next, we evaluate the SWAG variants along with the baselines on out-of-domain data detection.
To do so we train a WideResNet as described in Section H on the data from five classes of the
19
VGG-16 CIFAR-100 PreResNet-164 CIFAR-100 WideResNet28x10 CIFAR-100
0.20
KFAC-Laplace 0.25 KFAC-Laplace KFAC-Laplace
0.40 0.15
SGD 0.20 SGD SGD
Confidence - Accuracy
Confidence - Accuracy
Confidence - Accuracy
SWA-Drop SWA-Temp SWA-Drop
0.30 0.15 0.10
SWA-Temp SWAG SWA-Temp
SWAG 0.10 SWAG-Diag 0.05 SWAG
0.20
SWAG-Diag 0.05 SWAG-Diag
0.10 0.00
0.00
-0.05
0.00 -0.05
-0.10 -0.10
-0.10
0.200 0.759 0.927 0.978 0.993 0.998 0.200 0.759 0.927 0.978 0.993 0.998 0.200 0.759 0.927 0.978 0.993 0.998
Confidence (max prob) Confidence (max prob) Confidence (max prob)
VGG-16 CIFAR-10 PreResNet-164 CIFAR-10 WideResNet28x10 CIFAR-10
0.12
0.14
0.30 KFAC-Laplace KFAC-Laplace KFAC-Laplace
SGD 0.12 SGD 0.10 SGD
Confidence - Accuracy
Confidence - Accuracy
Confidence - Accuracy
0.25 SWA-Drop SWA-Temp SWA-Drop
0.10 0.08
SWA-Temp SWAG SWA-Temp
0.20 0.08 0.05
SWAG SWAG-Diag SWAG
0.15 SWAG-Diag 0.06 SWAG-Diag
0.02
0.04
0.10 0.00
0.02
0.05 -0.03
0.00
-0.05
0.00 -0.02
0.200 0.759 0.927 0.978 0.993 0.998 0.200 0.759 0.927 0.978 0.993 0.998 0.200 0.759 0.927 0.978 0.993 0.998
Confidence (max prob) Confidence (max prob) Confidence (max prob)
VGG-16 CIFAR-10 → STL-10 PreResNet-164 CIFAR-10 → STL-10 WideResNet28x10 CIFAR-10 → STL-10
0.50 0.40
SGD 0.40 SGD SGD
0.35
SWA-Drop SWA-Temp SWA-Drop
0.40
Confidence - Accuracy
Confidence - Accuracy
Confidence - Accuracy
SWA-Temp SWAG 0.30 SWA-Temp
0.30
SWAG SWAG-Diag 0.25 SWAG
0.30
SWAG-Diag SWAG-Diag
0.20
0.20
0.20 0.15
0.10 0.10
0.10
0.05
Confidence - Accuracy
0.08 SWAG 0.08 SWAG
0.05 SWAG-Diag SWAG-Diag
0.05
0.02 0.02
0.00 0.00
-0.02
-0.03
-0.05
-0.05
-0.08
0.200 0.759 0.927 0.978 0.993 0.998 0.200 0.759 0.927 0.978 0.993 0.998
Confidence (max prob) Confidence (max prob)
Figure 9: Reliability diagrams (see Section 5.1) for all models and datasets. The dataset and
architecture are listed in the title of each panel.
0.6 0.6
1.00
0.4
0.3
0.4
0.75 0.4
0.4 0.3
0.2
0.50 0.2
0.2
0.2 0.2
0.1
0.25 0.1
10−7 10−5 10−3 10−1 100 10−7 10−5 10−3 10−1 100 10−7 10−5 10−3 10−1 100 10−7 10−5 10−3 10−1 100 10−7 10−5 10−3 10−1 100 10−7 10−5 10−3 10−1 100
Entropy Entropy Entropy Entropy Entropy Entropy
Figure 10: In and out of sample entropy distributions for WideResNet28x10 on CIFAR5 + 5.
Table 2: Symmetrized, discretized KL divergence between the distributions of predictive entropies for
data from the first and last five classes of CIFAR-10 for models trained only on the first five classes.
The entropy distributions for SWAG are more different than the baseline models.
Method JS-Distance
SWAG 3.31
SWAG-Diag 2.27
MC Dropout 3.04
SWA 1.68
SGD (Baseline) 3.14
SGD + Temp. Scaling 2.98
CIFAR-10 dataset, and then analyze their predictions on the full test set. We expect the outputted
class probabilities on objects that belong to classes that were not present in the training data to have
high-entropy reflecting the model’s high uncertainty in its predictions, and considerably lower entropy
on the images that are similar to those on which the network was trained.
To make this comparison quantitative, we computed the symmetrized KL divergence between the
binned in and out of sample distributions in Table 2, finding that SWAG and Dropout perform best
on this measure. We plot the histograms of predictive entropies on the in-domain (classes that were
20
Table 3: ECE for various versions of SWAG, temperature scaling, and MC Dropout on CIFAR-10
and CIFAR-100.
CIFAR-10 CIFAR-10 CIFAR-10 CIFAR-100 CIFAR-100 CIFAR-100
Model VGG-16 PreResNet-164 WideResNet28x10 VGG-16 PreResNet-164 WideResNet28x10
SGD 0.0483 ± 0.0022 0.0255 ± 0.0009 0.0166 ± 0.0007 0.1870 ± 0.0014 0.1012 ± 0.0009 0.0479 ± 0.0010
SWA 0.0408 ± 0.0019 0.0203 ± 0.0010 0.0087 ± 0.0002 0.1514 ± 0.0032 0.0700 ± 0.0056 0.0684 ± 0.0022
SWAG-Diag 0.0267 ± 0.0025 0.0082 ± 0.0008 0.0047 ± 0.0013 0.0819 ± 0.0021 0.0239 ± 0.0047 0.0322 ± 0.0018
SWAG 0.0158 ± 0.0030 0.0053 ± 0.0004 0.0088 ± 0.0006 0.0395 ± 0.0061 0.0587 ± 0.0048 0.0113 ± 0.0020
KFAC-Laplace 0.0094 ± 0.0005 0.0092 ± 0.0018 0.0060 ± 0.0003 0.0778 ± 0.0054 0.0158 ± 0.0014 0.0379 ± 0.0047
SWA-Dropout 0.0284 ± 0.0036 0.0162 ± 0.0000 0.0094 ± 0.0014 0.1108 ± 0.0181 * 0.0574 ± 0.0028
SWA-Temp 0.0366 ± 0.0063 0.0172 ± 0.0010 0.0080 ± 0.0007 0.0291 ± 0.0097 0.0175 ± 0.0037 0.0220 ± 0.0007
SGLD 0.0082 ± 0.0012 0.0251 ± 0.0012 0.0192 ± 0.0007 0.0424 ± 0.0029 0.0363 ± 0.0008 0.0296 ± 0.0008
trained on) and out-of-domain (classes that were not trained on) in Figure A.10 for a qualitative
comparison.
Table 2 shows the computed symmetrized, discretized KL distance between in and out of sample
distributions for the CIFAR5 out of sample image detection class. We used the same bins as in
Figure 10 to discretize the entropy distributions, then smoothed these bins by a factor of 1e-7 before
calculating KL(IN||OUT) + KL(OUT||IN) using the scipy.stats.entropy function. We can
see even qualitatively that the distributions are more distinct for SWAG and SWAG-Diagonal than for
the other methods, particularly temperature scaling.
We provide test accuracies (Tables 9,10,11) and negative log-likelihoods (NLL) (Tables 6,7,8) all
methods and datasets. We observe that SWAG is competitive with SWA, SWA with temperature
scaling and SWA-Dropout in terms of test accuracy, and typically outperforms all the baselines in
terms of NLL. SWAG-Diagonal is generally inferior to SWAG for log-likelihood, but outperforms
SWA.
In Tables 3,4,5 we additionally report expected calibration error [ECE, 48], a metric of calibration of
the predictive uncertainties. To compute ECE for a given model we split the test points into 20 bins
based on the confidence of the model, and we compute the absolute value of the difference of the
average confidence and accuracy within each bin, and average the obtained values over all bins. Please
refer to [48, 19] for more details. We observe that SWAG is competitive with temperature scaling
for ECE. Again, SWAG-Diagonal achieves better calibration than SWA, but using the low-rank plus
diagonal covariance approximation in SWAG leads to substantially improved performance.
21
Table 6: NLL on CIFAR10 and CIFAR100.
Dataset CIFAR-10 CIFAR-100
Model VGG-16 PreResNet-164 WideResNet28x10 VGG-16 PreResNet-164 WideResNet28x10
SGD 0.3285 ± 0.0139 0.1814 ± 0.0025 0.1294 ± 0.0022 1.7308 ± 0.0137 0.9465 ± 0.0191 0.7958 ± 0.0089
SWA 0.2621 ± 0.0104 0.1450 ± 0.0042 0.1075 ± 0.0004 1.2780 ± 0.0051 0.7370 ± 0.0265 0.6684 ± 0.0034
SWAG-Diag 0.2200 ± 0.0078 0.1251 ± 0.0029 0.1077 ± 0.0009 1.0163 ± 0.0032 0.6837 ± 0.0186 0.6150 ± 0.0029
SWAG 0.2016 ± 0.0031 0.1232 ± 0.0022 0.1122 ± 0.0009 0.9480 ± 0.0038 0.6595 ± 0.0019 0.6078 ± 0.0006
KFAC-Laplace 0.2252 ± 0.0032 0.1471 ± 0.0012 0.1210 ± 0.0020 1.1915 ± 0.0199 0.7881 ± 0.0025 0.7692 ± 0.0092
SWA-Dropout 0.2328 ± 0.0049 0.1270 ± 0.0000 0.1094 ± 0.0021 1.1872 ± 0.0524 0.6500 ± 0.0049
SWA-Temp 0.2481 ± 0.0245 0.1347 ± 0.0038 0.1064 ± 0.0004 1.0386 ± 0.0126 0.6770 ± 0.0191 0.6134 ± 0.0023
SGLD 0.2001 ± 0.0059 0.1418 ± 0.0005 0.1289 ± 0.0009 0.9699 ± 0.0057 0.6981 ± 0.0052 0.678 ± 0.0022
SGD-Ens 0.1881 ± 0.002 0.1312 ± 0.0023 0.1855 ± 0.0014 0.8979 ± 0.0065 0.7839 ± 0.0046 0.7655 ± 0.0026
F Language Modeling
We evaluate SWAG using standard Penn Treebank and WikiText-2 benchmark language modeling
datasets. Following [44] we use a 3-layer LSTM model with 1150 units in the hidden layer and an
embedding of size 400; we apply dropout, weight-tying, activation regularization (AR) and temporal
22
Table 11: Accuracy when transferring from CIFAR-10 to STL-10.
Table 12: Unnormalized test log-likelihoods on small UCI datasets for proposed methods, as well
as direct comparisons to the numbers reported in deterministic variational inference (DVI, Wu et al.
[60]) and Deep Gaussian Processes with expectation propagation (DGP1-50, Bui et al. [7]), and
variational inference (VI) with the re-parameterization trick [32]. * denotes reproduction from [60].
Note that SWAG wins on two of the six datasets, and that SGD serves as a strong baseline throughout.
dataset N D SGD SWAG DVI* DGP1-50* VI* SGLD* PBP*
boston 506 13 -2.536 ± 0.240 -2.469 ± 0.183 -2.41 ± 0.02 -2.33 ± 0.06 -2.43 ±0.03 -2.40 ± 0.05 -2.57 ± 0.09
concrete 1030 8 -3.02 ± 0.126 -3.05 ± 0.1 -3.06 ± 0.01 -3.13 ± 0.03 -3.04 ±0.02 -3.08 ± 0.03 -3.16 ± 0.02
energy 768 8 -1.736 ± 1.613 -1.679 ± 1.488 -1.01 ± 0.06 -1.32 ± 0.03 -2.38 ±0.02 -2.39 ± 0.01 -2.04 ± 0.02
naval 11934 16 6.567 ± 0.185 6.708 ± 0.105 6.29 ± 0.04 3.60 ± 0.33 5.87 ±0.29 3.33 ± 0.01 3.73 ± 0.01
yacht 308 6 -0.418 ± 0.426 -0.404 ± 0.418 -0.47 ± 0.03 -1.39 ± 0.14 -1.68 ±0.04 -2.90 ± 0.01 -1.63 ± 0.02
power 9568 4 -2.772 ± 0.04 -2.775 ± 0.038 -2.80 ± 0.00 -2.81 ± 0.01 -2.66 ± 0.01 -2.67 ± 0.00 -2.84 ± 0.01
activation regularization (TAR) techniques. We follow [44] for specific hyper-parameter settings such
as dropout rates for different types of layers. We train all models for language modeling tasks and
evaluate validation and test perplexity. For SWA and SWAG we pre-train the models using standard
SGD for 500 epochs, and then run the model for 100 more epochs to estimate the mean θSWA and
covariance Σ in SWAG. For this experiment we introduce a small change to SWA and SWAG: to
estimate the mean θSWA we average weights after each mini-batch of data rather than once per epoch,
as we found more frequent averaging to greatly improve performance. After SWAG distribution
is constructed we sample and ensemble 30 models from this distribution. We use rank-10 for the
low-rank part of the covariance matrix of SWAG distribution.
G Regression
For the small UCI regression datasets, we use the architecture from Wu et al. [60] with one hidden
layer with 50 units, training for 50 epochs (starting SWAG at epoch 25) and using 20 repetitions of
90/10 train test splits. We fixed a single seed for tuning before using 20 different seeds for the results
in the paper.
We use SGD8 , manually tune learning rate and weight decay, and use batch size of N/10 where
N is the dataset size. All models predict heteroscedastic uncertainty (i.e. output a variance). In
Table 12, we compare subspace inference methods to deterministic VI (DVI, Wu et al. [60]) and deep
Gaussian processes with expectation propagation (DGP1-50 Bui et al. [7]). SWAG outperforms DVI
and the other methods on three of the six datasets and is competitive on the other three despite its
vastly reduced computational time (the same as SGD whereas DVI is known to be 300x slower).
Additionally, we note the strong performance of well-tuned SGD as a baseline against the other
approximate inference methods, as it consistently performs nearly as well as both SWAG and DVI.
Finally, in Table 12, we compare the calibration (coverage of the 95% credible sets of SWAG and 95%
confidence regions of SGD) of both SWAG and SGD. Note that neither is ever too over-confident (far
beneath 95% coverage) and that SWAG is considerably better calibrated on four of the six datasets.
23
Table 13: Calibration on small-scale UCI datasets. Bolded numbers are those closest to 0.95 %the
predicted coverage).
N D SGD SWAG
boston 506 13 0.913 ± 0.039 0.936 ± 0.036
concrete 1030 8 0.909 ± 0.032 0.930 ± 0.023
energy 768 8 0.947 ± 0.026 0.951 ± 0.027
naval 11934 16 0.948 ± 0.051 0.967 ± 0.008
yacht 308 6 0.895 ± 0.069 0.898 ± 0.067
power 9568 4 0.956 ± 0.006 0.957 ± 0.005
• VGG-16: https://github.com/pytorch/vision/blob/master/torchvision/models/
vgg.py
• Preactivation-ResNet-164: https://github.com/bearpaw/pytorch-classification/
blob/master/models/cifar/preresnet.py
• WideResNet28x10: https://github.com/meliketoy/wide-resnet.pytorch/blob/
master/networks/wide_resnet.py
For all datasets and architectures we use the same piecewise constant learning rate schedule and
weight decay as in Izmailov et al. [27], except we train Pre-ResNet for 300 epochs and start averaging
after epoch 160 in SWAG and SWA. For all of the methods we are using our own implementations in
PyTorch. We describe the hyper-parameters for all experiments for each model:
SWA We use the same hyper-parameters as Izmailov et al. [27] on CIFAR datasets. On ImageNet
we used a constant learning rate of 10−3 instead of the cyclical schedule, and averaged 4 models per
epoch. We adapt the code from https://github.com/timgaripov/swa for our implementation of
SWA.
SWAG In all experiments we use rank K = 20 and use 30 weight samples for Bayesian model
averaging. We re-use all the other hyper-parameters from SWA.
KFAC-Laplace For our implementation we adapt the code for KFAC Fisher approximation from
https://github.com/Thrandis/EKFAC-pytorch and implement our own code for sampling. Follow-
ing [54] we tune the scale of the approximation on validation set for every model and dataset.
MC-Dropout In order to implement MC-dropout we add dropout layers before each weight layer
and sample 30 different dropout masks for Bayesian model averaging at inference time. To choose
the dropout rate, we ran the models with dropout rates in the set {0.1, 0.05, 0.01} and chose the one
that performed best on validation data. For both VGG-16 and WideResNet28x10 we found that
dropout rate of 0.05 worked best and used it in all experiments. On PreResNet-164 we couldn’t
achieve reasonable performance with any of the three dropout rates, which has been reported from
the work of He et al. [22]. We report the results for MC-Dropout in combination with both SWA
(SWA-Drop) and SGD (SGD-Drop) training.
Temperature Scaling For SWA and SGD solutions we picked the optimal temperature by min-
imizing negative log-likelihood on validation data, adapting the code from https://github.com/
gpleiss/temperature_scaling.
SGLD We initialize SGLD from checkpoints pre-trained with SGD. We run SGLD for 100 epochs
on WideResNet and for 150 epochs on PreResNet-156. We use the learning rate schedule of [59]:
η0
ηt = .
(η1 + t)0.55
24
We tune constants a, b on validation. For WideResNet we use a = 38.0348, b = 13928.7 and for
PreResNet we use a = 40.304, b = 15476.4; these values are selected so that the initial learning
rate is 0.2 and final learning rate is 0.1. We also had to rescale the noise in the gradients by a factor
of 5 · 10−4 compared to [59]. Without this rescaling we found that even with learning rates on the
scale of 10−7 SGD diverged. We note that noise rescaling is commonly used with stochastic gradient
MCMC methods (see e.g. the implementation of [65]).
On CIFAR datasets for tuning hyper-parameters we used the last 5000 training data points as a
validation set. On ImageNet we used 5000 of test data points for validation. On the transfer task for
CIFAR10 to STL10, we report accuracy on all 10 STL10 classes even though frogs are not a part of
the STL10 test set (and monkeys are not a part of the CIFAR10 training set).
25