Mean Field
Mean Field
Abstract
While a class of models and techniques in deep learning has achieved
empirical success, the interactions of their underlying mechanisms are
under-explored. Oftentimes, researchers who seek clarity in the science
of deep learning adapt theoretic tools developed in other scientific fields.
In statistical mechanics, an approximation technique for complex, inter-
active systems called Mean Field Theory (MFT) is now broadly applied
to explain why deep learning works. In particular, MFT highlights the
dynamical system similarities between a deep network’s parameters and
interacting particles. By adapting MFT to study a network’s signal prop-
agation, theorists can explore behaviors of very large, general neural net-
works that experimental work alone can’t cover [Saul et al., 1996]. Most
recently, various papers on this topic have been gaining popularity [Hanin,
2018, Karakida et al., 2018, Kawamoto et al., 2018, Pretorius et al., 2018,
Schoenholz et al., 2016].
To machine learning practitioners, however, the conference papers on
the topic may be too short to be accessible. This paper serves as an
introduction to mean field formalism as applied to study properties of
neural networks. Readers who wish to understand the subfield should
find here tools, definitions, illustrations that clarify the motivation and
assumptions used in current works.
We first introduce mean field theory as in its historical context of
physics, with an example on the Ising model. Then we connect MFT to
machine learn through parallels drawn in variational inference. Finally, we
summarize the setups of MFT modelling in recent advances to help under-
stand neural networks’ expessivity [Poole et al., 2016], ResNets [Yang and
Schoenholz, 2017b], Convolutional neural networks [Xiao et al., 2018], and
most recently batch normalization [Yang et al., 2019] and gradient descent
dynamics. In summary, we show that application of MFT touches very
popular architectures and empirical techniques in today’s deep learning
era.
1
Contents
1 MFT In Statistical Physics 3
1.1 Isolated Magnet In a Heat Bath . . . . . . . . . . . . . . . . . . . 3
1.2 Ising Model . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 4
1.2.1 Correlation function . . . . . . . . . . . . . . . . . . . . . 4
1.2.2 Factorization Approximation in MFT . . . . . . . . . . . 6
1.3 High Dimension Ising Model and Mean Field Approximation . . 6
1.3.1 Self-averaging MFT . . . . . . . . . . . . . . . . . . . . . 8
1.4 Conclusion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 8
2 MFT in statistics 9
2.1 Variational Inference . . . . . . . . . . . . . . . . . . . . . . . . . 9
2.2 Limitations . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 12
7 Conclusion 24
8 Appendix 28
8.1 Derivation for ELBO . . . . . . . . . . . . . . . . . . . . . . . . . 28
8.2 Variational Mean Field for the Ising model . . . . . . . . . . . . 28
2
1 MFT In Statistical Physics
Strong assumptions notwithstanding, some simplified models can explain real-
world observations without resorting to very difficult mathematics. Originally,
Mean Field Theory stemmed from such models physicists used to explain macro-
scopic phenomenons.
The Ising Model proposes that spins of particles arrange themselves on a
chain in one-dimensional space, or a lattice in higher dimensions. Furthermore,
each particle takes on a binary state: up or down. In addition, every spin’s
(stochastic) properties are only dependent on its nearest neighbors: two on a
line, four in a plane, and 2d in d dimensions.
This section sets up a physical system under the Ising Model, and introduces
the usage of mean field approximation in deriving phase transition, applied to
magnetization, largely based on Statistical Mechanics lecture [Susskind, 2013].
3
for numeric factor J. We obtain its partition function, summed over all 2
configurations.
X
Z= eβJσ = e(+1)βJ + e(−1)βJ = eβJ + e−βJ = 2 cosh βJ
configs
The whole system of N individual magnets has a factor partition function, which
is the individual spin’s raised to the N -th power, since each one is independent.
That allows us to the take its logarithm and get a sum. We calculate the
expected value of the thermodynamic energy, which is the negative inverse times
the derivative of the partition function with respect to the inverse temperature.
1 δZ J sinh(βJ)
Eone spin = − =− = −J tanh βJ
Z δβ cosh(βJ)
configs
where Ei is the energy associated with the state of interest. As per convention, β = 1/T , and
we will use β for inverse temperature throughout this paper.
4
Figure 2: A 1-D Ising Model of ferromagnetic σ‘s.
X n−1
X
E = −J σi σi+1 = −J µi
i
Notice that there are 1 fewer bonds than all the particles. In our transformation
of the energy makeup, the individual bonds have no relationship among them
as far as the equation goes. The information is retained, also, since it is as good
to know the µ’s as it is to know σ’s. So now you can substitute the sum over
spins in Z, the partition function, with the sum of the values of the bond:
X P
Z=2 e− i Jβµi
µ
The factor 2 arises from the possibilities for the first spin, which we condition on.
Recall that N = ||i|| = ||µ||+1. The Boltzmann factor here is a product, assum-
ing N spins so kik = N . As such, we factorize the partition function into that
of one spin’s energy raised to the kµk-th power and obtain Z = (2 cosh βJ)N −1 ,
a familiar-looking partition function we saw in Section 1.1.
Despite the partition function, the physical meaning is different from N-1
isolated magnets, because µ is the product of the neighboring spins. Now con-
5
sider hµi = E(µ), the average correlation between immediate pairs of neighbors.
Through the same calculus
where positive J biases a positive value. In the physical system, this indicates
a tendency towards alignment, in a way that is better than even chance. We
write the correlation between i and i + n spins as their product:
If we assert the independence of the µ’s in this formulation, and substitute the
average.
hσi σi+n i ≈ hµin−1 = (tanh βJ)n−1
Given this being higher than the uniformly random expectation of 12 , we see a
long range memory in magnetization: everything will be biased to go up if the
first one is up.
Though all MFT has its origin in taking the average, for simplicity, we refer
to this specific approximation strategy as ”factorization”.
6
on one spin of a small magnet, assumed to be at equilibrium with the rest of
the environment, with the partition function
X
Zone spin = e+βJσ = eβJ + e−βJ = 2 cosh βJ
Zwhole system = (eβJ + e−βJ )k = 2 cosh(βJ)k
1 δz J sinh(βJ)
Eone spin = − =− = −J tanh βJ.
z δβ cosh(βJ)
So in expectation, hσiaverage = tanh βJ. This sets up for mean field approxima-
tion: In approximating the bias, we assume a high dimensionality Ising where
the average fluctuation is a lot smaller than the average bias. There, using
the average is a pretty good approximation for individual Pbehavior if the num-
ber of neighbors, 2d, is large. For one spin, Ei = −jσi j neighboring i σj . For
simplicity, let
¯ = tanh[(2βdJ)σ̄].
σ̄
¯
Similarly, if they have an average of sigma, then σ̄ ¯
¯ = sigma. This gives an
equation that applies to all temperature:
σ̄ = tanh[(2βdJ)σ̄]
y
Let y = (2βdJ)σ̄, then = tanh y
2βdJ
y
Recall β being inverse temperature, so T = tanh y.
2dJ
We plot both sides at different temperatures, as shown in Figure 4. The only
possible solution is y = 0, at very high T , so the average of sigma is 0, as
expected. As we lower the temperature, the slope of this curve on the left-
hand-side decreases to the point of 1, so we are tangent to the tanh(.), when
T = 2dJ. This is a critical point and signals that our approximation shows a
phase transition!
7
y
Figure 4: phase transition When T = 2dJ, T 2dJ = y, we see a hyperbolic
tangent curve tanh y only intersecting with line y at the origin y = 0, as shown
y
on the left plot. As we lower the temperature as when T = dJ, T 2dJ = y/2,
we see an additional critical point where they intersect, as shown on the right
plot. The point corresponds to a critical temperature where phase transition
happens.
1.4 Conclusion
The Ising Model is a simplified generating mechanism in statistical physics that
encapsulates complex behavior. In 1D, magnets influence each other with de-
caying correlation at long distance. By seeing each spin as a mean of the field of
spins it is in, we demonstrate two MFT flavors, factorization and self-averaging,
and derive the phenomenon of magnetization.
The mean field approach conditionally simplifies the Ising Model’s mathe-
matics. In studying phase transition, the MFT relies on high dimensionality,
which dominates the criterion for the derivation. Essentially, a particular sit-
uation was picked as a way to make a spin have a lot of nearest neighbors to
apply the mean field.
Without mean field assumptions, a close form would be very hard to com-
pute. Unsurprisingly, extending mean field approaches is demonstratively pow-
erful in high dimensional statistics. The next section introduces Variational
Mean Field methods in statistics.
8
2 MFT in statistics
One of the major problems in statistics is to approximate hard to compute
probability distributions for a system. This is especially important in Bayesian
Inference and statistical machine learning, where a joint probability distribution
over unobserved and observed data is required. The distribution maybe easy to
compute for some small models. However, for large complex models, it is not
at all easy. Exact inference on such models is not practically possible. We look
at a class of approximation techniques called variational methods that attempt
to approximate the probability distributions as best as possible. In the section
that follows, we briefly introduce the problem of inference and how a mean-field
assumption helps in efficiently computing the required estimate of probability
distribution. The sections is largely based on Blei et al. [2017].
The edge drawn in the graph above relates variables Z and X as a con-
ditional distribution P (X|Z). We now look at a general problem formula-
tion. Consider hidden variables Z = {Z1 , Z2 , · · · , Zm } and visible variables
X = {X1 , X2 , · · · , Xn }. Inference in a Bayesian setting usually involves, cal-
culating the posterior over hidden variables i.e probability of hidden variable
conditioned on observed data. By Bayes theorem, we have
P (X, Z)
P (Z|X) =
P (X)
9
The denominator P (X) is the marginal probability of the observations also
called the likelihood of evidence. This is obtained by marginalizing the hidden
variables in the joint distribution P (X, Z). This is simply the sum of the join
distribution over all possible configurations of the hidden variables. Thus the
likelihood of the evidence P(X) is represented as
X
P (X) = P (X, Z)
Z
In order to calculate the posterior over hidden variables, we require the like-
lihood of evidence. This is all well and good for small models, but for large
complex models, the number of hidden variables tends to be very large. In this
case, the sum in the likelihood of evidence becomes very hard to compute since
it involves summation of a very large number of terms. The number of terms
in the sum increases in an exponential manner with respect to the number of
hidden variables. It is now clear that the sum is intractable for large number of
hidden variables and some sort of approximation is required for P (Z|X).
But how do we go about approximating P (Z|X)? One method for approx-
imate inference is a sampling based method called Monte Carlo Markov Chain
(MCMC) sampling. MCMC algorithms are very popular and find applications
in a wide number of problems. One key feature of such methods is that they
provide guarantees (asymptotically) of producing exact samples from the target
density (the density that had to be approximated). This makes them ideal for
scenarios that require precise samples. However, MCMC tends to be computa-
tionally expensive and does not scale well (in terms of computation time) for
large and complicated models. For such cases, variational inference acts as a
faster alternative. Even though variational inference, does not provide guaran-
tees similar to MCMC, they give reasonable results. Thus they are suitable for
scenarios where there is huge amount of data and a fast exploration through
models is needed.
In variational inference, we introduce a family of distributions Q over the
hidden variables Z. Each member Q(Z) in the family Q is a potential approxima-
tion to the posterior over the hidden variables. To find the best approximation,
we resort to the Kullback-Liebler divergence between our posterior P (Z|X) and
a member of Q. The Q(Z) that is closest in KL divergence with our posterior
is the best approximation.
Q∗ (Z) = arg min DKL ( Q(Z)|| P (Z|X))
Q(Z)∈Q
10
KL divergence directly, we can minimize a new objective which is the difference
between the expectation of the logarithm of distribution Q(Z) and the expecta-
tion of the logarithm of the joint distribution of observed and hidden variables.
The log-likelihood of evidence P(X) does not depend on the distribution Q(Z)
and therefore remains a constant.
DKL ( Q(Z)|| P (Z|X)) = E[log Q(Z)] − E[log P (X, Z)] + log P (X)
This new objective is called the ELBO or the evidence lower bound.
The family of distributions that is chosen for Qi is usually the exponential family.
It turns out that this family along with independence assumptions simplify the
optimization of the objective. We will look into this into a little more detail
later. We apply the variational mean field method to an example, specifically the
high dimensional Ising model. The full derivation can be found in the Appendix
8.2
We have seen the variational inference converts the original inference prob-
lem into an optimization problem that maximizes the ELBO. The posterior
is then approximated with a family of mean field distributions i.e. factorized
models that assume sparse interaction terms 2 .
Superior in computability, MF sacrifices interaction terms between groups
of latent variables. The independence assumption welcomes many optimization
2 For factorization justification, see exponential-family-conditional models, a.k.a condition-
ally conjugate models where latent variables are independent.[Blei et al., 2017]
11
methods, such as coordinate gradient ascent: at every iteration, some coor-
dinates are held fixed while others are optimized. Effectively, the coordinates
allow the ELBO to climb to a local optima. If an exponential family is used for
the mean field distribution, updates in the coordinate ascent algorithm simplify
resulting in faster computations.
2.2 Limitations
Despite the performance boosts of mean-field variational inference in terms of
computation costs, the method suffers from limitations. The main limitation
of mean-field inference is that it explicitly ignores correlations between latent
variables when making the independence assumption. As a result, despite cap-
turing the marginal probability distribution of latent variables, it fails to capture
their correlation. Moreover, the marginal variances of the approximation under-
represent the true posterior. This behaviour can be explained by the form of KL
divergence used in mean field variational inference. The KL divergence penalizes
mass placed in variational distribution Q(Z) when the true posterior P (Z|X) is
small. This basically means, that Q(Z) is forced to be small whenever P (Z|X)
is small. The above behaviour can be see as ’zero-forcing’ since P (Z|X) = 0
implies Q(Z) = 0 [Minka, 2005]. This zero-forcing behaviour emphasizes on
modelling the tails of the distribution rather than bulk which results in under-
estimating the variance of the true posterior. Another consequence of this is
that mean-field variational inference does not approximate well when the true
posterior is a multi-modal distribution. It tends to model the mode with highest
probability mass rather than the entire distribution.
Figure 6: The approximate distribution q models the tails of the true distribu-
tion rather than the mass [Minka, 2005]
12
3 When is Mean Field Good?
In the previous sections, it appears that MFT aggressively simplifies complex
mathematics without sacrificing fidelity. Though conceptually intriguing and
empirically successful, it is not yet clear how such a mean field strategy can be
formally employed. In particular, why should the particles be self-consistent in
a large network, how will the observed transition in the mean field generalize,
when are interaction terms “safe” to ignore, and why are factorized distributions
a good choice in optimization. We leave the bulk of the research to the reader
by summarizing some key ideas.
Suitability: Like the Ising Model where MFT originated, the probabilistic
graphical models of concern are generally large. When the statistics of inter-
correlation at long distances decay, the maximal terms dominate, provided that
the energy is modeled as a summation dependent on interactive strengths. This
justifies the first-order methods in method field analyses. In addition, part of the
model’s apparent success came from studying behaviors near extremal points,
such as zero temperatures or very high dimensions. This strategy pushes down
the influence of other terms and ensure the dominance of the averaging effects.
MFT’s empirical success may be due to appropriate mean-field assumptions,
because the decoupling of variables is not too far from what one sees in the true
posterior due to natural clustering: some particles are closer to each other than
to random particles, thus allowing the variation in parameter to capture the
diversity in observation [Xing et al., 2002].
Tests: MFT trade-off is reflected in the difference between the observation
and the approximation. In physics, this is done via the Gibbs-Bogoliubov-
Feynman inequality; in variational statistics, it is done via testing the distance
from the ELBO, since the mean-field models’ marginal variances are lower
bounds on the variance of real data. Specifically, TAP correction [J. Thou-
less et al., 1977] and second order approximations [Kappen and Wiegerinck,
2000] are commonly used in conjunction with MFT to improve the quality of
mean-field results.
Heuristics: For a heuristics-based procedure, we summarize several strate-
gies in using MFT. The setup of the problem needs to have elements of stochas-
ticity, the number of particles need to reach a scale, so that a self-averaging
behavior could be observed. In approximate inference, factorization is the most
prominent in variational methods.
13
deep learning theory work that places mean field analysis squarely at its center,
which we coin as a kind of phenomenological deep learning.
While large-scale neural networks work amazingly well, many phenomenons
they exist remain elusive. In untangling various effects, empirical work alone
is often insufficient, partly due to its scale limit in coverage across data sets,
parameters, and architectural features. A phenomenology is thus desirable. This
section lays down a shared mean field formalism used by [Chizat and Bach, 2018,
Mei et al., 2018, Yang and Schoenholz, 2017a,b, Yang et al., 2019] to study the
efficacy of deep learning itself.
Figure 7: Methods that discover general insights in large scale learning make
trade-offs between incurring expensive computation and making strict theoretic
assumptions.
Mean field theory becomes a natural tool in this pursuit for generalizing
insights, such as optimization behaviors at limits [Mei et al., 2018], why neural
networks generalize [Jacot et al., 2018], where gradient explosion and vanish-
ing happen [Yang and Schoenholz, 2017b], and the efficacy of BatchNorm in
stabilizing training [Ioffe and Szegedy, 2015, Yang et al., 2019].
14
x Network input
D Input dimension
N Number of Neurons
{Wij } Weight matrix
{bj } Bias vector
2
σw Variance of weights
σb2 Variance of biases
k Number of SGD run
Common assumptions
15
At every layer, we can take the output of the previous layer, and define a
recurrence relationship layer wise. This effectively places a dynamical system
view by studying the changes of inputs from layer to layer over the space con-
variance matrices. As the width goes to infinity, this Gaussian process between
input and output can be seen as deterministic. This feedforward dual is studied
more extensively in Section 5.1.
16
The idea is that in the suitable scaling limit, the reduction of population loss
is captured by ρ; on the other hand, ρ is the solution to a partial differential
equation. As a result, SGD can be approximated using Wasserstein gradient
flow. This is useful in studying macroscopic phenomenons, because the scale
limit can be then derived from the PDE formulation, often by examining the
dependency of the critical points with respect to the variables. This flavor is
very similar to the duality view in physics, as seen in Section 1.2.2, which often
accompanies the application of MFT. This duality is, however, not exact, and
thus requires further examination to hold.
Notably in [Mei et al., 2018], the cost function in the space of (P, W2 ) is
viewed as a gradient flow. The results and approach are further summarized in
Section 6.
17
geometry, they give a theoretical formulation that proves that the expressivity
of neural networks increases exponentially with depth. We briefly discuss their
approach in this section.
Signal propagation in deep neural networks can be understood by studying
the geometry of simple manifolds in the input layer x[0] . Essentially, we would
like to know how the geometry is modified as the manifold propagates through
numerous layers. For the simplest case of a single vector, one can track its
‘length’ i.e. the squared norm, represented as:
Nl
1 X
ql = (z l )2
Nl i=1 i
l
Similar to (2), we can derive a correlation map for q12
Z
l−1 l−1 l−1
q12 = C(c12 , q11 , q22 |σw , σb ) ≡ σw Dz1 Dz2 φ(u1 )φ(u2 ) + σb2
l 2
(3)
l l −1/2
where cl12 = q12
l
(q11 q22 ) is the correlation coefficient4 . Together (2) and
(3), form a theoretical prediction for the geometry of a pair of points5 as they
3 These can be from the input layer x[0] or pre-activations in intermediate layers
4 Also corresponds to the cosine similarity between pre-activations
5 Points in the input manifold for a layer or in other words two inputs
18
propagate through a neural network. Analyzing the the equations in the σw and
σb plane reveals and interesting order to chaos transition for the system. The
relation between two points can be tracked by the correlation coefficient cl12 .
Using the fixed point q ∗ (σw , σb ) for the length of a single vector, we calculate
an iterative correlation coefficient map (C-map) as
1 ∗ ∗
cl12 = C(cl−1
12 , q , q |σw , σb )
q∗
The C-map has a fixed point at 1 (c∗ = 1). However, the stability of the fixed
point depends on the slope at 1
∂cl12 h 0 √
Z i2
2
χ1 ≡ l−1
= σw Dz φ ( q ∗ z)
∂c12 c=1
5.1.2 Gradients
While Poole et al. [2016] investigates the nature of the signal as it propagates
through the network in a forward dynamics, Schoenholz et al. [2016] study
the nature of gradients drawing in a duality between forward and backward
propagation.
Consider the backpropagtion of a given loss E,
∂E ∂E
= δil φ(zjl−1 ) δil =
∂Wijl ∂zil
6 Term coined by Yang and Schoenholz [2017b]
19
Within mean field theory, it is clear that the scale of fluctuations of the gradient
of weights in a layer will be proportional to the second moment of δli [Schoen-
holz et al., 2016]. The authors note that unlike the pre-activations in forward
propagation, δil will not be a Gaussian distribution even for N → ∞. However,
l
we can obtain a recurrence relation for q̃aa = E (δil )2 under the assumption
that the weights used during backpropagation are drawn independently from
the weights used in forward propagation.
l l+1
q̃aa = q̃aa χ1
Note: The equation above also contains a factor proportional to Nl+1 /Nl which
is unity for our setup. Since χ1 depends only on the asymptotic c∗ , the above
equation has an exponential solution resulting in a phase transition boundary
similar to what was discussed in the previous section, but for gradients. When
in the ordered phase (χ1 < 1), the gradients are expected to vanish over a
depth whereas in the chaotic phase (χ1 > 1), gradients are expected to explode.
On the edge of chaos, namely region χ1 → 1, the gradients should be stable
regardless of depth.
The results in 5.1.1 and 5.1.2 lead to a trainability vs expressivity trade-
off for fully-connected neural networks. While deep networks operating in the
chaotic phase tend to be more expressive (with expressivity increasing with
depth up to a fixed point), the gradients for such networks tend to explode with
increase in depth. For networks on the edge of chaos, extremely deep neural
networks can be trained. This is because information about the inputs is able
to propagate forward and information on gradients are also able to propagate
backwards through the deep network.
5.2 Resnets
In the previous sections, we have seen that the exponential forward dynamics
of sigmoidal feed forward neural networks causes a rapid collapse of the input
geometry7 . A similar scenario exists for the backward dynamics causing gra-
dients to drastically vanish or explode. Yang and Schoenholz [2017b] build on
previous works ([Poole et al., 2016],[Schoenholz et al., 2016]) and show that by
adding skip connections, the network adopts a sub-exponential or polynomial
forward and backward dynamic (depending on the non-linearity). This slower
convergence to the fixed points allows residual networks to ’hover’ over the edge
of chaos longer. This provides some theoretical justification as to why ResNets
with a large number of layers work well in practice.
The main results in Yang and Schoenholz [2017b] are:
• The forward dynamics for tanh and α-ReLU (α < 1) is polynomial with
depth.
• The backward dynamic for tanh is sub-exponential whereas the backward
dynamics for α-ReLU (α < 1) is asymptotically polynomial.
7 The input geometry exponentially converge to the fixed point
20
• ReLu exhibits exponential forward and backward dynamics asymptoti-
cally. One interesting observation is that not all gradient signals exhibit
exponential behaviour. The gradient norm with respect to the weights w is
independent of how far the gradient has propagated (it is constant). This,
however, is not the case with bias b for which it increases exponentially.
5.3 CNN
Convolutional Neural Networks have been crucial to the success of deep learning.
However, most of these deep models are only trainable by employing techniques
like residual connections and batch normalization. Although we have seen some
justification in support of techniques like residual connections (Section 5.2), it
is still unclear whether deep CNNs necessarily require these techniques for
successful training. Xiao et al. [2018] develop a mean field theory of CNNs to
investigate this issue by furthering works discussed above. One key difference
in the mean field assumption for CNNs is that instead of considering the large
width assumption i.e N → ∞ we assume a large number of channels i.e the
channels c in a filter tend to infinity.
Xiao et al. [2018] find that the mean field derivation for signal propagation
in CNNs is similar to that of [Poole et al., 2016] and the stability condition
is precisely the one that govern fully-connected networks (as discussed in Sec-
tion 5.1.1). Moreover, the fixed point analysis for CNNs leads to the same result
as in the case of feed-forward neural network. This means that for CNNs too,
there exists a phase transition boundary at χ1 = 1. For χ1 < 1, c∗ = 1 is a
stable fixed point and the network exists in an ordered phase where all pixels
converge to the same value. For χ1 > 1, there exists stable fixed point c∗ < 1.
This corresponds to the chaotic phase where all pixels values de-correlate.
The analysis for the backward propagation of the signal leads to the same
result as the one derived in Section 5.1.2. Thus, the network must stay in the
edge of chaos to ensure that gradient signals neither explode nor vanish as they
back-propagate through a convolutional network.
The authors note that although the order-to-chaos phase boundaries of fully-
connected and convolutional networks look identical, the underlying mean-field
theories differ. A novel aspect of the convolutional theory is the existence of
multiple depth scales that control signal propagation at different spatial fre-
quencies. In the large depth limit, signals can only propagate along modes
with minimal spatial structure; all other modes end up deteriorating, even at
criticality.
Xiao et al. [2018] push their analysis beyond mean-field theory by incorpo-
rating dynamical isometry [Pennington et al., 2017] for CNNs. They develop
a modified initialization scheme that allows for balanced propagation of signals
among all frequencies. They call this scheme Delta-Orthogonal initialization.
This scheme allows them to train ultra deep vanilla CNNs with no degradation
in performance.
21
5.4 BatchNorm and Gradients
BatchNorm remain elusive in machine learning theory. On one hand, it clearly
works well in practice. On the other, there is no clear theory on why it works;
some theorized pre-conditioning leading to some notion of stability in training
because the landscape to optimize is much smoother, and the story seems to
have stopped there since [Santurkar et al., 2018]. Similar to Phase Transition
in Section 1.3, Yang et al. [2019] find the limits of phenomenons of interest: at
L < 50, gradient explosion is small because the gradients are small compared
to the weights and the weights don’t change much; at L > 50, explosion domi-
nates W ’s: weight norm decreases, and from t = 0 to t = 1, gradients cross the
threshold of |W | = |∇(·)|. The major contributions enabled by MFT includes
the mathematics to show that BatchNorm causes gradient explosion, enlarging
gradient norm with every layer by 1.47, and a linear activation is suggested to
minimize the explosion rate to b−2
b−3 where b denotes the batch size. This conclu-
sion is at odds with several other theories that postulate the stability benefits of
BatchNorm, suggesting that BatchNorm works through other benefits. In this
way, MFT made a difficult-to-test theory more feasible to study.
22
to be random.
(N )
θ̇i = −∇Ψ(θi )ρ̂t )
N
(N ) 1 X
ρ̂t = δθj (t)
N j=1
The system describes a particle moving in the force field defined by these
other particles, following a non-linear dynamic. Recall that these are i.i.d.
trajectories, because θi0 is drawn from ρ0 . On the other hand, we formulate
a different system with θ̄i (t), which describes n independent initialization. The
evolution is described in a way akin to that of a system of particles in physics:
˙ = ∇Ψ(θ̄(t), ρ ).
θ̄(t) t
The system of θ̄ is then used to relate to θ, with θ̄i (0) = θi0 so that they live in the
same probability space. Eventually, to show that the PDE written from gradient
flow approximates SGD, it suffices to bound the distance of this approximation
for some bounded function M :
1 X
dθ,θ̄ (t) = |θi (t) − θ̄i (t)|2 ≤ M (N, t, D)
N i
This PDE describes the evolution of the particle in the force field provided by
the “density” of all the other particles. This strips N from the optimization at
large N , showing that the optimization does not infinitely scale with the number
of neurons. While a non-actionable result in empirical machine learning, this
suggests that over-parametrization is only part of the story why neural nets
work. Additionally, this formulation effectively reduces a N × D-dimensional
problem to a problem of only D dimension, and a very random process to a
somewhat deterministic one.
23
The two mean fields may merge in the future. As of now, however, the
biggest weakness is that the assumptions made are extremely restrictive; an
open challenge lies in not just how to limit the width of each layer, but also
in how to extend this to multiple layers, mostly recently attempted by Nguyen
[2019] “non-rigorously”. While the scaling limits between neuron size N , num-
ber of steps k and dimension D are reasonable, the number of hidden layers
staying at 1 is unacceptable. In addition, the particle descent formulation re-
quires continuity equation, which is not rigorously shown to be convergent in
realistic settings. As is, this MFT formulation is thus unlikely to offer gener-
alizing insights to practitioners. However, this prototypical framing of particle
evolution abstracts away training dynamic from network features, thus allow-
ing for the novel application of diverse mathematics to study what makes deep
learning work, as seen in [Chizat and Bach, 2018] [Rotskoff and Vanden-Eijnden,
2018].
7 Conclusion
Explaining deep learning’s empirical success requires an intersection of acute
observation and appropriate approximation. Mean Field Theory is a power-
ful technique, uniquely applied to the scale and practice of deep learning. As
practitioners set out to fully understand and apply MFT, it is essential to un-
derstand the situations under which MFT is effective, the strategies to use, and
the limitations where the theory is inappropriate.
This survey paper motivates the use of MFT in deep learning through the
historical practice of mean-field methods in physics and statistics, with flavors of
factorization and self-averaging. Its general philosophy states that the study of
the phenomenon of the system can be divorced from the study of its parts, and
that the parts are self-consistent. In studying why neural networks work, this
abstraction is drastically different from the experimental methods that try to
isolate the widgets of the most affect. At the heart of its mathematics, mean field
approximation assumes some extent of independence among entities, making it
a suitable theoretic for studying practical regimes of over-parametrized neural
nets.
Throughout the survey, we discuss the ample restrictions in each of the MF
approximations. Like all theoretical models, MFT is wrong when its assump-
tions deviate from practice e.g. inconsistency near critical conditions, because
the fluctuations and correlations between particles are not modelled, or the
details of the phenomenon studied may also be much more diverse than the
averaging effects mean field theory assumes. To mitigate, higher order methods
are used, such as TAP correction.
We summarize current progress on connecting MFT to deep learning, a fast-
moving area of research. We introduce a specific formalism which is agnostic to
the scale of data and model. This MFT has been successfully applied to study
the behaviors of a variety of popular deep learning architectures. Though pow-
erful, MFT comes at a cost: mean-field modeling in deep learning necessarily
24
simplifies the entire phenomenon. The results obtained are correct in the weak
sense: they are only exact under strict assumptions, and are otherwise approx-
imations. In complement, experimental work is used verify the phenomenons
derived through mathematics under those unrealistic assumptions. Future work
in MFT should consider second-order interactions when there are finite neu-
rons while extending the mathematics to be universal for all neural networks
features.
References
David M Blei, Alp Kucukelbir, and Jon D McAuliffe. Variational inference: A
review for statisticians. Journal of the American Statistical Association, 112
(518):859–877, 2017.
Lenaic Chizat and Francis Bach. On the global convergence of gradient descent
for over-parameterized models using optimal transport. In Advances in neural
information processing systems, pages 3036–3046, 2018.
Boris Hanin. Which neural net architectures give rise to exploding and vanishing
gradients? In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-
Bianchi, and R. Garnett, editors, Advances in Neural Information Processing
Systems 31, pages 582–591. Curran Associates, Inc., 2018. URL http:
//papers.nips.cc/paper/7339-which-neural-net-architectures-
give-rise-to-exploding-and-vanishing-gradients.pdf.
Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating
deep network training by reducing internal covariate shift. arXiv preprint
arXiv:1502.03167, 2015.
D J. Thouless, Philip Anderson, and R G. Palmer. Solution of ’solvable
model of a spin glass’. Phil. Mag., 35:593–601, 03 1977. doi: 10.1080/
14786437708235992.
Arthur Jacot, Franck Gabriel, and Clément Hongler. Neural tangent kernel:
Convergence and generalization in neural networks. In Advances in neural
information processing systems, pages 8571–8580, 2018.
25
Ryo Karakida, Shotaro Akaho, and Shun ichi Amari. Universal statistics of
fisher information in deep neural networks: Mean field approach, 2018.
Tatsuro Kawamoto, Masashi Tsubaki, and Tomoyuki Obuchi. Mean-
field theory of graph neural networks in graph partitioning. In
S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-
Bianchi, and R. Garnett, editors, Advances in Neural Information
Processing Systems 31, pages 4361–4371. Curran Associates, Inc.,
2018. URL http://papers.nips.cc/paper/7689-mean-field-theory-of-
graph-neural-networks-in-graph-partitioning.pdf.
Song Mei, Andrea Montanari, and Phan-Minh Nguyen. A mean field view of the
landscape of two-layer neural networks. Proceedings of the National Academy
of Sciences, 115(33):E7665–E7671, 2018.
Thomas Minka. Divergence measures and message passing. 01 2005.
Phan-Minh Nguyen. Mean field limit of the learning dynamics of multilayer
neural networks. arXiv preprint arXiv:1902.02880, 2019.
Jeffrey Pennington, Samuel S. Schoenholz, and Surya Ganguli. Resurrecting
the sigmoid in deep learning through dynamical isometry: theory and
practice. In Advances in Neural Information Processing Systems 30:
Annual Conference on Neural Information Processing Systems 2017, 4-9
December 2017, Long Beach, CA, USA, pages 4788–4798, 2017. URL
http://papers.nips.cc/paper/7064-resurrecting-the-sigmoid-in-
deep-learning-through-dynamical-isometry-theory-and-practice.
Ben Poole, Subhaneil Lahiri, Maithra Raghu, Jascha Sohl-Dickstein, and Surya
Ganguli. Exponential expressivity in deep neural networks through transient
chaos. In Advances in neural information processing systems, pages 3360–
3368, 2016.
Arnu Pretorius, Elan Van Biljon, Steve Kroon, and Herman Kamper. Critical
initialisation for deep signal propagation in noisy rectifier neural networks. In
NeurIPS, 2018.
Grant M Rotskoff and Eric Vanden-Eijnden. Neural networks as interacting
particle systems: Asymptotic convexity of the loss landscape and universal
scaling of the approximation error. arXiv preprint arXiv:1805.00915, 2018.
Shibani Santurkar, Dimitris Tsipras, Andrew Ilyas, and Aleksander Madry. How
does batch normalization help optimization? In Advances in Neural Infor-
mation Processing Systems, pages 2483–2493, 2018.
Lawrence K. Saul, Tommi S. Jaakkola, and Michael I. Jordan. Mean field theory
for sigmoid belief networks. J. Artif. Intell. Res., 4:61–76, 1996.
Samuel S Schoenholz, Justin Gilmer, Surya Ganguli, and Jascha Sohl-Dickstein.
Deep information propagation. arXiv preprint arXiv:1611.01232, 2016.
26
Leonard Susskind. Statistical Mechanics lecture 9. https://www.youtube.com/
watch?v=AT4_S9vQJgc, 2013. Accessed: 2019-04-26.
Martin J. Wainwright and Michael I. Jordan. Graphical models, exponential
families, and variational inference. Found. Trends Mach. Learn., 1(1-2):1–
305, January 2008. ISSN 1935-8237. doi: 10.1561/2200000001. URL http:
//dx.doi.org/10.1561/2200000001.
Lechao Xiao, Yasaman Bahri, Jascha Sohl-Dickstein, Samuel Schoenholz, and
Jeffrey Pennington. Dynamical isometry and a mean field theory of CNNs:
How to train 10,000-layer vanilla convolutional neural networks. In Jennifer
Dy and Andreas Krause, editors, Proceedings of the 35th International Con-
ference on Machine Learning, volume 80 of Proceedings of Machine Learning
Research, pages 5393–5402, Stockholmsmässan, Stockholm Sweden, 10–15 Jul
2018. PMLR. URL http://proceedings.mlr.press/v80/xiao18a.html.
Eric P Xing, Michael I Jordan, and Stuart Russell. A generalized mean field
algorithm for variational inference in exponential families. In Proceedings
of the Nineteenth conference on Uncertainty in Artificial Intelligence, pages
583–591. Morgan Kaufmann Publishers Inc., 2002.
Ge Yang and Samuel Schoenholz. Mean field residual networks: On the edge
of chaos. In Advances in neural information processing systems, pages 7103–
7114, 2017a.
Greg Yang and Samuel S. Schoenholz. Mean field residual networks: On the
edge of chaos. In NIPS, 2017b.
Greg Yang, Jeffrey Pennington, Vinay Rao, Jascha Sohl-Dickstein, and
Samuel S. Schoenholz. A mean field theory of batch normalization. In
International Conference on Learning Representations, 2019. URL https:
//openreview.net/forum?id=SyMDXnCcF7.
27
8 Appendix
8.1 Derivation for ELBO
The log likelihood of the evidence can be written as follows,
X
log P (X) = log P (Z, X)
Z
X P (X, Z)
= log Q(Z|X) (Introduce a distribution Q(Z|X))
Q(Z|X)
Z
X P (X, Z)
≥ Q(Z|X) log (By Jensen’s Inequality)
Q(Z|X)
Z
P (X, Z)
≥ EQ log
Q(Z|X)
≥ EQ [log P (X, Z)] − EQ [Q(Z|X)]
Thus,
We represent the state of each atom i by a random variable σi which takes values
+1 or -1. The marginal probability of a configuration of states σ, is represented
as:
p(σ) ∝ e−βH(σ)
P −βH(σ)
The normalization factor Z = e requires the sum over a huge number
σ
of configurations and so the exact marginal probability is often intractable (Sum
28
2
over 2N and 2N terms in case of 1D and 2D lattices which can be very large as
N increases).
Our main goal is to use mean field methods to approximate the probability of
a configuration of lattice points. One important thing to note is that, typically,
the true distribution is not in the variational family obtained by mean field. We
approximate the exact probability p(σ) with q(σ) such that q belongs to the
exponential family of distributions. We consider that q is fully factorizable :-
Y N
Y
q(σ) = qi (σi ) = qi (σi )
i∈V i=1
Our goal is to find q that acts a best approximation for p. For this we consider
minimizing the Kullback-Liebler divergence between the two distributions.
N
Q
N
! Z YN qj
Y j=1
DKL qi p = ( qi ) log dσ
i=1 i=1
p
N
Z Y N
X N
Z Y
= qi log qj dσ − qi log p dσ + C
i=1 j=1 i=1
N
Z Y N
Z Y X N
Z Y
= qi log qk dσ + qi log qj dσ − qi log p dσ + C
i=1 i=1 j6=k i=1
Z Z Y X N
Z Y
= qk log qk dσk + qi log qi dσ -k − qi log p dσ + C
i6=k i6=k i=1
| {z }
constant wrt to qk
Z Z Y
0
= qk (log qk − qi log p dσ -k )dσk + C
i6=k
R Q
Let us consider r(σk ) = qi log p dσ -k . We can normalize this to obtain a
i6=k
29
r(σk )
distribution s(σk ) = R er(σ ) . Now, the above equation can be written as:
e i dσi
N
! Z Z
Y 0
DKL qi p = qk (log qk − log sk + log er(σi ) dσi ) dσk + C
i=1 | {z }
constant wrt to qk
Z Z
qk 0
= qk log dσk + C ” qk dσ k +C
sk
| {z }
=1
= DKL ( qk || sk ) + Constant wrt to qk
qk = sk
Z Y
log qk = qi log p dσ -k + Constant wrt to σk
i6=k
= Eq−k [log p] + C
We will now use this result and apply to the Ising model. We know that that
the joint probability of states is:
p(σ) ∝ e−βH(σ)
N P
N N
where H(σ) = − 12
P P
Jσi σj − Bσi
i=1 j=1 i=1
X
= βσk (J µi + B) +C ”
i∈N br(k)
| {z }
H
qk (σk ) = CeβHσk
30
We will now find the value of the constant C.
Z
qk (σk )dσk = 1
eβHσk
qk (σk ) =
+ e−βH
eβH
qk (σk = 1) = Sigmoid(2βH)
qk (σk = −1) = Sigmoid(−2βH)
µk = E [σk ]
= qk (σk = 1) − qk (σk = −1)
eβH − e−βH
=
eβH + e−βH
= tanh(βH)
Note:
• Nbr(k) represents the nearest neighbours of lattice point k
31