Black Box Variational Inference
Black Box Variational Inference
ρt = ηdiag(Gt )−1/2 . (11) Our data consist of longitudinal data from 976 patients
(803 train + 173 test) from a clinic at New York Presby-
terian hospital who have been diagnosed with chronic
This is a per-component learning rate since diag(Gt ) kidney disease. These patients visited the clinic a to-
has the same dimension as the gradient. Note that tal of 33K times. During each visit, a subset of 17
since AdaGrad only uses the diagonal of Gt , those measurements (labs) were measured.
are the only elements we need to compute. AdaGrad
captures noise and varying length scales through the The data are observational and consist of measurements
square of the noisy gradient and reduces the number (lab values) taken at the doctor’s discretion when the
of parameters to our algorithm from the standard two patient is at a checkup. This means both that the
parameter Robbins-Monro learning rate. labs at each time step are sparse and that the time
between patient visits are highly irregular. The labs
values are all positive as the labs measure the amount
4.2 Stochastic Inference in Hierarchical of a particular quantity such as sodium concentration
Bayesian Models in the blood.
Stochastic optimization has also been used to scale Our modeling goal is to come up with a low dimen-
variational inference in hierarchical Bayesian models to sional summarization of patients’ labs at each of their
massive data (Hoffman et al., 2013). The basic idea is visits. From this, we aim to to find latent factors that
to subsample observations to compute noisy gradients. summarize each visit as positive random variables. As
We can use a similar idea to scale our method. in medical data applications, we want our factors to
be latent indicators of patient health.
In a hierarchical Bayesian model, we have a hyper-
parameter η, global latent variables β, local latent We evaluate our model using predictive likelihood. To
variables z1...n , and observations x1...n having the log compute predictive likelihoods, we need an approximate
joint distribution posterior on both the global parameters and the per
visit parameters. We use the approximate posterior on
log p(x1...n , z1...n , β) = log p(β|η) the global parameters and calculate the approximate
n
posterior on the local parameters on 75% of the data in
the test set. We then calculate the predictive likelihood
X
+ log p(zi |β) + log p(xi |zi , β).
i=1
on the other 25% of the data in the validation set using
(12) Monte Carlo samples from the approximate posterior.
We initialize randomly and choose the variational fami-
This is the same definition as in Hoffman et al. (2013), lies to be fully-factorized with gamma distributions for
but they place further restrictions on the forms of the positive variables and normals for real valued variables.
distributions and the complete conditionals. Under the We use both the AdaGrad and doubly stochastic ex-
mean field approximating family, applying Eq. 10 to tensions on top of our base algorithm. We use 1,000
construct noisy gradients of the ELBO would require samples from the variational distribution and set the
iterating over every datapoint. Instead we can compute batch size at 25 in all our experiments.
noisy gradients using a sampled observation and sam-
ples from the variational distribution. The derivation 5.2 Model
along with variance reductions can be found in the
supplement. To meet our goals, we construct a Gamma-Normal
time series model. We model our data using weights
drawn from a Normal distribution and observations
5 Empirical Study drawn from a Normal, allowing each factor to both
positively and negative affect each lab while letting
We use Black Box Variational Inference to quickly factors represent lab measurements. The generative
construct and evaluate several models on longitudinal process for this model with hyperparameters denoted
medical data. We demonstrate the effectiveness of our with σ is
variance reduction methods and compare the speed
and predictive likelihood of our algorithm to sampling Draw W ∼ Normal(0, σw ), an L × K matrix
based methods. We evaluate the various models using For each patient p: 1 to P
predictive likelihood and demonstrate the ease at which Draw op ∼Normal(0, σo ), a vector of L
several models can be explored. Define xp0 = α0
Metropolis-Hastings works by sampling from a proposal
0
distribution and accepting or rejecting the samples
based on the likelihood. Standard Metropolis-Hastings
Held Out Log Predictive Likelihood
0 5 10 15 20
Time (in hours) On this model, we compared Black Box Variational In-
ference to Metropolis-Hastings inside Gibbs. We used
Figure 1: Comparison between Metropolis-Hastings
a fixed computational budget of 20 hours. Figure 1
within Gibbs and Black Box Variational Inference. In
plots time versus predictive likelihood on the held out
the x axis is time and in the y axis is the predictive
set for both methods. We found similar results with
likelihood of the test set. Black Box Variational Infer-
different random initializations of both models. Black
ence reaches better predictive likelihoods faster than
Box Variational Inference gives better predictive likeli-
Gibbs sampling. The Gibbs sampler’s progress slows
hoods and gets them faster than Metropolis-Hastings
considerably after 5 hours.
within Gibbs.3 .
Basic
Rao−Blackwell patient health dataset. We find that taking into account
Rao−Blackwell+CV
the longitudinal nature of the data in the model leads
to a better fit. The Gamma weight models perform
relatively poorly. This is likely due to the fact that
1e+08
Gamma-TS, and Gamma-Normal. We set the AdaGrad scaling parameter to 1 for both
the Gamma-Normal models and to .5 for the Gamma
models.
Gamma. We model the latent factors that summa-
rize each visit in our models as positive random vari- Model Comparisons. Table 1 details our models
ables; as noted above, we expect these to be indicative along with their predictive likelihoods. From this we
of patient health. The Gamma model is a positive- see that time helps in modelling our longitudinal health-
value factor model where all of the factors, weights, care data. We also see that the Gamma-Gamma mod-
and observations have positive values. The generative els perform poorly. This is likely because they cannot
process for this model is capture the negative correlations that exist between dif-
ferent medical labs. More importantly, by using Black
Draw W ∼ Gamma(αw , βw ), an L × K matrix Box Variational Inference we were able to quickly fit
For each patient p: 1 to P and explore a set of complicated non-conjugate mod-
Draw op ∼Gamma(αo , βo ), a vector of L els. Without a generic algorithm, approximating the
For each visit v: 1 to vp posterior of any of these models is a project in itself.
Draw xpv ∼Gamma(αx , βx )
Draw lpv ∼GammaE(W xpv + op , σo ), a vector of L.
6 Conclusion
We set all hyperparameters save σo to be 1. As in the
previous model, σo is set to .01. We developed and studied Black Box Variational In-
ference, a new algorithm for variational inference that
drastically reduces the analytic burden. Our main ap-
Gamma-TS. We can link the factors through time
proach is a stochastic optimization of the ELBO by
using the expectation parameterization of the Gamma
sampling from the variational posterior to compute a
distribution. (Note this is harder with the usual natural
noisy gradient. Essential to its success are model-free
parameterization of the Gamma.) This changes xpv to
variance reductions to reduce the variance of the noisy
be distributed as GammaE(xpv−1 , σv ). We draw xp1
gradient. Black Box Variational Inference works well
as above. In this model, the expected values of the
for new models, while requiring minimal analytic work
factors at the next visit is the same as the value at
by the practitioner.
the current visit. This allows us to propagate patient
states through time. There are several natural directions for future improve-
ments to this work. First, the software libraries that References
we provide can be augmented with score functions for a
wider variety of variational families (each score function C. Bishop. Pattern Recognition and Machine Learning.
is simply the log gradient of the variational distribution Springer New York., 2006.
with respect to the variational parameters). Second, D. Blei and J. Lafferty. A correlated topic model of
we believe that number of samples could be set dy- Science. Annals of Applied Statistics, 1(1):17–35,
namically. Finally, carefully-selected samples from the 2007.
variational distribution (e.g., with quasi-Monte Carlo
methods) are likely to significantly decrease sampling L. Bottou and Y. LeCun. Large scale online learning. In
variance. Advances in Neural Information Processing Systems,
2004.
M. Braun and J. McAuliffe. Variational inference for
7 Appendix: The Gradient of the
large-scale models of discrete choice. Journal of
ELBO American Statistical Association, 105(489), 2007.
The key idea behind our algorithm is that the gradient Peter Carbonetto, Matthew King, and Firas Hamze.
of the ELBO can be written as an expectation with A stochastic approximation method for inference
respect to the variational distribution. We start by in probabilistic graphical models. In Y. Bengio,
differentiating Eq. 1, D. Schuurmans, J. Lafferty, C. K. I. Williams, and
A. Culotta, editors, Advances in Neural Information
Processing Systems 22, pages 216–224. 2009.
Z
∇λ L = ∇λ (log p(x, z) − log q(z|λ))q(z|λ)dz
Z George Casella and Christian P Robert. Rao-
blackwellisation of sampling schemes. Biometrika,
= ∇λ [(log p(x, z) − log q(z|λ))q(z|λ)]dz
83(1):81–94, 1996.
Z
= ∇λ [log p(x, z) − log q(z|λ)]q(z|λ)dz E. Cinlar. Probability and Stochastics. Springer, 2011.
Z D. R. Cox and D.V. Hinkley. Theoretical Statistics.
+ ∇λ q(z|λ)(log p(x, z) − log q(z|λ))dz Chapman and Hall, 1979.
John Duchi, Elad Hazan, and Yoram Singer. Adaptive
= −Eq [log q(z|λ)] (13)
Z subgradient methods for online learning and stochas-
+ ∇λ q(z|λ)(log p(x, z) − log q(z|λ))dz, tic optimization. J. Mach. Learn. Res., 12:2121–2159,
July 2011. ISSN 1532-4435.
where we have exchanged derivatives with integrals via Z. Ghahramani and M. Beal. Propagation algorithms
the dominated convergence theorem 4 (Cinlar, 2011) for variational Bayesian learning. In NIPS 13, pages
and used ∇λ [log p(x, z)] = 0. 507–513, 2001.
The first term in Eq. 13 is zero. To see this, note M. Hoffman, D. Blei, C. Wang, and J. Paisley. Stochas-
tic variational inference. Journal of Machine Learn-
ing Research, 14(1303–1347), 2013.
Z
∇λ q(z|λ)
Eq [∇λ log q(z|λ)] = Eq = ∇λ q(z|λ)dz
q(z|λ) T. Jaakkola and M. Jordan. A variational approach
to Bayesian logistic regression models and their ex-
Z
= ∇λ q(z|λ)dz = ∇λ 1 = 0. (14) tensions. In International Workshop on Artificial
Intelligence and Statistics, 1996.
To simplify the second term, first observe that M. Jordan, Z. Ghahramani, T. Jaakkola, and L. Saul.
∇λ [q(z|λ)] = ∇λ [log q(z|λ)]q(z|λ). This fact gives us Introduction to variational methods for graphical
the gradient as an expectation, models. Machine Learning, 37:183–233, 1999.
Z D. Kingma and M. Welling. Auto-encoding variational
∇λ L = ∇λ [q(z|λ)](log p(x, z) − log q(z|λ))dz bayes. ArXiv e-prints, December 2013.
Z D. Knowles and T. Minka. Non-conjugate variational
= ∇λ log q(z|λ)(log p(x, z) message passing for multinomial and binary regres-
sion. In Advances in Neural Information Processing
− log q(z|λ))q(z|λ)dz
Systems, 2011.
= Eq [∇λ log q(z|λ)(log p(x, z) − log q(z|λ))],
H. Kushner and G. Yin. Stochastic Approximation
4
The score function exists. The score and likelihoods Algorithms and Applications. Springer New York,
are bounded. 1997.
J. Paisley, D. Blei, and M. Jordan. Variational Bayesian Recall the definitions from Section 3 where we defined
inference with stochastic search. In International ∇λi L as the gradient of the ELBO with respect to
Conference on Machine Learning, 2012. λi , pi as the components of the log joint that include
H. Robbins and S. Monro. A stochastic approximation terms form the ith factor, and Eq(i) as the expectation
method. The Annals of Mathematical Statistics, 22 with respect to the set of latent variables that appear
(3):pp. 400–407, 1951. in the complete conditional for zi . Let p−i bet the
components of the joint that does not include terms
S. M. Ross. Simulation. Elsevier, 2002. from the ith factor respectively. We can write the
T. Salimans and D Knowles. Fixed-form variational gradient with respect to the ith factor’s variational
approximation through stochastic linear regression. parameters as
ArXiv e-prints, August 2012. ∇λ i L
M. Wainwright and M. Jordan. Graphical models, ex- =Eq1 . . . Eqn [∇λi log qi (zi |λi )(log p(x, z)
ponential families, and variational inference. Founda- X n
tions and Trends in Machine Learning, 1(1–2):1–305, − log qj (zj |λj ))]
2008. j=1
C. Wang and D. M. Blei. Variational inference for =Eq1 . . . Eqn [∇λi log qi (zi |λi )(log pi (x, z)
nonconjutate models. JMLR, 2013. X n
+ log p−i (x, z) − log qj (zj |λj ))]
D. Wingate and T Weber. Automated variational
j=1
inference in probabilistic programming. ArXiv e-
prints, January 2013. =Eqi [∇λi log qi (zi |λi )(Eq−i [log pi (x, z(i) )]
− log qi (zi |λi ) + Eq−i [log p−i (x, z)
Xn
Derivation of the Rao-Blackwellized Gradient =Eqi [∇λi log qi (zi |λi )(Eq−i [log pi (x, z)]
To compute the Rao-Blackwellized estimators, we need − log qi (zi |λi ) + Ci )]
to compute conditional expectations. Due to the mean =Eqi [∇λi log qi (zi |λi )(Eq−i [log pi (x, z(i) )]
field-assumption, the conditional expectation simplifies
− log qi (zi |λi ))]
due to the factorization
R =Eq(i) [∇λi log qi (zi |λi )(log pi (x, z(i) ) − log qi (zi |λi ))].
J(x, y)p(x)p(x)dy (A.17)
E[J(X, Y )|X] = R
p(x)p(y)dy
Z where we have leveraged the mean field assumption
= J(x, y)p(y)dy = Ey [J(x, y)]. and made use of the identity for the expected score Eq.
14. This means we can Rao-Blackwellize the gradient
(A.15) of the variational parameter λi with respect to the the
latent variables outside of the Markov blanket of zi
Therefore, to construct a lower variance estimator when without needing model specific computations.
the joint distribution factorizes, all we need to do is
integrate out some variables. In our problem this means Derivation of Stochastic Inference in Hierarchi-
for each component of the gradient, we should compute cal Bayesian Models Recall the definition of a hi-
expectations with respect to the other factors. We erarchical Bayesian model with n observations given in
present the estimator in the full mean field family of Eq. 12
variational distributions, but note it applies to any
variational approximation with some factorization like logp(x1...n , z1...n , β)
structured mean-field. X n
= log p(β|η) + log p(zi |β) + log p(xi , |zi , β).
Thus, under the mean field assumption the Rao- i=1
Blackwellized estimator for the gradient becomes
Let the variational approximation for the posterior
n
X distribution be from the mean field family. Let λ be
∇λ L =Eq1 . . . Eqn [ ∇λ log qj (zj |λj )(log p(x, z) the global variational parameter and let φ1...n be the
j=1 local variational parameters. The variational family is
n
X m
Y
− log qj (zj |λj ))]. (A.16) q(β, z1...n ) = q(β|λ) q(zi |φi ). (A.18)
j=1 i=1
Using the Rao Blackwellized estimator to compute iterate over all of the observations at each update
noisy gradients in this family for this model gives
S
ˆ λL = 1
X
∇ ∇λ log q(βs |λ)(log p(βs |η) − log q(βs |λ)
S
S i=1
ˆ 1X
∇λ L = ∇λ log q(βs |λ)(log p(βs |η) − log q(βs |λ) − aˆ∗λ + n(log p(zis |βs ) + log p(xi , zis |βs )))
S i=1
S
ˆφ L =1
n X
∇ ∇λ log q(zs |φi )(−aˆ∗φi + n(log p(zis |βs )
X
+ (log p(zis |βs ) + log p(xi , zis |βs ))) i
S i=1
i=1
S + log p(xi , zis |βs ) − log q(zis |φi )))
ˆφ L =1
X
∇ i
∇λ log q(zs |φi )((log p(zis |βs ) ˆ φ L =0 for all j 6= i.
∇ (A.20)
S i=1
j
S
ˆ λL = 1
X
∇ ∇λ log q(βs |λ)(log p(βs |η) − log q(βs |λ)
S i=1
− n(log p(zis |βs ) + log p(xi , zis |βs )))
S
ˆφ L =1
X
∇ i
∇λ log q(zs |φi )(n(log p(zis |βs )
S i=1
+ log p(xi , zis |βs ) − log q(zis |φi )))
ˆ φ L =0 for all j 6= i.
∇ j