2022PhD - Princeton - Bridging Theory and Practice in Deep Learning Optimization and Generalization
2022PhD - Princeton - Bridging Theory and Practice in Deep Learning Optimization and Generalization
in Deep Learning:
Optimization and Generalization
Zhiyuan Li
A Dissertation
Presented to the Faculty
of Princeton University
in Candidacy for the Degree
of Doctor of Philosophy
September 2022
c Copyright by Zhiyuan Li, 2022.
Deep learning has been hugely successful for several important applications in
the past decade, yet mathematical understanding has lagged behind its breathtaking
empirical success. Classic machine learning theory is insufficient to explain various
new phenomena in deep learning and to provide guidance on algorithmic choices,
largely due to an oversimplified black box view that ignores the interaction between
the model and the optimization algorithm. This dissertation presents a collection of
theoretical results that take the interplay between the model and the optimization
algorithm into account and aims to bridge the gaps between theory and practice in
deep learning for both generalization and optimization.
For optimization, we first illustrate the mismatches between traditional optimiza-
tion theory and deep networks with normalization layers by presenting an exponentially
increasing learning rate schedule that works well empirically. We explain this surprise
by establishing its equivalence to SGD with Weight Decay and proving that their
convergence rates are fast and insensitive to initialization scale. Based on this, we
design a variant of BERT named SIBERT, which is trainable by SGD and thus more
memory-efficient than adaptive algorithms like ADAM. Finally we present the first
provable yet general setting where gradient descent decreases loss in a non-monotone
way, as observed empirically.
For generalization, we study the implicit bias of optimization algorithms, which
refers to the phenomenon that the algorithm returns solutions with good generalization
despite the existence of solutions with poor generalization due to the overparametrized
models. We first give a rigorous justification of why convolutional networks are
more sample-efficient than fully-connected networks. Then we provide theoretical
justification for the empirical observation that deep linear networks, including matrix
factorization, trained by gradient descent from small initialization implicitly bias
to low-rank solutions. We also identify a condition when gradient descent with
iii
reparametrization is equivalent to mirror descent which can be used to understand
implicit bias of non-linear models and recovers several previous results. We further
show gradient descent has an implicit bias for ‘flatter’ solutions when having certain
gradient noise or its learning rate is larger than two over sharpness of loss.
iv
Acknowledgements
vi
To my family.
vii
Contents
Abstract . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . iii
Acknowledgements . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . v
1 Introduction 1
1.1 The Black Box View in Existing Theory . . . . . . . . . . . . . . . . 2
1.2 Gaps in Generalization Theory and Practice . . . . . . . . . . . . . . 4
1.3 Gaps in Optimization Theory and Practice . . . . . . . . . . . . . . . 7
1.4 Our Contributions . . . . . . . . . . . . . . . . . . . . . . . . . . . . 8
1.5 Previously Published Works . . . . . . . . . . . . . . . . . . . . . . . 12
viii
2.6 Viewing Exponential Learning Rates via Canonical Optimization Frame-
work . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 30
2.7 Experiments . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 32
2.8 Proofs . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 35
2.9 Side Results on Parameter Norm Convergence . . . . . . . . . . . . . 51
2.10 Scale Invariance in Modern Network Architectures . . . . . . . . . . . 53
ix
II Implicit Bias Along Entire Optimization Trajectory 115
x
6.15 Proofs for Deep Matrix Factorization . . . . . . . . . . . . . . . . . . 224
6.16 Proof of Linear Convergence to Minimizer . . . . . . . . . . . . . . . 238
xi
9 Implicit Bias of Gradient Descent Operating on Edge of Stability:
Sharpness Reduction 394
9.1 Introduction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 395
9.2 Related Works . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 399
9.3 Warm-up: Quadratic Loss Functions . . . . . . . . . . . . . . . . . . 401
9.4 Notations . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 405
9.5 Main Results: Sharpness Reduction . . . . . . . . . . . . . . . . . . . 406
9.6 Proof Overview . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 411
9.7 Experiments . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 416
9.8 Limitation and Future Work . . . . . . . . . . . . . . . . . . . . . . . 419
9.9 Proofs for Results for Quadratic Loss Functions . . . . . . . . . . . . 420
9.10 Setups for General Loss Functions . . . . . . . . . . . . . . . . . . . . 435
9.11 Analysis of Normalized GD on General Loss Functions . . . . . . . . 454
9.12 Phase I, Proofs of the Main Lemmas . . . . . . . . . . . . . . . . . . 461
9.13 Phase II, Proofs of the Main Lemmas . . . . . . . . . . . . . . . . . . 467
9.14 Some Useful Lemmas About Eigenvalues and Eigenvectors . . . . . . 493
√
9.15 Analysis of L . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 494
9.16 Additional Experimental Details . . . . . . . . . . . . . . . . . . . . . 503
Bibliography 506
xii
Chapter 1
Introduction
Despite enormous successful applications, deep learning still lacks good mathematical
understanding. Classic machine learning theory is often found incapable of explaining
or predicting various new phenomena in deep learning, not to mention aiding in the
design of better learning algorithms. One main issue behind this failure is that existing
theory usually holds a black box view which decouples the roles of the model and the
optimization algorithm, meaning when analyzing one of them, the rest one is treated
as a black box. Such decoupling typically leads to oversimplified assumptions and
vacuous bounds for deep learning practice.
In recent years, an effort has emerged to develop mathematical understanding
of deep learning via analyzing the trajectory of the optimization algorithm using
the specific property of the model. This dissertation is along this line of works and
presents some recent progress towards identifying and bridging the gap between
machine learning theory and deep learning practice, in both aspects of optimization
and generalization. Our approaches have a special focus on interplay between the
model and the optimization algorithm.
1
1.1 The Black Box View in Existing Theory
We consider the setting of supervised learning. Given a set of training data Zn = {zi }ni=1
and labels Yn = {yi }ni=1 where the data and label are jointly sampled from some
unknown distribution P , the goal of a machine learning algorithm A is to output a
function hn = A(Zn , yn ) that can predict the label of unseen data, and the quality of
prediction is measured by some given metric called loss function, `(ŷi , yi ), where ŷi is
the prediction made by the learned function hn . The notion of loss function can be
extended to the domain of prediction function as well, where L(h) := ni=1 `(h(zi ), yi )
P
and L(h) := E(z,y)∼P `(h(z), y) are used to denote the training and population loss of
a given function h respectively.
A typical machine learning algorithm consists of two parts: a model, which is a
function class H, and an optimization algorithm, which picks a function from the given
function class attaining a small average loss L(hn ) over the given training dataset
(Zn , Yn ). In the context of deep learning, the model is parametrized by real numbers
x ∈ RD and is in the form of artificial neural networks, which can be viewed as
the concatenation of a sequence of parametrized transformations. The optimization
algorithms are typically iterative and based on first-order local search, including
(stochastic) gradient descent (GD) and its variants, e.g., ADAM [1], AdaGrad [2], etc.
The decoupling of model and optimization algorithm originates from the following
standard three-part decomposition of the error of a learning algorithm (compared to the
ground truth function h∗ ): approximation error, optimization error and generalization
error.
E[L(hn )] − L(h∗ ) =E[L(hn ) − L(hn )] +E[L(hn ) − inf L(h)] + E[ inf L(h)] − L(h∗ )
h∈H h∈H
2
Given a function class H, approximation error refers to the gap between the smallest
loss achieved within the function class and the loss of the groundtruth function.
Optimization error is the gap between the training loss of the returned function and
the smallest training loss within the function class. The generalization error refers
to the difference between the loss of the function returned by the learning algorithm
evaluated on training data and new unseen data and occurs due to the finiteness of
training dataset.
Decoupling model and optimization algorithm is actually useful and ease the
analysis and design for some classical learning methods. In statistical learning theory,
the efficacy of the model can be evaluated without knowing details of the optimization
algorithm, i.e., by treating the optimization algorithm as a black box except assuming
its ability to attain small training error. The common approach there is to relax the
generalization error of the learned function to the supremum of generalization error of
all functions in the function class and further upper bound it by certain complexity
measure of the function class via uniform convergence bounds such as VC dimension,
Rademacher complexity, etc. Thus the design principle of the model is to balance the
trade-off between approximation and generalization error.
Similarly, in optimization theory, a lot of convergence results (and matching lower
bounds) have been derived in the oracle setting, where the optimization algorithm
can only access the training loss via querying an oracle regarding information like loss
value and gradient at certain point and the goal is to return parameters achieving
small optimzation error using as few queries as possible. The entire training loss (and
therefore the model) is viewed as a black box except assuming a few properties like
convexity and global smoothness [3]. Here global smoothness refers to the supremum
of the largest eigenvalue of Hessian matrix of the training loss function over the entire
domain.
3
However, this black box view becomes incapable of explaining phenomena in deep
learning. There are gaps between theory and practice in both aspects of generalization
and optimization, and we discuss them below respectively.
4
measures. As a result, such bounds typically don’t provide strong guidance on how to
make algorithmic choices for better generalization.
One exception here is sharpness of the loss landscape, where Foret et al. [8]
successfully reduced the generalization error by explicitly minimizing the -sharpness
proposed by Keskar et al. [9]. Sharpness based bounds originates from PAC Bayesian
theory [10] and can be made non-vacuous by directly optimizing bound for simple
tasks [11]. Jiang* et al. [12] empirically found that sharpness along the worst direction
1
and average sharpness of all directions correlate best with the generalization error.
However, despite its empirical success, it’s still open why normal training methods
would find solutions with low sharpness and why particular algorithmic choice leads
to flatter solutions than others, e.g. small batch v.s. large batch SGD [9].
In recent years, there is an emerging effort on understanding how particular
optimization algorithms, e.g., gradient descent, can reach solutions with small value
for complexity measures mentioned above or other interesting properties, including
margin maximization in linear models [13] and homogeneous models [14] on separable
data and norm control for infinitely wide neural networks [15]. Such phenomena is
called the implicit bias of the optimization algorithm. We don’t aim to provide a
complete list of related works here, but defer the discussion into each chapter.
However, some of these results are only for simplified or ideal settings and cannot
justify important algorithmic choices in practice towards better generalization, for
example, in the infinitely wide neural networks (or networks in Neural Tangent Kernel
regime [15]), stochastic gradient descent provably converges to the same solution
as full-batch deterministic gradient descent. However in practice, gradient noise in
stochastic gradient descent is observed to be beneficial in terms of generalization,
suggesting neural networks in practice do not completely operate in NTK regime.
1
In the differential form these two sharpness notions are just the the largest eigenvalue of the
Hessian of loss, λ1 (∇2 L(x)), and sum (or average) of eigenvalues of the Hessian, Tr[∇2 L(x)]. For
simplicity we will use sharpness to denote λ1 (∇2 L(x)) in later sections.
5
Thus the generalization part of the dissertation is along the above line of works
on implicit bias but aims to mathematically understand the following previously
unexplained phenomena which are in more realistic settings:
• How does parametrizing the same function class in a different way affect gen-
eralization? It has been observed empirically [16] that gradient on matrix
factorization (writing a symmetric matrix W as U U > and doing gradient descent
with respect to U , which has the same shape of W ) empirically generalizes better
than plain gradient descent (doing gradient descent on W ) when the ground
truth is low-rank. A recent line of works have established correspondence be-
tween mirror descent and gradient descent with a different parametrization and
explains the above phenomena in a restrictive setting. What’s the limit of this
approach using equivalence between mirror descent and reparametrized gradient
descent? How can we resolve the implicit bias for general matrix factorization
problems?
The high dimensional loss landscape is complicated and the largest eigenvalue of
Hessian, usually called sharpness or local smoothness, can vary drastically depending
on the position. It is often either too pessimistic to think about the convergence rate
of an optimization algorithm in the worst case under some given global smoothness
constant, or too optimistic (sometimes even invalid) to assume that there exists such
a global smoothness constant. As a result, the optimization behavior for gradient-
based algorithms in practice are very different from theoretical predictions and thus
optimization theory cannot give effective guidance on algorithmic choice towards faster
optimization. Indeed, the hyperparameter tuning of the optimization algorithm in
practice is more based on a trial-and-error principle.
For example, standard convergence analysis for gradient descent requires the
learning rate to be smaller than two over sharpness, 2/λ1 (∇2 L(x)) to ensure loss
decrease. However, Cohen et al. [17] empirically showed that for all reasonably large
learning rate, gradient descent in deep learning doesn’t decrease loss in a monotone
way and violate the descent lemma. Instead, they found gradient descent typically
operates in a regime named “Edge of Stability”, which means the sharpness hovers
just above the value 2/learning rate), and the training loss behaves non-monotonically
over short timescales, yet consistently decreases over long timescales.
7
Another mystery in optimization for deep neural networks are the success of
normalization methods, including Batch Normalization [18], Layer Normalization [19],
etc., which makes network training much more robust and efficient. There have been a
lot of debates on its mechanism but no consensus is reached. Part of the reason that
the classical optimization theory fails on networks with normalization layers is that
the the usage of normalization layers also makes the sharpness of the network vary
drastically, or more specifically, scaling inversely to the squared norm of the parameter.
Thus again no single global smoothness constant is correct to be assumed, as it could
be either too optimistic or too pessimistic depending on the actual trajectory of the
optimization algorithm.
In the optimization part of the dissertation, we aim to answer the following
optimization questions emerging with the usage of modern network architectures:
• Under what settings can Edge of Stability regime occur? How can gradient
descent decrease the training loss without the descent lemma and in a non-
monotone way?
8
In the first part, we study the unconventional optimization behavior of networks
with normalization layers and design more robust and memory-efficient training
methods for BERT as an application.
• In Chapter 2, we show that training can be done using SGD with momentum and
an exponentially increasing learning rate schedule, i.e., learning rate increases by
some (1 + α) factor in every epoch for some α > 0 and prove that it is equivalent
to the standard setting of BatchNorm + SGD + Standard Rate Tuning + Weight
Decay + Momentum. This equivalence holds for other normalization layers as
well, as long as their usage can make the loss function invariant to the scaling of
the parameters.
9
propose a novel clipping method named Global Relative Clipping and show that
it enhances training stability both theoretically and experimentally.
11
• In Chapter 9, we mathematically analyze a new mechanism of implicit regular-
ization in the EoS phase, whereby GD updates due to non-smooth loss landscape
turn out to evolve along some deterministic flow on the manifold of minimum
loss. This is in contrast to many previous results about implicit bias either
relying on infinitesimal updates or noise in gradient. Formally, for any smooth
function L with certain regularity condition, this effect is demonstrated for (1)
η
Normalized GD, i.e., GD with a varying LR ηt = k∇L(x(t))k and loss L; (2) GD
p
with constant LR and loss L − minx L(x). Both provably enter the Edge of
Stability, with the associated flow on the manifold minimizing λ1 (∇2 L). The
above theoretical results have been corroborated by an experimental study.
12
Part I
13
Chapter 2
An Exponentially Increasing
Learning Rates Schedule for
Normalized Networks
Intriguing empirical evidence exists that deep learning can work well with exotic
schedules for varying the learning rate. This chapter suggests that the phenomenon
may be due to Batch Normalization or BN[18], which is ubiquitous and provides
benefits in optimization and generalization across all standard architectures. The
following new results are shown about BN with weight decay and momentum (in other
words, the typical use case which was not considered in earlier theoretical analyses of
stand-alone BN [18, 28, 29]
• Training can be done using SGD with momentum and an exponentially increasing
learning rate schedule, i.e., learning rate increases by some (1 + α) factor in
every epoch for some α > 0. (Precise statement in the paper.) To the best of
our knowledge this is the first time such a rate schedule has been successfully
used, let alone for highly successful architectures. As expected, such training
14
rapidly blows up network weights, but the network stays well-behaved due to
normalization.
2.1 Introduction
z 2 − (1 + γ − λη)z + γ = 0, (2.1)
The above theorem requires that the product of learning rate and weight decay
√
factor, λη, is small than (1 − γ)2 , which is almost always satisfied in practice. The
rigorous and most general version of above theorem is Theorem 2.4.12, which deals
with multi-phase LR schedule, momentum and weight decay.
There are other recently discovered exotic LR schedules, e.g. Triangular LR
schedule [33] and Cosine LR schedule [34], and our exponential LR schedule is an
extreme example of LR schedules that become possible in presence of BN. Such an
exponential increase in learning rate seems absurd at first sight and to the best of
our knowledge, no deep learning success has been reported using such an idea before.
It does highlight the above-mentioned viewpoint that in deep learning, optimization
and regularization are not easily separated. Of course, the exponent trumps the effect
16
of initial lr very fast (See Figure 2.3), which explains why training with BN and
WD is not sensitive to the scale of initialization, since with BN, tuning the scale of
initialization is equivalent to tuning the initial LR η while fixing the product of LR
and WD, ηλ (See Lemma 2.4.7).
Note that it is customary in BN to switch to a lower LR upon reaching a plateau in
the validation loss. According to the analysis in the above theorem, this corresponds
to an exponential growth with a smaller exponent, except for a transient effect when a
correction term is needed for the two processes to be equivalent (see discussion around
Theorem 2.4.12).
Thus the final training algorithm is roughly as follows: Start from a convenient LR
like 0.1, and grow it at an exponential rate with a suitable exponent. When validation
loss plateaus, switch to an exponential growth of LR with a lower exponent. Repeat
the procedure until the training loss saturates.
In Section 2.5, we demonstrate on a toy example how weight decay and normaliza-
tion are inseparably involved in the optimization process. With either weight decay or
normalization alone, SGD will achieve zero training error. But with both turned on,
SGD fails to converge to global minimum.
In Section 2.7, we experimentally verify our theoretical findings on CNNs and
ResNets. We also construct better exponential LR schedules by incorporating the
Cosine LR schedule on CIFAR10, which opens the possibility of even more general
theory of rate schedule tuning towards better performance.
There have been other theoretical analyses of training models with scale-invariance.
Cho and Lee [35] proposed to run Riemanian gradient descent on Grassmann manifold
G(1, n) since the weight matrix is scaling invariant to the loss function. observed
17
ηw
that the effective stepsize is proportional to kxt k2
. Arora et al. [36] show the gradient
is always perpendicular to the current parameter vector which has the effect that
norm of each scale invariant parameter group increases monotonically, which has an
auto-tuning effect. Wu et al. [37] proposes a new adaptive learning rate schedule
motivated by scale-invariance property of Weight Normalization.
Previous work for understanding Batch Normalization. Santurkar et al.
[28] suggested that the success of BN has does not derive from reduction in Internal
Covariate Shift, but by making landscape smoother. Kohler et al. [38] essentially shows
linear model with BN could achieve exponential convergence rate assuming gaussian
inputs, but their analysis is for a variant of GD with an inner optimization loop rather
than GD itself. Bjorck et al. [39] observed that the higher learning rates enabled by
BN empirically improves generalization. Arora et al. [36] proved that with certain
mild assumption, (S)GD with BN finds approximate first order stationary point with
any fixed learning rate. None of the above analyses incorporated weight decay, but
Zhang et al. [40], Hoffer et al. [41], Van Laarhoven [42? ? ] argued qualitatively that
weight decay makes parameters have smaller norms, and thus the effective learning
ηw
rate, kxt k2
is larger. They described experiments showing this effect but didn’t have
a closed form theoretical analysis like ours. None of the above analyses deals with
momentum rigorously.
18
normalization layers, including Batch Normalization [18], Group Normalization [30],
Layer Normalization [19], Instance Norm [31], etc.
Implementations of SGD with Momentum/Nesterov comes with subtle variations
in literature. We adopt the variant from [43], also the default in PyTorch [44]. L2
regularization (a.k.a. Weight Decay) is another common trick used in deep learning.
Combining them together, we get the one of the mostly used optimization algorithms
below.
Definition 2.3.1. [SGD with Momentum and Weight Decay] At iteration t, with
randomly sampled batch Bt , update the parameters xt and momentum vt as following:
where ηt , λt are the learning rate and weight decay factor at iteration t respectively
and γ is the momentum coefficient. Usually, v0 is initialized to be 0.
For ease of analysis, we will use the following equivalent of Definition 2.3.1.
xt − xt−1 xt−1 − xt−2 λt−1 2
=γ − ∇x (L(xt−1 ) + kxt−1 k2 , (2.4)
ηt−1 ηt−2 2
x0 −x−1
where η−1 and x−1 must be chosen in a way such that v0 = η−1
is satisfied, e.g.
when v0 = 0, x−1 = x0 and η−1 could be arbitrary.
19
Lemma 2.3.2 (Scale Invariance). If for any c ∈ R+ , L(x) = L(cx), then
(1). h∇x L, xi = 0;
(2). ∇x L x=x0
= c∇x L x=cx0
, for any c > 0
As a warm-up in Section 2.4.1 we show that if momentum is turned off then Fixed
LR + Fixed WD can be translated to an equivalent Exponential LR. In Section 2.4.2
we give a more general analysis on the equivalence between Fixed LR + Fixed WD
+ Fixed Momentum Factor and Exponential LR + Fixed Momentum Factor. While
interesting, this still does completely apply to real-life deep learning where reaching
full accuracy usually requires multiple phases in training where LR is fixed within a
phase and reduced by some factor from one phase to the next. Section 2.4.3 shows
how to interpret such a multi-phase LR schedule + WD + Momentum as a certain
multi-phase exponential LR schedule with Momentum.
SGD
We use notation of Section 2.3 and assume LR is fixed over iterations, i.e. ηt = η0 ,
and γ (momentum factor) is set as 0. We also use λ to denote WD factor and x0 to
denote the initial parameters.
The intuition should be clear from Lemma 2.3.2, which says that shrinking parame-
ter weights by factor ρ (where ρ < 1) amounts to making the gradient ρ−1 times larger
without changing its direction. Thus in order to restore the ratio between original
parameter and its update (LR×Gradient), the easiest way would be scaling LR by ρ2 .
This suggests that scaling the parameter x by ρ at each step is equivalent to scaling
the LR η by ρ−2 .
20
To prove this formally we use the following formalism. We’ll refer to the vector
(x, η) the state of a training algorithm and study how this evolves under various
combinations of parameter changes. We will think of each step in training as a
mapping from one state to another. Since mappings can be composed, any finite
number of steps also correspond to a mapping. The following are some basic mappings
used in the proof.
1. Run GD with WD for a step: GDρt (x, η) = (ρx − η∇Lt (x), η);
For example, when ρ = 1, GD1t is vanilla GD update without WD, also abbreviated as
GDt . When ρ = 1 − λη0 , GD1−λη
t
0
is GD update with WD λ and LR η0 . Here Lt is
the loss function at iteration t, which is decided by the batch of the training samples
Bt in tth iteration. Below is the main result of this subsection, showing our claim that
GD + WD ⇔ GD+ Exp LR (when Momentum is zero). It will be proved after a
series of lemmas.
Theorem 2.4.1 (WD ⇔ Exp LR). For every ρ < 1 and positive integer t following
holds:
h t 2t
i −1 −2 −2 −1
GDρt−1 ◦ · · · ◦ GDρ0 = Πρ1 ◦ Πρ2 ◦ Πρ2 ◦ GDt−1 ◦ Π2ρ ◦ · · · ◦ GD1 ◦ Πρ2 ◦ GD0 ◦ Πρ2 .
With WD being λ, ρ is set as 1 − λη0 and thus the scaling factor of LR per iteration
is ρ−2 = (1 − λη0 )−2 , except for the first iteration it’s ρ−1 = (1 − λη0 )−1 .
We first show how to write GD update with WD as a composition of above defined
basic maps.
−1
Lemma 2.4.2. GDρt = Πρ2 ◦ Πρ1 ◦ GDt ◦ Πρ2 .
21
−2
Below we will define the proper notion of equivalence such that (1). Πρ1 ∼ Πρ2 ,
−1 −1
which implies GDρt ∼ Πρ2 ◦ GDt ◦ Πρ2 ; (2) the equivalence is preserved under future
GD updates.
We first extend the equivalence between weights (same direction) to that between
states, with additional requirement that the ratio between the size of GD update and
that of parameter are the same among all equivalent states, which yields the notion of
Equivalent Scaling.
The following lemma shows that equivalent scaling commutes with GD update
with WD, implying that equivalence is preserved under GD update (Lemma 2.4.4).
This anchors the notion of equivalence — we could insert equivalent scaling anywhere
in a sequence of basic maps(GD update, LR/parameter scaling), without changing
the final network.
2 2
Lemma 2.4.4. For any constant c, ρ > 0 and t ≥ 0, GDρt ◦[Πc1 ◦Πc2 ] = [Πc1 ◦Πc2 ]◦GDρt .
c c
In other words, (x, η) ∼ (x0 , η 0 ) =⇒ GDρt (x, η) ∼ GDρt (x0 , η 0 ).
Definition 2.4.5 (Equivalent Maps). Two maps F, G are equivalent iff ∃c > 0,
2 c
F = Πc1 ◦ Πc2 ◦ G, which is also denoted by F ∼ G.
ρ −1 −1
Proof of Theorem 2.4.1. By Lemma 2.4.2,, GDρt ∼ Πρ2 ◦GDt ◦Πρ2 . By Lemma 2.4.4,
c c
GD update preserves map equivalence, i.e. F ∼ G ⇒ GDρt ◦ F ∼ GDρt ◦ G, ∀c, ρ > 0.
Thus,
ρt −1 −2 −2 −1
GDρt−1 ◦ · · · ◦ GDρ0 ∼ Πρ2 ◦ GDt−1 ◦ Πρ2 ◦ · · · ◦ GD1 ◦ Πρ2 ◦ GD0 ◦ Πρ2 .
22
2.4.2 Replacing WD by Exponential LR: Case of constant
LR with momentum
In this subsection the setting is the same to that in Subsection 2.4.1 except that
the momentum factor is γ instead of 0. Suppose the initial momentum is v0 , we
set x−1 = x0 − v0 η. Presence of momentum requires representing the state of the
algorithm with four coordinates, (x, η, x0 , η 0 ), which stand respectively for the current
parameters/LR and the buffered parameters/LR (from last iteration) respectively.
Similarly, we define the following basic maps and equivalence relationships.
ρ 0 0 x−x0
1. Run GD with WD for a step: GDt (x, η, x , η ) = ρx + η γ η0 − ∇Lt (x) , η, x, η ;
Lemma 2.4.8. For any input (x, η, x0 , η), if α > 0 is a root of α + γα−1 = ρ + γ, then
h i
−1
GDρt (x, η, x0 , η) = Πα4 ◦ Πα2 ◦ Πα1 ◦ GDt ◦ Πα2 ◦ Πα3 ◦ Πα4 (x, η, x0 , η). In other words,
h −1 i
α −1 −1 −1
GDρt (x, η, x0 , η) ∼ Πα3 ◦ Πα4 ◦ Πα2 ◦ GDt ◦ Πα2 ◦ Πα3 ◦ Πα4 (x, η, x0 , η). (2.5)
Though looking complicated, the RHS of Equation (2.5) is actually the desired
−1 −1
Πα2 ◦ GDt ◦ Πα2 conjugated with some scaling on momentum part Πα3 ◦ Πα4 , and
−1 −1
Πα3 ◦ Πα4 in the current update cancels with the Πα3 ◦ Πα4 in the next update. Now we
are ready to show the equivalence between WD and Exp LR schedule when momentum
is turned on for both.
Theorem 2.4.9 (GD + WD ⇔ GD+ Exp LR; With Momentum). The following
defined two sequences of parameters ,{xt }∞ xt } ∞
t=0 and {e et = αt xt , thus they
t=0 , satisfy x
correspond to the same networks in function space, i.e. fxt = fxet , ∀t ∈ N, given
x e−1 = x−1 α, and ηet = η0 α−2t−1 .
e0 = x0 , x
et −e
x xt−1 xt−1 −e
γ(e xt−2 )
2. ηet
= ηet−1
− ∇x L(e
xt−1 )
24
where α is a positive root of equation x2 − (1 + γ − λη0 )x + γ = 0, which is always
smaller than 1(See Section 2.8.1). When γ = 0, α = 1 − λη0 is the unique non-zero
solution.
√
Remark 2.4.10. Above we implicitly assume that λη0 ≤ (1− γ)2 such that the roots
are real and this is always true in practice. For instance of standard hyper-parameters
λη0
where γ = 0.9, η0 = 0.1, λ = 0.0005, √
(1− γ)2
≈ 0.019 1.
h i
−1
Proof. Note that (e
x0 , ηe0 , x
e−1 , ηe−1 ) = Πα2 ◦ Πα3 ◦ Πα4 (x0 , η0 , x0 , η0 ), it suffices to
show that
h −1 −1 −1 −2 −2 −1
i
Πα3 ◦ Πα4 ◦ Πα2 ◦ GDt−1 ◦ Πα2 ◦ · · · ◦ GD1 ◦ Πα2 ◦ GD0 ◦ Πα2 ◦ Πα3 ◦ Πα4 (x0 , η0 , x0 , η0 )
αt
∼ GD1−λη
t−1
0
◦ · · · ◦ GD1−λη
0
0
(x0 , η0 , x0 , η0 ), ∀t ≥ 0.
which follows immediately from Lemma 2.4.7 and Lemma 2.4.8 by induction.
LR phases
Usual practice in deep learning shows that reaching full training accuracy requires
reducing the learning rate a few times.
Definition 2.4.11. Step Decay is the (standard) learning rate schedule, where training
has K phases I = 0, 1, . . . , K − 1, where phase I starts at iteration TI (T0 = 0), and
all iterations in phase I use a fixed learning rate of ηI∗ .
q
2
1+γ−ληI∗ + (1+γ−ληI∗ ) −4γ
where αI∗ = 2
, ηe0 = η0 · (α0∗ )−1 = η0∗ · (α0∗ )−1 .
The analysis in previous subsection give the equivalence within each phase, where
the same LR is used throughout the phase. To deal with the difference between
buffered LR and current LR when entering new phases, the idea is to pretend ηt−1 = ηt
xt −xt−1
and xt−1 becomes whatever it needs to maintain ηt−1
such that we can again apply
Lemma 2.4.8, which requires the current LR of the input state is equal to its buffered
LR. Because scaling α in RHS of Equation (2.5) is different in different phases, so
26
unlike what happens within each phase, they don’t cancel with each other at phase
transitions, thus remaining as a correction of the momentum. The proofs are delayed
to Section 2.8, where we proves a more general statement allowing phase-dependent
WD, {λI }K−1
I=0 .
For example, with WD 0.0005, max LR 0.1, momentum factor 0.9, the ratio is within
1 ± 0.0015 ∗ 0.9t−TI , meaning TEXP and TEXP++ are very close for Step Decay with
standard hyperparameters.
define the same sequence of network functions, i.e. fxt = fxet , ∀t ∈ N, given the initial
conditions, x e0 = P0 x0 , xe−1 = P−1 x−1 .
1. xtη−x t−1
t−1
= γ xt−1 −xt−2
ηt−2
− ∇ x (L(xt−1 ) +
λt−1
2
kxt−1 k22 , for t = 1, 2, . . .;
et −e
x xt−1 −e
2. = γ xet−1ηet−2
ηet−1
xt−2
− ∇x L(e
xt−1 ), for t = 1, 2, . . .,
t
αi−1 , ∀t ≥ −1 and αt recursively defined as
Q
where ηet = Pt Pt+1 ηt , Pt =
i=−1
ηt−1 −1
αt = −ηt−1 λt−1 + 1 + γ(1 − αt−1 ), ∀t ≥ 1. (2.7)
ηt−2
ηt }∞
The LR schedule {e t=0 is called Tapered Exponential ++, or TEXP++.
27
2.5 Example Illustrating Interplay of Weight De-
input of the last layer are already separable, and w.l.o.g. we assume the label is equal
to the sign of the first coordinate of z ∈ Rm , namely sign (x1 ) . Thus the training loss
and training error are simply
Case 1: WD alone: Since both the above objective with L2 regularization is strongly
convex and smooth in x, vanilla GD with suitably small learning rate could get
arbitrarily close to the global minimum for this regularized objective. In our case,
28
q
large batch SGD behaves similarly to GD and can achieve O( ηλ
B
) test error following
the standard analysis of convex optimization.
Case 2: BN alone: Add a BN layer after the linear layer, and fix scalar and bias term
to 1 and 0. The objective becomes
> x
LBN (x) = E [LBN (x, z)] = E ln(1 + exp(−z y)) .
z∼N (0,Im ),y=sign(x1 ) z∼N (0,Im ),y=sign(x1 ) kxk
From Section 2.8.6, there’s some constant C, such that ∀x ∈ Rm with constant
C
probability, k∇x LBN (x, z)k ≥ kxk
. By Pythagorean Theorem, kxt+1 k4 = (kxt k2 +
η 2 k∇x LBN (xt , z)k2 )2 ≥ kxt k4 + 2η 2 kxt k2 k∇x LBN (xt , z)k2 . As a result, for any fixed
learning rate, kxt+1 k4 ≥ 2 ti=1 η 2 kxk2 k∇x LBN (xi , z)k2 grows at least linearly with
P
high probability. Following the analysis by Arora et al. [36], this is like reducing the
effective learning rate, and when kxt k is large enough, the effective learning rate is
small enough, and thus SGD can find the local minimum, which is the unique global
minimum.
Case 3: Both BN and WD: When BN and WD are used together, no matter how
small the noise is, which comes from the large batch size, the following theorem shows
√
that SGD will not converge to any solution with error smaller than O( ηλ), which is
independent of the batch size (noise level).
Proof Sketch. (See full proof in Section 2.8.) The high level idea of this proof is that
if the test error is low, the weight is restricted in a small cone around the global
minimum, and thus the amount of the gradient update is bounded by the size of the
cone. In this case, the growth of the norm of the weight by Pythagorean Theorem is
not large enough to cancel the shrinkage brought by weight decay. As a result, the
29
norm of the weight converges to 0 geometrically. Again we need to use the lower bound
for size of the gradient, that k∇x Lt k = Θ( kxηt k m
p
B
) holds with constant probability.
Thus the size of the gradient will grow along with the shrinkage of kxt k until they’re
comparable, forcing the weight to leave the cone in next iteration.
This section tries to explain why the efficacy of exponential LR in deep learning is
mysterious to us, at least as viewed in the canonical framework of optimization theory.
Canonical framework for analysing 1st order methods This focuses on proving that
each —or most—steps of GD noticeably reduce the objective, by relying on some
assumption about the spectrum norm of the hessian of the loss, and most frequently,
the smoothness, denoted by β. Specifically, for GD update xt+1 = xt − η∇L(xt ), we
have
β βη
L(xt+1 ) − L(xt ) ≤ (xt+1 − xt )> ∇L(xt ) + kxt+1 − xt k2 = −η(1 − )k∇L(xt )k2 .
2 2
When β < η2 , the first order term is larger than the second order one, guaranteeing
the loss value decreases. Since the analysis framework treats the loss as a black box
(apart from the assumed bounds on the derivative norms), and the loss is non-convex,
the best one can hope for is to prove speedy convergence to a stationary point (where
gradient is close to 0). An increasing body of work proves such results.
Now we turn to difficulties in understanding the exponential LR in context of the
above framework and with scale-invariance in the network.
1. Since loss is same for x and c · x for all c > 0 a simple calculation shows that
along any straight line through the origin, smoothness is a decreasing function of
30
c, and is very high close to origin. (Note: it is also possible to one can show the
following related fact: In any ball containing the origin, the loss is nonconvex.)
Thus if one were trying to apply the canonical framework to argue convergence
to a stationary point, the natural idea would be to try to grow the norm
of the parameters until smoothness drops enough that the above-mentioned
Canonical Framework starts to apply. Arora et al. [36] showed this happens in
GD with fixed LR (WD turned off), and furthermore the resulting convergence
rate to stationary point is asymptotically similar to analyses of nonconvex
optimization with learning rate set as in the Canonical framework. Santurkar
et al. [28] observed similar phenomenon in experiments, which they described as
a smoothening effect of the objective due to BN.
31
2
3. It can be shown that if the local smoothness is upperbounded by η
(as stipulated
in Canonical Framework) during a sequence xt (t = 1, 2, . . .) of GD updates with
WD and constant LR then such sequence satisfies xt → 0. This contrasts with
the usual experimental observation that xt stays bounded away from 0. One
should thus conclude that in practice, with constant LR and WD, smoothness
doesn’t always stay small (unlike the above analyses where WD is turned off).
2.7 Experiments
32
Figure 2.3: Instant LR decay has only temporary effect when LR growth ηet /eηt−1 − 1
is large. The blue line uses an exponential LR schedule with constant exponent. The
orange line multiplies its LR by the same constant each iteration, but also divide
LR by 10 at the start of epoch 80 and 120. The instant LR decay only allows the
parameter to stay at good local minimum for 1 epoch and then diverges, behaving
similarly to the trajectories without no instant LR decay.
q
2
1+γ−ληI∗ + (1+γ−ληI∗ ) −4γ
where αI∗ = 2
, ηe0 = η0 · (α0∗ )−1 = η0∗ · (α0∗ )−1 .
We applied the TEXP LR schedule (Theorem 2.4.12) on the Cosine LR schedule [34],
where the learning rate changes every epoch, and thus correction terms cannot be
1+cos( Tt π)
ignored. The LR at epoch t ≤ T is defined as: ηt = η0 2
. Our experiments
33
Figure 2.4: Instant LR decay is crucial when LR growth ηet /e ηt−1 − 1 is very small.
The original LR of Step Decay is decayed by 10 at epoch 80, 120 respectively. In the
third phase, LR growth ηet /e ηt−1 − 1 is approximately 100 times smaller than that in
the third phase, it would take TEXP-- hundreds of epochs to reach its equilibrium.
As a result, TEXP achieves better test accuracy than TEXP--. As a comparison, in
ηt−1 − 1 is only 10 times smaller than that in the first phase and
the second phase, ηet /e
it only takes 70 epochs to return to equilibrium.
Figure 2.5: The orange line corresponds to PreResNet32 trained with constant LR
and WD divided by 10 at epoch 80 and 120. The blue line is TEXP-- corresponding
to Step Decay schedule which divides LR by 10 at epoch 80 and 120. They have
similar trajectories and performances by a similar argument to Theorem 2.4.12.(See
Theorem 2.8.2 and its proof in Section 2.8)
34
Figure 2.6: Both Cosine and Step Decay schedule behaves almost the same as their
exponential counterpart, as predicted by our equivalence theorem. The (exponential)
Cosine LR schedule achieves better test accuracy, with a entirely different trajectory.
show this hybrid schedule with Cosine LR performs better on CIFAR10 than Step
Decay, but this finding needs to be verified on other datasets.
2.8 Proofs
Lemma 2.8.1 (Some Facts about Equation (2.1)). Suppose z 1 , z 2 (z 1 ≥ z 2 ) are the
two real roots of the the following equation, we have
z 2 − (1 + γ − λη)z + γ = 0
√ √
1 1+γ−λη+ (1−γ)2 −2(1+γ)λη+λ2 η 2 2 1+γ−λη− (1−γ)2 −2(1+γ)λη+λ2 η 2
1. z = 2
, z = 2
√ 2
2. z 1 , z 2 are real ⇐⇒ λη ≤ (1 − γ) ;
3. z 1 z 2 = γ, z 1 + z 2 = (1 + γ − λη);
4. γ ≤ z 2 ≤ z 1 ≤ 1;
λη 1 λη
5. Let t = 1−γ
, we have z 1 ≥ 1+t
≥1−t=1− 1−γ
.
35
6. if we view z 1 (λη), z 2 (λη) as functions of λη, then z 1 (λη) is monotone decreasing,
z 2 (η) is monotone increasing.
5. It holds that
p
1 − γ + λη − (1 − γ)2 − 2(1 + γ)λη + λ2 η 2
1 − z1 =
q 2
1+γ
1+t− 1− 1−γ
t + t2
= (1 − γ)
2
1+γ
2t + 2 1−γ t
= (1 − γ) q
2(1 + t + 1 − 1+γ1−γ
t + t2 )
4
1−γ
t
≤ (1 − γ)
4(1 + t)
t
=
(1 + t)
η −1
GDρt (x, η) = (ρx − η∇Lt (x), η) = [Πρ1 ◦ Πρ2 ◦ GDt ](x, ) = [Πρ1 ◦ Πρ2 ◦ GDt ◦ Π2ρ ](x, η).
ρ
36
Proof of Lemma 2.4.4. For any (x, η), we have
h i
2 ∗
GDt ◦ Πc1 ◦ Πc2 (x, η) = GDt (cx, c2 η) = (cx − c2 x∇Lt (cx), c2 η) = (c(x − ∇Lt (x)), c2 η)
h 2
i
= Πc1 ◦ Πc2 ◦ GDt (x, η),
∗
where = is because of Scale Invariance of Lt (Lemma 2.3.2).
Proof of Lemma 2.4.7. For any input (x, η, x0 , η 0 ), it’s easy to check both composed
maps have the same outputs on the 2,3,4th coordinates, namely (c2 η, cx, c2 η 0 ). For
the first coordinate, we have
x − x0
0 2
ρ 2
2
GD (cx, c η, cx , c η) 1 = ρcx + c η γ − ∇Lt (cx)
η0
x − x0
∗
=c x + η γ − ∇Lt (x)
η0
=c [GDρ (x, η, x0 , η)]1 ,
∗
where = is because of Scale Invariance of Lt (Lemma 2.3.2).
Proof of Lemma 2.4.8. For any input (x, η, x0 , η 0 ), it’s easy to check both composed
maps have the same outputs on the 2,3,4th coordinates, namely (η, x, η). For the first
coordinate, we have
hh −1
i i
Πα3 Πα4 Πα2 Πα2 Πα3 Πα4 (x, η, x , η) = α GDt (x, α−1 η, αx0 , αη) 1
0
◦ ◦ ◦ ◦ GDt ◦ ◦
1
x − x0
=α x + α−1 η γ − ∇Lt (x)
η
x0
= α + γα−1 x − η∇Lt (x) − ηγ
η
= (ρ + γ) x − η∇Lt (x) − γx0 = [GDρt (x, η, x0 , η)]1
37
2.8.4 Omitted proofs of Theorem 2.4.12
Theorem 2.8.2 (A stronger version of Theorem 2.4.12). There exists a way to correct
the momentum only at the first iteration of each phase, such that the following Tapered-
Exponential LR schedule (TEXP) {e
ηt } with momentum factor γ and no WD, leads
the same sequence networks in function space compared to that of Step Decay LR
schedule(Definition 2.4.11) with momentum factor γ and phase-dependent WD λ∗I in
phase I, where phase I lasts from iteration TI to iteration TI+1 , T0 = 0.
∗
)−2
ηet × (αI−1
if TI−1 + 1 ≤ t ≤ TI − 1, I ≥ 1
ηet+1 = , (2.9)
ηI∗
(αI∗ )−1 (αI−1
∗
)−1
ηet ×
∗ × if t = TI , I ≥ 1
ηI−1
q
2
1+γ−λ∗I ηI∗ + (1+γ−λ∗I ηI∗ ) −4γ
where αI∗ = 2
, ηe0 = η0 (α0∗ )−1 = η0∗ (α0∗ )−1 .
Towards proving Theorem 2.4.12, we need the following lemma which holds by
expanding the definition, and we omit its proof.
Definition 2.8.4 (Equivalent Maps). For two maps F and G, we say F is equivalent
h i
c c2 c c2 c
to G iff ∃c > 0, F = Π1 ◦ Π2 ◦ Π3 ◦ Π4 ◦ G, which is also denoted by F ∼ G.
38
Note that for any (x, η, x0 , η 0 ), [N (x, η, x0 , η 0 )]2 = [N (x, η, x0 , η 0 )]4 . Thus as a direct
consequence of Lemma 2.4.8, the following lemma holds.
α −1 −1 −1 −1
Lemma 2.8.5. ∀ρ, α > 0, GDρt ◦ N ∼ Πα3 ◦ Πα4 ◦ Πα2 ◦ GDt ◦ Πα2 ◦ Πα3 ◦ Πα4 ◦ N .
Proof of Theorem2.4.12. Starting with initial state (x0 , η0 , x−1 , η−1 ) where η−1 = η0
and a given LR schedule {ηt }t≥0 , the parameters generated by GD with WD and
momentum satisfies the following relationship:
ηt+1
ηt 1−ηt λt
(xt+1 , ηt+1 , xt , ηt ) = Π2 ◦ GDt (xt , ηt , xt−1 , ηt−1 ).
b
Define Ft = Fb ◦ Fb−1 ◦ . . . ◦ Fa , for a ≤ b. By Lemma 2.8.3 and Lemma 2.8.5,
t=a
letting αt be the root of x2 − (γ + 1 − ηt−1 λt−1 )x + γ = 0, we have
ηt+1
T −1
Π2ηt
◦ GDt1−ηt λt
t=0
ηt+1
T −1
= Π2ηt
◦ GDt1−ηt λt ◦N
t=0
−1
TQ
αi T −1 ηt+1
i=0 ηt α−1 α−1 α−1 α−1 α α (2.10)
∼ Π2 ◦ Π3 t+1 ◦ Π4 t+1 ◦ Π2 t+1 ◦ GDt ◦ Π2 t+1 ◦ Π3 t+1 ◦ Π4 t+1 ◦N
t=0
ηT
ηT −1 α−1 α−1 α−1
=Π2 ◦ Π3 T −1 ◦ Π4 T −1 ◦ Π2 T ◦ GDT −1 ◦
i
T −1 h α−1 α−1
α−1
= Π2 t+1 t
◦ Ht ◦ GDt−1 ◦ Π2 1 ◦ Πα3 1 ◦ Πα4 1 ◦ N,
t=1
−1
TQ
αi
i=0
where ∼ is because of Lemma 2.8.5, and Ht is defined as
ηt−1 ηt
α α α−1 α−1 α−1 ηt−1
Ht = Πα2 t ◦ Π2 ηt
◦ Π3 t+1 ◦ Π4 t+1 ◦N ◦ Π3 t ◦ Π4 t ◦ Π2 t ◦ Π2 .
Since the canonicalization map N only changes the momentum part of the state, it’s
easy to check that Ht doesn’t touch the current parameter x and the current LR η.
Thus Ht only changes the momentum part of the input state. Now we claim that
39
Ht ◦ GDt−1 = GDt−1 whenever ηt = ηt−1 . This is because when ηt = ηt−1 , αt = αt+1 ,
thus Ht ◦ GDt−1 = GDt−1 . In detail,
Ht ◦ GDt−1
α−1 α−1 α−1
=Πα2 t ◦ Πα3 t ◦ Πα4 t ◦ N ◦ Π3 t ◦ Π4 t ◦ Π2 t ◦ GDt−1
∗ α−1 α−1 α−1
=Πα2 t ◦ Πα3 t ◦ Πα4 t ◦ Π3 t ◦ Π4 t ◦ Π2 t ◦ GDt−1
=GDt−1 ,
∗
where = is because GD update GDt sets η 0 the same as η, and thus ensures the input
of N has the same momentum factor in buffer as its current momentum factor, which
makes N an identity map.
0
Thus we could rewrite Equation (2.10) with a “sloppy”version of Ht , Ht =
Ht ηt 6= ηt−1 ;
:
Id o.w.
ηt+1
T −1
ηt 1−ηt λt
Π2 ◦ GDt
t=0
ηT i
α−1 α−1 α−1 T −1 h α−1 α−1
ηT −1 α−1
=Π2 ◦ Π3 T −1 ◦ Π4 T −1 ◦ Π2 T ◦ GDT −1 ◦ Π2 t+1 t
◦ Ht0
◦ GDt−1 ◦ Π2 1 ◦ Πα3 1 ◦ Πα4 1 ◦ N
t=1
ηT i
α−1 α−1 α−1 T −1 h α−1 −1
η t+1 αt 0 α−1
=Π2 T −1 ◦ Π3 T −1 ◦ Π4 T −1 ◦ Π2 T ◦ GDt ◦ Π2 ◦ Ht ◦ GD0 ◦ Π2 1 ◦ Πα3 1 ◦ Πα4 1 ◦ N,
t=1
(2.11)
Now we construct the desired sequence of parameters achieved by using the Tapered
Exp LR schedule 2.9 and the additional one-time momentum correction per phase.
Let (e
x0 , ηe0 , x
e−1 , ηe−1 ) = (x0 , η0 , x−1 , η0 ), and
40
α−1
h i
α1 α1
(e e0 , ηe0 ) = GD0 ◦ Π2 ◦ Π3 ◦ Π4 ◦ N (e
x1 , ηe1 , x 1
x0 , ηe0 , x
e−1 , ηe−1 )
α−1
h i
= GD0 ◦ Π2 1 ◦ Πα3 1 ◦ Πα4 1 (e x0 , ηe0 , x
e−1 , ηe−1 );
α−1 α−1
h i
(e et , ηet ) = GDt ◦ Π2 t+1 t ◦ Ht0 (e
xt+1 , ηet+1 , x xt , ηet , x
et−1 , ηet−1 ).
we claim {e
xt }t=0 is the desired sequence of parameters. We’ve already shown that
xt ∼ x
et , ∀t. Clearly {e
xt }t=0 is generated using only vanilla GD, scaling LR and
6 TI for any I, ηt = ηt−1 and thus
modifying the momentum part of the state. When t =
Ht0 = Id. Thus the modification on the momentum could only happen at TI (I ≥ 0).
Also it’s easy to check that αt = αI∗ , if TI + 1 ≤ t ≤ TI+1 .
41
xt − xt−1 xt−1 − xt−2 λt−1 2
=γ − ∇x (L(xt−1 ) + kxt−1 k2
ηt−1 ηt−2 2
Take gradient xt − xt−1 xt−1 − xt−2
=======⇒ =γ − ∇x L(xt−1 ) + λt−1 xt−1
ηt−1 ηt−2
Scale Invariance xt − xt−1 xt−1 − xt−2
=========⇒ =γ − Pt−1 ∇x L(ext−1 ) + λt−1 xt−1
ηt−1 ηt−2
Rescaling Pt (xt − xt−1 ) Pt−2 (xt−1 − xt−2 ) xt−1
=====⇒ =γ − ∇x L(e xt−1 ) − λt−1
Pt Pt−1 ηt−1 Pt−1 Pt−2 ηt−2 Pt−1
−1
Simplfying Pt xt − αt x et−1 et−1 − x
αt−1 x et−2 Pt xt−1
======⇒ =γ − ∇x L(e xt−1 ) − ηt−1 λt−1
ηet−1 ηet−2 ηt−1 Pt−1 Pt
−1
Simplfying Pt xt − αt x et−1 et−1 − x
αt−1 x et−2 α−1 x et−1
======⇒ =γ − ∇x L(e xt−1 ) − ηt−1 λt−1 t
ηet−1 ηet−2 ηet−1
−1
Simplfying Pt xt − αt (1 − ηt−1 λt−1 )e xt−1 et−1 − x
αt−1 x et−2
======⇒ =γ − ∇x L(e xt−1 )
ηet−1 ηet−2
To conclude that Pt xt = x
et , it suffices to show that the coefficients before x
et−1 is
the same to that in (2). In other words, we need to show
Lemma 2.8.6 (Sufficient Conditions for positivity of αt ). Let λmax = maxt λt , ηmax =
maxt ηt . Define zmin is the larger root of the equation z 2 − (1 + γ − λmax ηmax )z + γ = 0.
√
To guarantee the existence of zmax we also assume ηmax λmax ≤ (1 − γ)2 . Then we
have
42
Proof. We will prove the above theorem with a strengthened induction —
^ α−1 −1
0 t0 − 1 zmin −1
S(t) : ∀0 ≤ t ≤ t, zmin ≤ αt0 ≤ 1 ≤ .
ηt0 −1 ηmax
Since α0 = 1, S(0) is obviously true. Now suppose S(t) is true for some t ∈ N, we
will prove S(t + 1).
First, since 0 < αt ≤ 1, αt+1 = −ηt λt + 1 + ηt
ηt−1
γ(1 − αt−1 ) ≤ 1.
Again by Equation (2.7), we have
αt−1 − 1 z −1 − 1 −1
1 − αt+1 = ηt λt + ηt γ = ηt λt + min ηt γ ≤ ηt λt + (zmin − 1)γ = 1 − zmin ,
ηt−1 ηmax
which shows αt+1 ≥ zmin . Here the last step is by definition of zmin .
Because of αt+1 ≥ zmin , we have
−1
αt+1 −1 −1 1 − αt+1 −1 α−1 − 1
≤ zmin ≤ zmin (λt + t γ)
ηt ηt ηt−1
−1 z −1 − 1 −1 1 − zmin z −1 − 1
≤zmin (λmax + min γ) = zmin = min .
ηmax ηmax ηmax
Now we are ready to give the formal statement about the closeness of Equation (2.6)
and the reduced LR schedule by Theorem 2.4.13.
43
Theorem 2.8.7. Given a Step Decay LR schedule with {TI }K−1 ∗ K−1 ∗ K−1
I=0 , {ηI }I=0 , {λI }I=0 ,
η̂t = Pt Pt+1 ηt .
It’s the same as the TEXP LR schedule({η˜t }) in Theorem 2.4.12 throughout each
phase I, in the sense that ∀TI + 1 ≤ t ≤ TI+1
η̂t−1 ηet−1
− 1 ≤ 0.0015 × 0.9009t−TI −1 .
η̂t ηet
Proof of Theorem 2.8.7. Assuming zI1 and zI2 (zI1 ≥ zI2 ) are the roots of Equation (2.1)
√
with η = ηI and λ = λI , we have γ ≤ zI20 ≤ γ ≤ zmin ≤ zI1 ≤ 1, ∀I, I 0 ∈ [K − 1] by
Lemma 2.8.1.
We can rewrite the recursion in Theorem 2.4.13 as the following:
−1 −1
αt = −ηI λI + 1 + γ(1 − αt−1 ) = −(zI1 + zI2 ) + zI1 zI2 αt−1 . (2.13)
44
In other words, we have
zI2
αt − zI1 = (αt−1 − zI1 ), t ≥ 1. (2.14)
αt−1
z2I αt−1
By Lemma 2.8.6, we have αt ≥ zmin , ∀t ≥ 0. Thus | αz1t − 1| = |
αt−1 zI1
− 1| ≤
I
2
γ
zmin
| αzt−1
1 − 1| = 2
γ
zmin
| αzt−1
1 − 1| ≤ γ(1 + λη 2 αt−1
1−γ
) | z1 |, which means αt geometrically
I I I
ηet−1
converges to its stable fixed point zI1 . and ηet
= (zI1 )2 . Since that zmin ≤ αt ≤ 1,
αTI 1−zmin λmax ηmax
zmin ≤ zI1 ≤ 1, we have | zI1
− 1| ≤ zmin
= 1−γ
≤ 1 , and thus | αz1t − 1| ≤
I
λmax ηmax
1−γ
( z2γ )t−TI −1 ≤ 1, ∀TI + 1 ≤ t ≤ TI+1 .
min
η̂t−1
Note that αI∗ = zI1 , η̂t
= αt αt+1 By definition of TEXP and TEXP++, we have
1
)2
ηet−1 (zI−1
if TI−1 + 1 ≤ t ≤ TI − 1
= (2.15)
ηet ∗
ηI−1 z1z1
ηI∗ I I−1
if t = TI , I ≥ 1
η̂t−1 ηt−1 αt+1 αt
if TI−1 + 1 ≤ t ≤ TI − 1
= αt+1 αt = (2.16)
η̂t ηt ∗
ηI−1
η∗
αTI +1 αTI if t = TI , I ≥ 1
I
45
Thus we conclude ∀I ∈ [K − 1], TI + 1 ≤ t ≤ TI+1 , we have
t−TI −1
η̂t−1 ηet−1 λmax ηmax γ λmax ηmax t−TI −1 λmax ηmax 2(t−TI −1)
−1 ≤3 2
≤3 ·γ (1+ ) .
η̂t ηet 1−γ zmin 1−γ 1−γ
Case 1: WD alone Since the objective is strongly convex, it has unique argmin
w∗ . By symmetry, w∗ = βe1 , for some β > 0. By KKT condition, we have
r
|x1 | 2
λβ = E ≤ E [|x1 |] = ,
x1 ∼N (0,1) 1 + exp(β|x1 |) x1 ∼N (0,1) π
Case 3: Both BN and WD We will need the following lemma when lower
bounding the norm of the stochastic gradient.
i.i.d.
Lemma 2.8.8 (Concentration of Chi-Square). Suppose X1 , . . . , Xk ∼ N (0, 1), then
" k
#
X k2
Pr Xi2 < kβ ≤ βe1−β . (2.17)
i=1
46
" k
# " k
! #
X k X
Pr Xi2 < kβ ≤ βe 1−β 2
= Pr exp ktβ − t Xi2 ≥1
i=1 i=1
" k
!#
X
≤ E exp ktβ − t Xi2 (Markov Inequality)
i=1
− k2
=ektβ (1 + 2t) .
(2.18)
The last equality uses the fact that E [tXi2 ] = √ 1 for t < 1
. The proof is
1−2t 2
1−β
completed by taking t = 2β
.
B
η X > wt λ 2
wt+1 =wt − ∇ ln(1 + exp(−xt,b yt,b) ) + kwt k
B b=1 kwt k 2
B
η X yt,b Π⊥
wt xt,b
=(1 − λη)wt − w ,
B b=1 1 + exp(xt,b > kwtt k yt,b ) kwt k
i.i.d. wt wt>
where xt,b ∼ N (0, Im ), yt,b = sign ([xt,b ]1 ), and Π⊥
wt = I − kwt k2
.
1. kwt0 k = (1 − ηλ)kwk.
kwt0 k
2. kwt+1 k ≤ cos 2ε
.
47
The second property is because by Lemma 2.3.2, (wt+1 − wt0 ) ⊥ wt0 and by
assumption of small error, ∠wt+1 wt0 ≤ 2ε.
Therefore
2T1 2T1
kwT1 +T0 k2
1 − ηλ 1 − ηλ 2 2T1 2
≤ e−2T1 (ηλ−2ε )
2
≤ ≤ 2
≤ 1 − (ηλ − 2ε )
kwT0 k cos 2ε 1 − 2ε
r
η m−2
= .
64kwT0 k2 ε B
q
η m−2
In other word, kwT0 +T1 k2 ≤ 64ε B
. Since kwT0 +t k is monotone decreasing,
q
η m−2
kwT0 +t k2 ≤ 64ε B
holds for any t = T1 , . . . , T1 + T2 .
Step 2: We show that the norm of the stochastic gradient is lower bounded
with constant probability. In other words, we want to show the norm of ξt =
PB yt,b Π⊥
wt xt,b
b=1 1+exp(xt,b > wt
yt,b ) kwt k
is lower bounded with high probability.
kwt k
Let Π⊥
wt ,e1 be the projection matrix for the orthogonal space spanned by wt and
e1 . W.L.O.G, we can assume the rank of Π⊥
wt ,e1 is 2. In case wt = e1 , we just exclude
with kΠ⊥
wt ,e1 ξt k.
v !2
B
u B
X yt,b uX
d t yt,b
Π⊥ xt,b = Π⊥
wt ,e1 x,
b=1
1 + exp(xt,b kwt k yt,b ) wt
> wt
b=1
wt
1 + exp(xt,b > kw tk
yt,b )
(2.19)
where x ∼ N (0, Im ). We further note that kΠ⊥ 2 2
wt ,e1 xk ∼ χ (m − 2). By
Lemma 2.8.8,
48
m−2 1 m−2 1 1 1
Pr kΠ⊥
wt ,e1 xt k
2
≥ ≥ 1 − ( 7 ) 2 ≥ 1 − ( 7 )2 ≥ . (2.20)
8 8e 8 8e 8 3
2
PB yt,b
Now we will give a high probability lower bound for b=1 w
1+exp(xt,b > kwt k yt,b )
.
t
wt 1
Pr |x>
t,b |<1 ≥ , (2.21)
kwt k 2
h i
which implies the following, where At,b is defined as 1 |x> wt
t,b kwt k | <1≥ 1
2
:
" #
yt,b 1 1
E At,b = Pr k > wt
k≥ ≥ . (2.22)
1 + exp(xt,b kwt k yt ) 1+e 2
PB PB hP i
B B B
Note that b=1 At,b ≤ B, and E b=1 At,b ≥ 2
, we have Pr b=1 At,b < 4
≤ 23 .
Thus,
!2
" B #
B
X yt,b B X B 1
Pr
wt ≥ ≥ Pr At,b ≥ ≥ . (2.23)
b=1
1 + exp(xt,b > kwtk
yt,b ) 4(1 + e)2
b=1
4 3
Thus w.p. at least 19 , Equation (2.23) and Equation (2.20) happen together, which
implies
B B
η X > wt η X yt,b Π⊥
wt xt,b
k ∇ ln(1 + exp(−xt,b yt,b ))k = k > wt k
B b=1 kwt k B b=1 1 + exp(xt,b kwt k yt ) kwt k
√ r
η m−2 η m−2
≥ ≥
1 + e 8kwt k 32kwt k B
Step 3. To stay in the cone {w|∠we1 ≤ ε}, the SGD update kwt+1 − wt0 k =
k Bη B > wt
P
b=1 ∇ ln(1 + exp(−xt,b kwt k yt,b ))k has to be smaller than kwt k sin 2ε for any
49
t = T0 + T1 , . . . , T0 + T1 + T2 . However, step 1 and 2 together show that k∇ ln(1 +
exp(−x> wt
t kwt k yt ))k ≥ 2kwt kε w.p.
1
per iteration. Thus the probability that wt always
9
T
stays in the cone for every t = T0 + T1 , . . . , T0 + T1 + T2 is less than 89 2 ≤ δ.
It’s interesting that the only property of the global minimum we use is that
the if both wt , wt+1 are ε−optimal, then the angle between wt and wt+1 is at
most 2ε. Thus we indeed have proved a stronger statement: At least once in every
√
1 64kwT0 k2 ε B
2(ηλ−2ε2 )
ln √
η m−2
+ 9 ln 1δ iterations, the angle between wt and wt+1 will be larger
than 2. In other words, if the the amount of the update stabilizes to some direction
√
in terms of angle, then the fluctuation in terms of angle must be larger than 2ηλ for
this simple model, no matter how small the noise is.
Lemma 2.8.9. Suppose loss L is scale invariant, then L is non-convex in the following
two sense:
2. There exists no ball containing origin such that the loss is locally convex, unless
the loss is constant function.
Proof. Suppose L(x∗ ) = supx∈B L(x). W.L.O.G, we assume kx∗ k < 1. By convexity,
every line segment passing x∗ must have constant loss, which implies the loss is
∗
constant over set B − {c kxx∗ k | −1 ≤ c ≤ 0}. Applying the above argument on any
other maximum point x0 implies the loss is constant over B − {0}.
50
Proof in Item 3. By Lemma 2.3.2 and the update rule of GD with WD, we have
which implies
t−1
X
kxt k2 = (1 − λη)2(t−i−1) η 2 k∇L(xt−1 )k2 + (1 − λη)2(t−T ) kxT k2 .
i=T
T 0 0 −1
TX
! 0 −1
TX
!
X 1 1
kxt k2 ≤ k∇L(xt )k2 + kxT k2 ≤ k∇L(xt )k2 + kxT k2 .
t=T
1 − (1 − λη)2 t=T
λη t=T
P 0
Note that by assumption we have Tt=T−1 k∇L(xt )k2 = cη1 f (xT ) − f (xT 0 ).
P∞ 2
As a conclusion, we have 2
t=T kxt k ≤
f (xT )−minx f (x)
cη 2 λ
+ kxλη
Tk
, which implies
lim kxt k2 = 0.
t→∞
Now we rigorously analyze norm growth in this algorithm. This greatly extends
previous analyses of effect of normalization schemes [29, 37] for vanilla SGD.
Theorem 2.9.1. Under the update rule 2.3.1 with λt = 0, the norm of scale invariant
parameter xt satisfies the following property:
t
2
X 1 − γ t−i+1 1 − γ t+1
kxt+1 k = kxi − xi+1 k2 + γkxi−1 − xi k2 −γ (kx0 k2 −kx−1 k2 )
i=0
1−γ 1−γ
51
Proof. Let’s use Rt , Dt , Ct to denote kxt k2 , kxt+1 − xt k2 , x>
t (xt+1 − xt ) respectively.
We also have
Ct xt+1 − xt xt − xt−1 γ
= x>
t = x>
t (γ − λt xt ) = (Dt + Ct−1 ) − λt Rt ,
ηt ηt ηt−1 ηt−1
namely,
Ct γDt γ
∀t ≥ 0 P (t) : − = Ct−1 − λt Rt .
ηt ηt−1 ηt−1
S(t) γS(t−1)
Simplify ηt
− ηt−1
+ P (t), we have
When λt = 0, we have
t
Rt+1 − Rt R0 − R−1 X t−i Di Di−1 R0 − R−1
= γ t+1 + γ ( +γ ) ≥ γ t+1 .
ηt η−1 i=0
η i ηi−1 η0
t
X 1 − γ t−i+1 1 − γ t+1
Rt+1 = R0 + (Di + γDi−1 ) − γ (R0 − R−1 ),
i=0
1−γ 1−γ
t
X
Rt+1 = R0 + Di .
i=0
52
For general deep nets, we have the following result, suggesting that the mean square
of the update are constant compared to the mean square of the norm. The constant
is mainly determined by ηλ, explaining why the usage of weight decay prevents the
1
parameters to converge in direction.
Theorem 2.9.2. For SGD with constant LR η, weight decay λ and momentum γ,
P −1 P −1
when the limits R∞ = limT →∞ T1 Tt=0 kwt k2 , D∞ = limT →∞ T1 Tt=0 kwt+1 − wt k2
exist, we have
2ηλ
D∞ = R∞ .
1+γ
Proof of Theorem 2.9.2. Take average of Equation (2.24) over t, when the limits
P −1 P −1
R∞ = limT →∞ T1 Tt=0 kwt k2 , D∞ = limT →∞ T1 Tt=0 kwt+1 − wt k2 exists, we have
1+γ
D∞ = 2λR∞ .
η
tectures
In this section, we will discuss how Normalization layers make the output of the
network scale-invariant to its parameters. Viewing a neural network as a DAG, we
give a sufficient condition for the scale invariance which could be checked easily by
topological order, and apply this on several standard network architectures such as Fully
Connected(FC) Networks, Plain CNN, ResNet[48], and PreResNet[49]. For simplicity,
we restrict our discussions among networks with ReLU activation only. Throughout
this section, we assume the linear layers and the bias after last normalization layer are
1
? ] had a similar argument for this phenomenon by connecting this to the LARS[47], though
it’s not rigorous in the way it deals with momentum and equilibrium of norm.
53
fixed to its random initialization, which doesn’t harm the performance of the network
empirically[32].
2.10.1 Notations
Definition 2.10.2. For a module with n inputs and m outputs, we say the module is
(a1 , ...an ; b1 , ..., bm )-homogeneous if the m outputs are bi -homogeneous to the network
parameters whenever the n inputs are ai -homogeneous to the network parameters. A
model is scale invariant iff its output is (; 0)-homogeneous. (A complete model doesn’t
take any input from another module)
Suppose the network only contains following modules, and we list the degree of
homogeneity of these basic modules, given the degree of homogeneity of its input.
(I) Input
(B) Bias Layer(Adding Trainable Bias to the output of the previous layer)
(+) Addition Layer (adding the outputs of two layers with the same dimension2 .)
2
Addition Layer(+) is mainly used in ResNet and other similar architectures. In this section, we
also use it as an alternative definition of Bias Layer(B). See Figure 2.7
54
Table 2.1: Table showing how degree of homogeneity of the output of basic modules
depends on the degree of homogeneity of the input. Input module doesn’t require
any input and thus the output are trivially scale invariant. ReLU, Pooling( and
other fixed linear maps) are ignored because they keep the degree of homogeneity, i.e.
(x; x)-homo, and thus can be omitted when creating the DAG in Theorem 2.10.4.
Remark 2.10.3. For the purpose of deciding the degree of homogeneity of a network,
we can ignore the difference among convolutional layers, fully connected layer and the
diagonal linear layer in the affine transformation of Normalization layer, since they’re
all linear and the degree of homogeneity is increased by 1 after applying them.
On the other hand, BN and IN has some benefit which GN and LN doesn’t have,
namely the bias term (per channel) immediately before BN or IN has zero effect on
the network output and thus can be removed. (See Figure 2.15)
We also demonstrate the homogeneity of the output of the modules via the following
figures, which will be reused to later to define network architectures.
Theorem 2.10.4. For a network only consisting of modules defined above and ReLU
activation, we can view it as a Directed acyclic graph and check its scale invariance
by Algorithm 1.
We start with the simple cases where all bias term(including that of linear layer and
normalization layer) and the scaling term of normalization layer are fixed to be 0 and
1 element-wise respectively, which means the bias and the scaling could be dropped
55
(a) Input(I) (b) Linear(L) (c) Addition(+) (d) Normaliza-
tion(N)
Figure 2.7: Degree of homogeneity of the output of basic modules given degree of
homogeneity of the input.
56
from the network structure. We empirically find this doesn’t affect the performance of
network in a noticeable way. We will discuss the full case in the next subsection.
Figure 2.8: Degree of homogeneity for all modules in vanilla CNNs/FC networks.
ResNet: See Figure 2.10. To ensure the scaling invariance, we add an additional
normalizaiton layer in the shortcut after downsampling. This implementation is
sometimes used in practice and doesn’t affect the performance in a noticeable way.
Preactivation ResNet: See Figure 2.11. Preactivation means to change the order
between convolutional layer and normalization layer. For similar reason, we add an
additional normalizaiton layer in the shortcut before downsampling.
57
(a) The starting part of ResNet
Figure 2.10: Degree of homogeneity for all modules in ResNet without affine transfor-
mation in normalization layer. The last normalization layer is omitted.
Now we discuss the full case where the affine transformation part of normalization
layer is trainable. Due to the reason that the bias of linear layer (before BN) has 0
gradient as we mentioned in 2.10.3, the bias term is usually dropped from network
architecture in practice to save memory and accelerate training( even with other
normalization methods)(See PyTorch Implementation [44]). However, when LN or
GN is used, and the bias term of linear layer is trainable, the network could be scale
variant (See Figure 2.15).
58
(a) The starting part of PreResNet
Figure 2.11: Degree of homogeneity for all modules in ResNet without affine transfor-
mation in normalization layer. The last normalization layer is omitted.
Figure 2.12: Degree of homogeneity for all modules in vanilla CNNs/FC networks.
59
ResNet: See Figure 2.13. To ensure the scaling invariance, we add an additional
normalizaiton layer in the shortcut after downsampling. This implementation is
sometimes used in practice and doesn’t affect the performance in a noticeable way.
Figure 2.13: Degree of homogeneity for all modules in ResNet with trainable affine
transformation. The last normalization layer is omitted.
Preactivation ResNet: See Figure 2.14. Preactivation means to change the order
between convolutional layer and normalization layer. For similar reason, we add an
additional normalizaiton layer in the shortcut before downsampling.
60
(a) The starting part of PreResNet
Figure 2.14: Degree of homogeneity for all modules in PreResNet with trainable affine
transformation. The last normalization layer is omitted.
Figure 2.15: The network can be not scale variant if the GN or IN is used and the bias
of linear layer is trainable. The red ‘F’ means the Algorithm 1 will return False here.
61
Chapter 3
3.1 Introduction
1
1. Parameter norm converges to Θ(( λη ) 4 ) in T1 = O(
e 1 ) steps with high probability
ηλ
(A3). (L, x(0)) → (L0 , cx(0)), where L0 is defined as L0 (x) := L( xc ) for any c > 0.
Properties (1) and (2) suggest our results are more robust to initialization scale (by
only having logarithmic dependence on it), showing the advantage of using scale
invariant functions while matching the standard convergence rates for non-convex
functions.
3.2 Preliminary
In this section we present the definition of scale invariant functions and some of their
x
useful properties. For x ∈ Rd , we define x := kxk2
. We say a function is C k iff it is
k-times continuously differentiable. We also assume the loss function L is a C 2 and
scale invariant function and ρ := max k∇2 L(x)k. Same to Chapter 2, we use λ to
kxk=1
denote weighrt decay factor and η to denote learning rate.
63
Definition 3.2.1. Given a cone U ⊂ Rd , we say a function f : U → R is (positively) k-
homogeneous or of homogeneity of degree k iff for any c > 0 and x ∈ U , f (cx) = ck f (x).
We say a function is scale invariant iff it is 0-homogeneous.
We first present the convergence result in the deterministic case, i.e., Gradient Descent
over L(x) + λ
2
kxk22 .
1
Theorem 3.3.1 (GD+WD). For ηλ ≤ 2
, let x(t) be defined by GD (3.1), and
kx(0)k2
l m
1
T0 = 2ηλ ln ρπ2 η 2 + 3 . We have
Proof Sketch of Theorem 3.3.1. Scale invariant functions do not have bounded smooth-
ness at 0 making it a challenge to use standard convergence analysis. Our key insight
is that for scale invariant loss function, even with a fixed LR η, GD can tune its
η
effective LR kx(t)k22
by changing the norm. Thus once GD passes the area of the
ρ
suitable norm, the smoothness of scale invariant loss function is upper bounded by r2
!
1 ρη
L(x(t)) − L(x(t + 1)) ≥ η − 2 k∇L(x(t))k22 .
1 − ηλ 2 kx(t)k2 (1 − ηλ)2
!
2ρη
L(x(t)) − L(x(t + 1)) ≥ η 1 − k∇L(x(t))k22 .
kx(t)k22
Remark 3.3.3. One might wonder why the upper bounds on loss and gradient norm
do not appear in Theorem 3.3.1. This is because we are working on a compact domain
(the unit sphere) and twice-differentiability implies those bounds implicitly. (See
Lemmas 3.5.3 and 3.5.4)
66
where γt ∈ Γ are i.i.d. random variables. We further assume there exists constants
σ and σ, such that σ 2 ≤ E k∇Lγ (x)k22 ≤ σ 2 , for any kxk2 = 1. We finally need the
following condition on ηλ to bound convergence.
q
σ2 2
Condition 3.4.1. M2
≥ 3e4ηλ λη ln 2Tδ .
The Condition 3.4.1 is useful for proving norm convergence in high probability. In
practice, typically ηλ is very small. Our experiments use η = 0.0008 and λ = 0.01.
Hence e4ηλ ≈ 1, and Condition 3.4.1 essentially requires the gradient norm square
√
cannot exceed its average multiplied by 1/ ηλ ≈ 350, which is reasonable for most
iterates.
Theorem 3.4.2 (SGD+WD). Let x(t) be defined by SGD (3.3). For ηλ ≤ 0.1,
under Condition 3.4.1, with probability 1 − 5δ,
σ2 2λ
∀T1 ≤ t ≤ T − 1, ≤ kx(t)k42 ≤ 4σ 2 , (3.4)
2 η
and
T −1
1 X π 2 ρσ p ρσ 3
k∇L(x(t))k22 ≤ √ + 4 ηλ 2
T − T1 t=T (T − T1 ) 2ηλ σ
1
s s (3.5)
ln 2δ πρM σ ln 2δ p M 2 ρσ
+ 4 + 4 λη ,
T − T1 σ T − T1 σ2
n 2
o
2e4 M 2
where T1 = 1
4ηλ
max ln Mσ2ηλ + ln kx(0)k4 −2
η
, 8 .
2
The proof of this theorem is presented in Section 3.7. Similar to our earlier
result for GD this bound matches the standard O(T −1/4 ) convergence rate of SGD
e 1 ). Further, it only has a logarithmic
for non-convex functions by setting T = O( ηλ
67
scale as discussed earlier for GD. We further extend this result to the case where the
scale invariant loss has multiple scale invariant parameter groups in Section 3.8.
ρ kvk22
L(x + v) − L(x) ≤ hv, ∇L(x)i + .
2 kxk22
Proof of Lemma 3.5.1. Define γ(s) = x + sv, then we have L(γ(0)) = L(x) and
L(γ(1)) = L(x + v). Taking Taylor expansion of F (s) = L(γ(s)) at s = 0, we have
F 00 (s∗ )
F (1) − F (0) = F 0 (0) + , for some s∗ ∈ [0, 1].
2
ρ 0 ∗ 2
F 00 (s∗ ) =γ 0 (s∗ )∇2 L(γ(s∗ ))γ 0 (s∗ ) ≤ 2 kγ (s )k2 ,
∗
kγ(s )k2
where the last inequality uses the fact that L is scale invariant. The proof is completed
by noting that kγ(s∗ )k2 ≥ kγ(0)k2 = kxk22 and that γ 0 (s∗ ) = v.
K
ρ X kvi k22
L(x + v) − L(x) ≤ hv, ∇L(x)i + .
2 k=1 kxi k22
68
Proof of Lemma 3.5.2. We first prove for the case where kxk k2 = 1, ∀k ∈ [K]. Similar
to the proof of Lemma 3.5.1, it suffices to show that the smoothness of L is at
most ρ along the line joining x and x + v. This holds because ∀s ∈ [0, 1], k ∈ [K],
kxi + svi k2 ≥ kxi k2 by assumption that hxk , vk i = 0 for all k ∈ [K].
x> x>
Now we turn to the general case. b = [ kx11k , . . . , kxKKk ]> and v 0 =
Define x
2 2
v> v>
[ kx11k , . . . , kxKKk ]> . Since L is multi-group scale invariant, we have L(x) = L(b
x)
2 2
π
Lemma 3.5.3. If L is scale invariant, k∇L(x)k2 ≤ kxk2
supkxk=1 k∇2 L(x)k2 .
Proof of Lemma 3.5.3. It suffices to prove the above bound for all x with kxk2 = 1.
Let x∗ be any local minimizer of L on Sd−1 and γ : [0, 1] → Sd−1 be the geodesic curve
satisfying that γ(0) = x∗ and γ(1) = x. We know the length of {γ(t)}1t=0 ≤ π and
thus
Z 1 Z 1
2 dγ(t) dγ(t)
k∇L(x)k = ∇ L(γ(t)) dt ≤ ∇2 L(γ(t)) 2
dt ≤ ρ · π
t=0 dt t=0 dt 2
π2
Lemma 3.5.4. If L is scale invariant, supx,x0 L(x) − L(x0 ) ≤ 2
supkxk=1 k∇2 L(x)k2 .
3.5.2 Probablity
σ 2 s2
E[exp(sX)] ≤ exp( ), ∀s ∈ R.
2
69
In this work, we also use the following notion of conditional subgaussian. We say a
random variable X ∈ R is said to be sub-Gaussian with variance proxy σ 2 conditioned
on event E (denoted by X ∼ subG(σ 2 , E)) if its moment generating function satisfies
σ 2 s2
E[exp(sX)1[E]] ≤ exp( ), ∀s ∈ R.
2
Lemma 3.5.6 (Chernoff Bound with Conditioning). Let X ∼ subG(σ 2 , E). Then for
any t > 0, it holds that
t2 t2
P[X > t ∧ E] ≤ exp(− ), and P[X < −t ∧ E] ≤ exp(− )
2σ 2 2σ 2
When P[E] = 1, we get the standard Chernoff bound. Let X ∼ subG(σ 2 ). Then for
any t > 0, it holds that
t2 t2
P[X > t] ≤ exp(− ), and P[X < −t] ≤ exp(− )
2σ 2 2σ 2
σ 2 s2
P[X > t ∧ E] = P[esX ≥ est ∧ E] ≤ e−st E[esX 1[E]] = exp(−st + ).
2
t
The proof is completed by picking s = σ2
.
We will use (Ω, Σ, P) to note the probability space and {Ft }t∈N to denote the
filtration.
70
Proof. We will prove by induction on T . When T = 1, the statement is true by
assumption. Now suppose the statement holds for T − 1, we have for any s > 0
T
X T −1
X
E[exp(s Xi )1[ET −1 ]] =E[exp(s Xi )1[ET −1 ]E[exp(sXT )1[Et−1 ] | FT −1 ]]
i=1 i=1
T −1
X s2 σT2 −1
≤E[exp(s Xi )1[ET −1 ] exp( )]
i=1
2
T −1
X s2 σT2 −1
≤E[exp(s Xi )1[ET −2 ]] exp( )
i=1
2
PT −1
PT s2 σt2
Thus we have that E[exp(s i=1 Xi )1[ET −1 ]] ≤ exp( t=0
2
).
3.5.3 Others
t
X ekx
(1 − x)kτ ≤
τ =0
kx
t ∞ ∞
X X X 1 ekx
kτ
(1 − x) ≤ kτ
(1 − x) ≤ e−kxτ = ≤ ,
τ =0 τ =0 τ =0
1 − e−kx kx
Proof of Lemma 3.3.2. This is a special case of Lemma 3.5.1 with x = (1 − ηλ)x(t)
and v = −η∇L(x(t)). Here we use the assumption that L is scale invariant, ∇L is
∇L(x(t))
−1-homogeneous. By Lemma 3.2.3, which means ∇L(x) = 1−ηλ
.
The following lemma deals with the case where kx(0)k22 < π 2 ρη.
71
Lemma 3.6.1. Let I = {T 0 ∈ N | ∀0 ≤ t ≤ T 0 , kx(t)k22 ≤ π 2 ρη ∧ k∇L(x(t))k22 >
2(π 2 ρη)2
8π 4 ρ2 λη}. Suppose 0 ∈ I and T = max I. Then T ≤ 1
6λη
and kx(T + 1)k22 ≤ kx(0)k22
.
≥ − 2π 2 ρλη 2 + 8π 2 ρλη 2
=6π 2 ρλη 2 .
Thus 6π 2 ρλη 2 · T ≤ kx(T )k22 − kx(0)k22 < kx(T )k22 ≤ π 2 ρη, which implies that
1
T < 6λη
. Moreover, we have that
Proof of Theorem 3.6.2. We first claim there’s 0 ≤ t ≤ T0 , such that kx(t)k22 < π 2 ρη.
72
Otherwise, by Lemma 3.3.2, for t = 0, . . . , T0 , we have L(x(t)) − L(x(t + 1)) ≤
η
2
k∇L(x(t))k22 . Note that kx(t + 1)k22 − (1 − ηλ)2 kx(t)k22 = η 2 k∇L(x(t))k22 .
Therefore, we have that
0 −1
TX
kx(T0 )k22 − (1 − ηλ)2T0
kx(0)k22 = η 2 (1 − ηλ)2(T0 −t) k∇L(x(t))k22
t=0
TX0 −1
≤ η 2 k∇L(x(t))k22
t=0
η
≤ (L(x(0)) − L(xT0 −1 ))
2
ηπ 2 ρ
≤
2
ηπ 2 ρ
By the definition of T0 , we have (1 − ηλ)2T0 kx(0)k22 ≤ e−2ηλT0 kx(0)k22 ≤ 2
. Thus
kx(T0 )k ≤ π 2 ρη.
Without loss of generality, we let T be the smallest integer such that kx(T )k22 ≤
π 2 ρη. By assumption, T ≥ 1. Therefore kx(T − 1)k22 ≥ π 2 ρη. Because kx(T )k22 =
(1 − ηλ)2 kx(T − 1)k22 + η 2 k∇L(x(T − 1))k22 , we have that
kx(T )k22
Note that kx(T )k22 < π 2 ρη and (1−λη)2
≥ kx(T − 1)k22 ≥ π 2 ρη, we conclude that
kx(T )k22
k∇L(x(T − 1))k22 ≤η −2 kx(T )k22 − (1 − ηλ)2 kx(T − 1)k22 )
(1 − λη)2
1 − (1 − λη)2 2 2
≤ 2 (π ρη)
η (1 − λη)2
≤8ληπ 4 ρ2 ,
73
Combining Lemma 3.6.1 and Theorem 3.6.2 removes the initial condition in
Theorem 3.6.2, and completes the proof of Theorem 3.3.1.
We will use (Ω, Σ, P) to note the probability space and {Ft }t∈N to denote the filtration
where Ft := σ({γi | 0 ≤ i ≤ t}) is the σ-algebra generated by γ0 , . . . , γt .
4
Lemma 3.7.1. k∇Lγ (x)k22 − E k∇Lγ (x)k22 ∼ subG( 4kxk
M
4 ).
2
M2
Proof. Lemma 3.7.1 Note 0 ≤ k∇Lγ (x)k22 ≤ kxk22
. The proof is immediate by Hoeffding
Lemma (see Lemma 3.6 in [52]).
s
t
X M2 1 2T 2
(1 − ηλ)4(t−τ ) k∇Lγτ (x(τ ))k22 − E[k∇Lγτ (x(τ ))k22 | x(τ )] ≤ e4ηλ ·
ln .
τ =t0
4 λη δ
(3.6)
t
X e8ηλ M 4
(1 − ηλ)4(t−τ ) k∇Lγτ (x(τ ))k22 − E[k∇Lγτ (x(τ ))k22 | x(τ )] ∼ subG(
)
τ =t0
32
8(t−τ ) M 4 e8ηλ
Pt
Proof of Lemma 3.7.2. Note that τ =t0 (1−ηλ) 4
≤ 32
by Lemma 3.5.8. Thus
by Azuma Inequality and Lemma 3.7.1, we have that the martingale
t
X
(1 − ηλ)4(t−τ ) k∇Lγτ (x(τ ))k22 − E[k∇Lγτ (x(τ ))k22 | x(τ )]
τ =t0
e8ηλ
is 32
-subgaussian.
74
By Lemma 3.5.6, we have for any ∀0 ≤ t0 ≤ t ≤ T − 1, Equation (3.6) holds with
δ
probability at least T2
. The proof is completed by applying union bound.
Lemma 3.7.3 (Norm Lower Bound). Under Condition 3.4.1 and additionally assume
ηλ ≤ 21 . On ET , it holds that for any t ≥ 0,
s
1 − ηλ 1 1 2T 2
η −2 kx(t)k42 ≥ (1 − e−4tηλ(1−ηλ) )σ 2 − (1 − ηλ)2 M 2 e4ηλ ln (3.7)
2ηλ 2 λη δ
q
σ2 M 2 4ηλ 1 2
When 12ηλ
≥ 2
e λη
ln 2Tδ , the above condition is simplified into the following:
1
on ET for any ηλ
≤ t ≤ T,
In the above inequality, we also used the fact that 1 − e−4(1−ηλ) ≥ 56 , which is
implied by ηλ ≤ 0.5.
k∇Lγt (x(t))k22
kx(t + 1)k22 = (1 − ηλ)2 kx(t)k22 + η 2 . (3.9)
kx(t)k22
η 4 k∇Lγt (x(t))k42
kx(t + 1)k42 = (1 − ηλ)4 kx(t)k42 + 2(1 − ηλ)2 η 2 k∇Lγ (x(t))k22 + .
kx(t)k42
(3.10)
75
Thus
t
X
η −2
kx(t + 1)k42 ≥2 (1 − ηλ)4(t−τ )+2 k∇Lγτ (x(τ ))k22
τ =0
t
X
≥2 (1 − ηλ)4(t−τ )+2 E k∇Lγτ (x(τ ))k22
τ =0
t
X
(1 − ηλ)4(t−τ )+2 k∇Lγτ (x(τ ))k22 − E k∇Lγτ (x(τ ))k22 .
+2
τ =0
t t
X X 1 − e−4tηλ(1−ηλ) 1 − e−4tηλ(1−ηλ)
(1 − ηλ) 4(t−τ )
≥ e−4(t−τ )ηλ(1−ηλ) = ≥ .
τ =0 τ =0
1 − e−4ηλ(1−ηλ) 4ηλ(1 − ηλ)
s
1 − ηλ 1 1 2T 2
η −2 kx(t)k42 ≥ (1 − e−4tηλ(1−ηλ) )σ 2 − (1 − ηλ)2 M 2 e4ηλ ln
2ηλ 2 λη δ
Lemma 3.7.4 (Norm upper bound). Under Condition 3.4.1 and additionally assume
1
ηλ ≤ 0.1. Let T0 = d ηλ e. Let t∗ be the earliest step t in {0, . . . , T0 − 1} that
e8 (1−ηλ)2 σ 2
η −2 kx(t)k42 ≥ 4ηλ
and we denote t∗ = T0 if this doesn’t happen in {0, . . . , T0 −1}.
(1−ηλ)2 σ 2
For the case t∗ = T0 , we have η −2 kx(T0 )k42 ≤ 4ηλ
. On ET , for any t ≥ t∗ ,
( )
2e4 M 2
−4λη(t−t∗ ) ln 4σ
2
σ2
−2
1)k42 kx(0)k4
2 −2
η kx(t + ≤e max 2M e 2η ,e . + . (3.11)
ηλ ηλ
n 2
o
2e4 M 2
Thus, there exists T1 = T0 + 1
4ηλ
max ln Mσ2ηλ + ln kx(0)k4 −2 , 4 , such that ∀t ≥ T1 ,
η 2
2σ 2
η −2
kx(t + 1)k42 ≤ ηλ
.
76
Proof of Lemma 3.7.4. If t∗ < T0 , it holds that conditioned on ET , for any t∗ ≤ t < T0 ,
∗ (1 − ηλ)2 σ 2
η −2 kxt k42 ≥ (1 − ηλ)4(t−t ) η −2 kx(t∗ )k42 ≥ (1 − ηλ)4(T0 −1) η −2 kx(t∗ )k42 ≥
4ηλ
η −2 kx(t + 1)k42
k∇Lγt (x(t))k42
=(1 − ηλ)4 η −2 kx(t)k42 + 2(1 − λη)2 k∇Lγ (x(t))k22 +
kx(t)k42 η −2
t
4(t+1−t∗ ) −2
X
=(1 − ηλ) η kx(t ∗
)k42 +2 (1 − ηλ)4(t−τ )+2 E[k∇Lγτ (x(τ ))k22 | x(τ )]
τ =t∗
| {z }
(A)
t
X
(1 − ηλ)4(t−τ )+2 k∇Lγτ (x(τ ))k22 − E[k∇Lγτ (x(τ ))k22 | x(τ )]
+2
τ =t∗
| {z }
(B)
t 4
4(t−τ ) k∇Lγτ (x(τ ))k2
X
+ (1 − ηλ) .
τ =t∗
kx(τ )k42 η −2
| {z }
(C)
(3.12)
Below we will upper-bound the terms (A), (B) and (C) on ET respectively.
t
X (1 − ηλ)2 e4ηλ 2 e0.2 2
(A) ≤ 2 (1 − ηλ)4(t−τ )+2 σ 2 ≤ σ ≤ σ , (3.13)
τ =t∗
2ηλ 2ηλ
s
M 2 4ηλ 1 2T 2 (1 − ηλ)2 2
(B) ≤ (1 − ηλ)2 e ln ≤ σ (3.14)
2 λη δ 6ηλ
77
(C). Combining the above analysis and Lemma 3.7.3, we know conditioned on ET ,
(1−ηλ)2 σ 2
for any t ≥ t∗ , it holds kx(t)k42 /η 2 ≥ 4ηλ
.
t
4ηλM 4 X 4(t−τ )−2 e4ηλ M 4
(C) ≤ (1 − ηλ) ≤ (3.15)
σ 2 τ =t∗ (1 − ηλ)2 σ 2
σ2
Under Condition 3.4.1, we can further upper bound (C) by 9ηλe4ηλ (1−ηλ)2
≤
σ2 σ2
9× 98 × 87 ηλ
= 7ηλ
, where we used the fact that ηλ ≤ 0.1.
∗ −1)
η −1 kxt∗ −1 k22 ≥ (1−ηλ)2(t η −1 kx(0)k22 ≥ e−4(T0 −1)ηλ η −1 kx(0)k22 ≥ e−4 kx(0)k22 η −1 .
2
∇Lγt∗ −1 (x(t∗ − 1))
η −1
kx(t ∗
)k22 =(1 − ηλ) η 2 −1
kxt∗ −1 k22 + 2
kxt∗ −1 k22 η −1
s
e8 (1 − ηλ)2 σ 2 M2
≤(1 − ηλ)2 + e4
4ηλ kx(0)k22 η −1
s
e8 σ 2 4 M2
≤2 max{ ,e }
4ηλ kx(0)k22 η −1
(1−ηλ)2 σ 2
• t∗ = T0 . Then we have η −2 kx(t∗ )k42 ≤ 4ηλ
.
78
Plugging (3.16) back into (3.12), we got for any t ≥ t∗
η −2 kx(t + 1)k42
∗
=(1 − ηλ)4ηλ(t+1−t ) η −2 kx(t∗ )k42 + (A) + (B) + (C) (3.17)
( )
2e4 M 2
−4λη(t−t∗ ) 2 ln
kx(0)k4 η −2 4σ
2
σ2
≤e max 2M e 2 ,e . + ,
ηλ ηλ
Theorem 3.4.2 (SGD+WD). Let x(t) be defined by SGD (3.3). For ηλ ≤ 0.1,
under Condition 3.4.1, with probability 1 − 5δ,
σ2 2λ
∀T1 ≤ t ≤ T − 1, ≤ kx(t)k42 ≤ 4σ 2 , (3.4)
2 η
and
T −1
1 X π 2 ρσ p ρσ 3
k∇L(x(t))k22 ≤ √ + 4 ηλ 2
T − T1 t=T (T − T1 ) 2ηλ σ
1
s s (3.5)
ln 2δ πρM σ ln 2δ p M 2 ρσ
+ 4 + 4 λη ,
T − T1 σ T − T1 σ2
n o
1 M 2 ηλ 2e4 M 2
where T1 = 4ηλ
max ln σ2 + ln kx(0)k4 η−2 , 8 .
2
79
Summing up for t = T1 to T − 1, we have
T −1
X T −1
X
η k∇L(x(t))k22 kx(t)k−2
2 = η k∇L(x(t))k22
t=T1 t=T1
T −1
X ρη 2 E[k∇Lγt (x(t))k22 | x(t)]
≤(1 − ηλ) (L(xT1 ) − L(xT )) +
t=T1
2(1 − ηλ) kx(t)k42
| {z }
(A)
T −1
X η h∇L(x(t)), ∇L(x(t)) − ∇Lγt (x(t))i
+
t=T1
kx(t)k22
| {z }
(B)
T −1
ρη 2 k∇Lγt (x(t))k22 − E[k∇Lγt (x(t))k22 | x(t)]
X
+
t=T1
2(1 − ηλ) kx(t)k42
| {z }
(C)
Below we will give high-probability bounds for (A), (B) and (C) respectively. For
convenience, we will use A(t), B(t), C(t) to denote the tth term in (A), (B) and (C).
√ 2
Claim 3.7.5. ET =⇒ ∀T1 ≤ t ≤ T, A(t) ≤ 2 2ρηλ σσ2
PT −1 2 ληρ2 M 2
Claim 3.7.6. (B) = t=T1 B(t) is subG((T − T1 ) 4π σ2
, ET )
PT −1 2 λ2 η 2 M 4
Claim 3.7.7. (C) = t=T1 C(t) is subG((T − T1 ) 4ρ σ4
, ET )
√
Here Claim 3.7.5 follows from that 2(1 − ηλ) ≥ 2 and Lemma 3.7.3. Note by
the choice of T1 , we can upper and lower bound kx(t)k2 by Lemmas 3.7.3 and 3.7.4,
σ2 2σ 2
that is 4ηλ
≤ η −2 kx(t)k22 ≤ ηλ
. Thus Claims 3.7.6 and 3.7.7 is a direct consequence
of Lemma 3.5.7.
Thus we conclude w.p. 1 − 5δ,
T −1
r
λη 1 X 2 L(x(T1 )) − minx L(x) √ σ2
k∇L(x(t))k2 ≤ + 2 2ρηλ
2σ 2 T − T1 t=T T − T1 σ2
1
s s
2
8λη ln δ πρM 8 ln 2δ M 2ρ
+ + λη 2 ,
T − T1 σ T − T1 σ
80
rearranging it and applying Lemma 3.5.4, we get
T −1
1 X π 2 ρσ p ρσ 3
k∇L(x(t))k22 ≤ √ + 4 ηλ 2
T − T1 t=T (T − T1 ) 2ηλ σ
1
s s
ln 2δ 4πρM σ ln 2δ p M 2 ρσ
+ + 4 λη .
T − T1 σ T − T1 σ2
q
σ2
By Condition 3.4.1, we have M2
≥ 3 λη ln 2δ , and thus we have
T −1
s r
1 X π 2 ρσ p ρσ 3 4 1 1 4ρσ
k∇L(x(t))k22 ≤ √ + 4 ηλ 2 + πρσ + .
T − T1 t=T (T − T1 ) 2ηλ σ 3 (T − T1 )ηλ T − T1 3
1
variant Functions
In this section we extend our results to the multi-group scale invariant setting, which
is quite common in practice, e.g. a feedforward network with normalization after each
layer. By Definition 3.8.1, multi-group scale invariant function is also scale invariant.
However, it violates the assumption that the smoothness and the expectation of
stochastic gradient norm square is lower bounded on unit sphere (indeed the loss
function is not defined at everywhere on unit sphere), and thus needs to be treated
x y
separately. A simple example would be L(x, y) = L( kxk , kyk ), the loss L is undefined
2 2
at any point where kxk2 = 1 and y = 0. Yet our analysis for single scale invariant
parameter group can still extend to this case, with a similar assumption that the
expected gradient norm square is lower bounded.
81
Let d1 , . . . , dK be positive integers with d = K d d1 dK
P
k=1 dk . For x ∈ R = R ×. . .×R ,
we use sk to denote i≤k di and xk to denote the vector [xsk−1 , . . . , xsk −1 ]> . For
P
∂f (x)
convenience, we define ∇k f (x) = ∂xk
for any 1 ≤ k ≤ K.
p T −1
λη/2 1 X
PK k∇L(x(t))k22
k=1 σ k
T − T1
t=T1
K
π2ρ √ X σ 2k
≤ + 2 2ρηλ (3.18)
T − T1 σ2
k=1 k
s s
K K
8λη ln 2δ X Mk 8 ln 2δ X Mk2
+ πρ + ληρ 2
,
T − T1 k=1
σ k T − T 1
k=1
σ k
n o
1 Mk2 ηλ 2e4 Mk2
where T1 = 4ηλ
maxk ln σ2 + ln kx (0)k4 η−2 , 8 .
k k 2
Following the same strategy, we can prove the multi-group counterpart of norm
convergence result, Lemma 3.7.2. Given a integer T ≥ 0, let ET,k be the event that
∀0 ≤ t0 ≤ t ≤ T − 1,
s
t
X M2 1 2T 2
(1 − ηλ)4(t−τ ) k∇k Lγτ (x(τ ))k22 − E[k∇k Lγτ (x(τ ))k22 | x(τ )] ≤ e4ηλ · k
ln .
τ =t0
4 λη δ
82
Lemma 3.8.4. For any 0 ≤ t0 ≤ t ≤ T − 1, 1 ≤ k ≤ K, it holds that
t
X e8ηλ Mk4
(1 − ηλ)4(t−τ ) k∇k Lγτ (x(τ ))k22 − E[k∇k Lγτ (x(τ ))k22 | x(τ )] ∼ subG(
)
τ =t0
32
The following theorem is a restatement of Lemmas 3.7.3 and 3.7.4 in the context
of multi-group scale invariance.
n o
1 Mk2 ηλ 2e4 Mk2
Lemma 3.8.5. Under Condition 3.8.2, there exists T1 = 4ηλ
maxk ln σ2 + ln kx (0)k4 η−2 , 8 ,
k k 2
σ 2k 2σ 2k
such that ∀t ≥ T1 , 4ηλ
≤ η −2 kx(t)k42 ≤ ηλ
, conditioned on ∪K
k=1 ET,k .
K
η X ρη 2 k∇k Lγt (x(t))k22
L(x(t + 1)) − L(xt ) ≤ − h∇L(x(t)), ∇Lγt (x(t))i +
1 − ηλ k=1
2(1 − ηλ)2 kxk (t)k42
x> x>
b = [ kx11k , . . . , kxKKk ]> . Summing up for t = T1 to T − 1,
For convenience we define x
2 2
we have
T −1
X T −1
X
η k∇L(x(t))k22 kx(t)k−2
2 = η k∇L(x(t))k22
t=T1 t=T1
T −1 X
K
X ρη 2 E[k∇k Lγt (x(t))k22 | x(t)]
≤(1 − ηλ) (L(xT1 ) − L(xT )) +
t=T1 k=1
2(1 − ηλ) kxk (t)k42
| {z }
(A)
T −1 XK
X η h∇k L(b
x(t)), ∇k L(bx(t)) − ∇k Lγt (b
x(t))i
+ 2
t=T1 k=1
kxk (t)k2
| {z }
(B)
T −1 X
K
ρη 2 k∇k Lγt (x(t))k22 − E[k∇k Lγt (x(t))k22 | x(t)]
X
+
t=T1 k=1
2(1 − ηλ) kxk (t)k42
| {z }
(C)
83
Below we will give high-probability bounds for (A), (B) and (C) respectively. For
convenience, we will use A(t), B(t), C(t) to denote the tth term in (A), (B) and (C).
√ PK σ2k
Claim 3.8.6. ∪K
k=1 ET,k =⇒ ∀T1 ≤ t ≤ T, A(t) ≤ 2 2ρηλ k=1 σ 2k
PT −1 2 2
PK M k 2 K
Claim 3.8.7. (B) = t=T1 B(t) is subG(4π ληρ (T − T1 ) k=1 σ k , ∪k=1 ET,k )
P 2
P −1 Mk2
Claim 3.8.8. (C) = Tt=T 1
C(t) is subG(4ρ2 2 2
λ η (T − T1 ) K
k=1 σ 2 , ∪Kk=1 ET,k )
k
√
Here Claim 3.8.6 follows from that 2(1 − ηλ) ≥ 2 and Lemma 3.7.3. Note by
the choice of T1 , we can upper and lower bound kx(t)k2 by Lemma 3.8.5, that is
σ 2k 2σ 2k
4ηλ
≤ η −2 kxk (t)k22 ≤ ηλ
. Thus Claims 3.8.7 and 3.8.8 is a direct consequence of
Lemma 3.5.7.
Thus by Chernoff bound (Lemma 3.5.6), with probability at least 1 − (K + 2)δ,
Equation (3.18) holds.
84
Chapter 4
In contrast to SGD, adaptive gradient methods like Adam allow robust training
of modern deep networks, especially large language models. However, the use of
adaptivity not only comes at the cost of extra memory but also raises the fundamental
question: can non-adaptive methods like SGD enjoy similar benefits? In this chapter,
we provide an affirmative answer to this question by proposing to achieve both
robust and memory-efficient training via the following general recipe: (1) modify the
architecture and make it scale invariant, i.e. the scale of parameter doesn’t affect
the output of the network, (2) train with SGD and weight decay, and optionally
q
(3) clip the global gradient norm proportional to weight norm multiplied by 2λ
η
,
where η is learning rate and λ is weight decay. We show that this general approach is
robust to rescaling of parameter and loss by proving that its convergence only depends
logarithmically on the scale of initialization and loss, whereas the standard SGD
might not even converge for many initializations. Following our recipe, we design a
85
scale invariant version of Bert, called Sibert, which when trained simply by vanilla
SGD achieves performance comparable to Bert trained by adaptive methods like
Adam on downstream tasks.
4.1 Introduction
Neural architectures like transformers are the cornerstone for modern machine learning
applications. However, training them is difficult and often results in training instability
[53, 54]. To enable stable training, one typically requires adaptive and carefully tuned
learning rates. However, the reason behind this issue is not very well-understood and
lacks a formal treatment.
In this chapter, we hypothesize that a primary cause of such behavior is the
k-homogeneous (k ≥ 2) nature of the network i.e., property where network’s output is
scaled by sk when its parameters are scaled by s. To illustrate our point, we consider
the following instructive toy model.
separable, the global optimum X ∗ must be finite and with out loss of generality, we
assume it positive.
Since L
e is convex with bounded smoothness in X, there exists step size that are
86
This is because |∇L(X)|
e is positive and monotone increases among all X > X ∗ .
Since all xi are initialized equally, they must be the same at any iteration. It
X(t) X(t)
holds that xi (t + 1) = xi (t) − η xi (t) ∇L(X(t)) = xi (t) 1 − η x2 (t) ∇L(X(t)) , where
e e
i
2k
X(t)
2k
X(t) = Πj=1 xj (t). This implies X(t + 1) = X(t) 1 − η k √ ∇L(X(t))
e ≥
X(t)
2k
X(0) 1
X(t) 1 − η k √ ∇L(X(0))
e > X(t). Thus we conclude if η ≥ |∇L(X(0))|2
(X(0)) k −1
X(0) e
X(t) X(0)
and X(0) > X ∗ , η √
k
∇L(X(t))
e −1 ≥ η√
k
∇L(X(0))
e − 1 > 1 and thus X(t)
X(t) X(0)
In the above example, the success of optimization is very sensitive to the right
choice of the learning rate that depends on the initialization. Furthermore, the training
cannot recover once the norm explodes due to large gradient update.
Still it is possible to find a small workable learning rate by extensive grid search
that depends on the initial point in the above one-dimensional example. However,
the situation can get worse when the k-homogeneous structure has an unbalanced
initialization as below.
87
4
Specifically, when d = 1 and Y = 0 and for any r ≥ 1, choosing η > k∇2 L(A(0),B(0))k
Similar issues can exist in deep neural networks as the k-homogeneous structure
is quite common. For instance, Liu et al. [53] identified the gradient norm varies
with depth and that no single learning rate is globally optimal for all layers. To this
end, one has to resort to adaptive methods like Adam to handle the k-homogeneous
structure of deep networks and allow for its robust training. However, this not only
comes at the expense of higher memory, but also raises the key question of our interest:
Can non-adaptive methods like SGD enjoy fast and robust convergence without
training instability?
Answering this question, requires us to first define our notion of robustness. In
this chapter, we primarily aim for three aspects of robustness by preventing: explosion
of parameters (e.g. due to frequent large gradient updates), slow progress in training
(e.g. due to loss plateaus) and loss explosion or spikes (e.g. due to possibly infrequent
large magnitude updates). In this chapter, we propose a simple yet powerful general
approach for achieving such fast and robust convergence. At a high level, our recipe
for robust training includes three key ingredients:
2. Using SGD with weight decay for training, wherein enabling weight decay im-
proves training efficiency under rescaling of loss and initialization. While scale
invariance prevents explosion of parameters, the training convergence has strong
dependence on initialization scale and learning rate, which can make training
88
inefficient in face of parameter and initialization rescaling. Use of SGD with
weight decay circumvents this issue.
3. Using a novel Relative Global Clipping to prevent spikes in training loss and
improve overall convergence speed. Although scale invariance in the archi-
tecture already guarantees the training stability, it does not prevent severe
non-monotonic loss explosion. By using a new global clipping approach, we show
that one can prevent such loss explosions effectively.
We show that this surprisingly simple training recipe can not only improve the
memory efficiency over adaptive methods but also achieves robust training. In light of
the above background, we list our main contributions below.
• In Section 4.3, we propose a new general recipe for memory efficient, robust
training using (1) scale invariant architecture; (2) SGD+WD for training and (3)
a novel clipping rule, called Relative Global Clipping, for clipping the updates.
Following this recipe, we design a new variant of Bert called Scale Invariant
Bert (Sibert).
• In Section 4.5, we show SGD+WD with Relative Global Clipping has better
parameter norm convergence via a novel analysis. With assumptions that the
clipping does not bring too much bias in expected gradients, we show similar
convergence result to SGD+WD.
89
4.2 Related Work and Background
The literature on adaptive methods and scale invariance in neural networks is vast, so
we only discuss works that are most relevant to our paper.
Adaptive Methods & Clipping Methods. Adaptive learning rates have long
been studied [56]. In machine learning, adaptive learning rates have been popu-
larized by Adagrad, which particularly benefits from sparse stochastic gradients
[57]. Inspired by Adagrad, several adaptive methods, like Adam, RMSprop and
its variants have been proposed in the deep learning community [1, 58–61]. These
approaches have been crucial in the success of many deep learning applications [62–
64]. Several works have studied the benefits of adaptive methods in deep learning
settings (e.g. [53, 54]). However, as mentioned earlier, these benefits come at the
cost of computational and memory efficiency. Anil et al. [65] proposed a variant of
Adagrad requiring fewer parameters for adaptivity, but still requires momentum.
Adafactor [61] removes momentum and uses much fewer adaptivity parameters, but
for large models, Adafactor still needs momentum to ensure training stability [66].
Our approach is also related to normalized and projected gradient descent, which has
been studied for quasi-convex and non-convex settings (e.g. see [67–69]). However,
these methods have seen very limited success.
Clipping based optimization methods, especially gradient clipping, are widely used
in deep learning applications to improve training stability or ensure privacy [70–72].
These approaches typically use a constant threshold to clip the gradients before the
update. However, choosing this threshold is difficult and requires careful tuning.
Adaptive variants of clipping methods partially alleviate this issue and are closely
related to adaptive methods [54]; however, they again incur additional computation
and memory costs.
90
Scale Invariance in deep networks. Various normalization schemes are the main
source of scale invariance in deep learning, e.g., BatchNorm [18], LayerNorm [19],
Weight Normalization [73], GroupNorm [30], InstanceNorm [74]. Scale invariance
from normalization allows GD and SGD to converge to stationary points from any
initialization and with any learning rate, in O(T −1/2 ) and O(T
e −1/4 ) rates respectively
[36]. The interplay between SGD, scale invariance and WD has also been well studied.
It was shown that the effect of WD for normalized networks can be replaced by
LR schedules [40, 41]. Li and Arora [22] formally builds the equivalence between
SGD+WD and SGD with an exponential increasing LR schedule for scale invariant
loss. Van Laarhoven [42] first proposed the notion of effective LR, η/ kxk22 , for normal-
ized networks, and showed that the unique stationary value of kxk42 is proportional to
λ/η, where η is LR and λ is WD. Li et al. [50] proved that the parameter norm always
converges to the above value by modeling SGD as Stochastic Differential Equation.
Wan et al. [75] proved the parameter norm converges to the same value directly for
SGD+WD, but only in expectation.
4.3 Methods
In this section, we provide a more detailed description of our recipe for robust and
memory-efficient network training, which includes three building blocks: (1) scale
invariant architecture (Section 4.3.1), (2) SGD with Weight Decay (Section 4.3.2) and
optionally (3) the Relative Global Clipping (Section 4.3.3 and Algorithm 2).
√
Algorithm 2 C-Clipped SGD + WD
Input: Total steps T , Scale invariant loss {Lt }Tt≥1 , initialization x(0), LR η, WD λ,
clipping factor C > 1 (C = ∞ ⇔ no clipping).
T − 1 do
for t = 0 to nq o
2Cλ
Nt ← min η
kx(t)k 2 , k∇L t (x(t))k 2 .
∇Lt (x(t))
x(t + 1) ← (1 − ηλ)x(t) − ηNt k∇Lt (x(t))k
.
2
91
4.3.1 Designing Scaling Invariant Architectures
We first revisit an approach for introducing scale invariance in neural networks, which
is presented in [22]. Viewing the neural network computation as a directed graph, the
high level idea is to ensure same homogeneity degree of different edges reaching a node.
For example in a ResNet block, the output from an affine transform is added back to
the input z from the previous layer yielding z + Aff(z). Now if we scale all the network
parameters by c, both z and Aff(z) should have the same degree of homogeneity and
scale as ck . Otherwise the network is no longer homogeneous and, hence, cannot be
scale invariant.
In this chapter, we apply the above design philosophy to develop a scale invariant
version of Bert [63] — a transformer based model. A transformer has two main
building blocks that need to be made scale invariant – residual block and Attention [62].
For residual block, Li and Arora [22] already demonstrated how to make both the
PreNorm and PostNorm version of ResNet scale invariant (see Appendix of their
paper for more details). In this chapter, we use their PreNorm variant (see Figure 4.2).
Furthermore, we design a novel scale invariant version of Attention block in transformer,
as described below.
Scale Invariant Attention: Recall the standard self attention block computes the
following for a given input Q, K, V ∈ Rn×dmodel :
QW Q (KW K )>
Attention(Q, K, V ) = Softmax( √ )V W V .
dk
Here W Q , W K ∈ Rdmodel ×dk and W V ∈ Rdmodel ×dv are affine transformations and, hence,
are all 1-homogeneous transformations. The Softmax function computes row wise
softmax normalization. It is easy to see that standard attention is not homogeneous
as softmax is itself not homogeneous.
92
We design a novel Scale Invariant Attention (SI Attention) in the following way:
(also see Figure 4.4)
a
where N denotes the row-wise normalization by sum, i.e., [N(A)]ij = P ij and
j aij
ReLU(A) denote the element-wise max between matrix A and 0. Notably we replace
the softmax with a ReLU activation followed by normalization. Both ReLU and
normalization are homogeneous operations; thus, making the overall attention score
computation (N(ReLU(ZQK > Z > ))) scale invariant to the concatenation of all param-
eters x, assuming Q, K, V are already positive homogeneous to x. The full design of
Scale Invariant Bert (Sibert) is presented to Section 4.4.
Although scale invariance can prevent parameter divergence after a large gradient
update by eliminating the positive feedback between gradient and parameter norm, it
alone does not ensure SGD trains the network in a robust and efficient way. This is
because, as shown in [36], the parameter norm monotonically increases when SGD is
used to optimize a scale invariant loss. As a result, once the norm becomes too large
(e.g due to large gradient in some step) the training can slow down drastically as the
η
effective LR kxt k22
is too small; thus, preventing effective recovery from even minor
training instabilities.
To tackle this issue we propose to use Weight Decay(WD) as a way to reduce the
parameter norm; thereby, allowing the network to recover from slow training induced
by infrequent updates of large norm. Under mild assumptions that the expectation
of squared norm of stochastic gradient does not vary too much on the unit sphere,
1
[50, 75] show that the parameter norm will stabilize in O( ηλ ) steps and the learning
93
dynamics is equivalent to one on unit sphere with effective learning rate proportional
√
to Θ( λη).
Leveraging the advantage of quick norm convergence, it is shown in Chapter 3 that
the convergence of SGD+WD is insensitive to the following three operations: loss
rescaling (A1), initialization rescaling (A2) and re-parametrization (A3), meaning the
| log c|
same convergence rate (independent of scaling c) can be achieved, in up to λη
more
steps. (See formal statement in Theorems 3.3.1 and 3.4.2 This property reduces the
effort of hyperparameter tuning and also makes training more robust when switching
between different codebases and frameworks, which is likely to have different default
scaling or parametrization. Also note by scale invariance of loss L, (A2) is equivalent
to (A3).
(A3). (L, x(0)) → (L0 , cx(0)), where L0 is defined as L0 (x) := L( xc ) for any c > 0.
Gradient clipping is a widely used effective strategy to stabilize neural network training.
However, often the clipping threshold need to be tuned based on the optimization
problem and the specific gradient distribution. Furthermore, simply using a constant
threshold can severely degrade the performance [54]. Thus, it is unclear how the
94
clipping threshold needs to be set for SGD+WD on scale invariant functions such
that it is insensitive to rescaling of loss and reparametrization, e.g., (A1-3).
To this end, we propose a clipping strategy named Relative Global Clipping which
allows consistent and robust training behavior for SGD+WD on scale invariant loss
under the aforementioned operations. In particular, we propose to set the clipping
q √
threshold as 2Cλ
η
kxk 2 , where C ≥ 1 is a hyperparamer with default value C = 2.
The high level design idea is that (1) the clipping rule should be invariant to the
scalings (L, η, λ) → (cL, η/c, cλ) and (x, η, λ) → (cx, c2 η, λ/c2 ) for any c > 0, to which
SGD+WD is invariant (see Lemma 3.2.4); (2) the clipping rule should only remove
the extremely large gradients and should not trigger too often to ensure that gradient
after clipping remains almost unbiased.
Intuitively, the derivation of Relative Global Clipping involves the following line of
reasoning: Suppose the norm of the stochastic gradient k∇Lγ (x)k2 is constant, say
σ, for all data and every parameter x on the unit sphere. In this case, we expect
our clipping strategy to not be triggered since there are no extremely high stochastic
gradients. Since Lγ is scale invariant, Theorem 3.2.2 implies that h∇Lγ (x), xi = 0.
That is,
It is not difficult to show the iteration (4.1) has a unique stationary point, kx(t)k22 =
q
2η
λ(2−ηλ)
σ[42]. In other words, at norm equilibrium, it holds
s
σ λ(2 − ηλ)
k∇Lγ (x(t))k2 = = kx(t)k2 . (4.2)
kx(t)k2 η
95
The above calculation suggests the clipping threshold should be at least
q
2λ
kx(t)k2 . 1 Furthermore, it is not difficult to check that the clipping
η
q
2λ
threshold η
kx(t)k2 is indeed invariant to the above mentioned scalings
(L, η, λ) → (cL, η/c, cλ) and (x, η, λ) → (cx, c2 η, λ/c2 ). For each hyperparame-
ter C > 1, the behavior of SGD+WD is consistent for different scalings (A1-3) and
it also improves the norm convergence (reducing undesirable spikes in norm while
training) for SGD+WD (see Theorem 4.5.3). Under mild assumptions that such
clipping does not introduce too much bias in gradients, we show that our recipe
enables convergence to approximate stationary points. Furthermore, the rate only
depends logarithmically on the initialization and loss scale, as shown in the following
section.
Following Section 2.10, we view the computation graph as a directed acyclic graph,
where each module is a node and each tensor (including inputs, intermediate compu-
tation results and final output) as an edge. Each edge can be viewed as a function
of parameters, and we can decide the homogeneity by doing induction over the com-
putation graph by its topological order. In detail, we know the jth output edge
of some (a1 , . . . , an ; b1 , . . . , bn )- homogeneous module is bj homogeneous if for each
1 ≤ i ≤ n, the ith input edge is ai -homogeneous. For convenience, we allow ai ,bi to be
functions of free variable x, meaning the module is (a1 (x), . . . , an (x); b1 (x), . . . , bm (x))-
homogeneous for every x ∈ Rd .
In Table 4.1, we summarize the homogeneity of building blocks in our design.
1
We drop −ηλ for convenience. This doesn’t lead to any practical difference as ηλ is typically
very small, e.g. less than10−4 .
96
Overview of SIBERT structure: Our SIBERT has two main parts — encoder
and classification head, which is the same to standard BERT. We only make encoder
part scale invariant and train it by SGD+WD. We leave the classification head not
scale invariant and train it by Lamb. Note the classification head is only used in
pretraining and is not used in the downstream task.
Figure 4.1: Encoder and Classification Head (CLS). ‘x12/24’ means to stack 12 our
(2; 2)-homogeneous encoder layer for base SIBERT (or 24 for large SIBERT)
97
Figure 4.2: The (2; 2)-homogeneous encoder layer. ‘ATTN’ denotes our Scale Invariant
Attention (see Figure 4.4). ‘FF’ denotes the 2-layer feedforward structure, which is
(0; 2)-homogeneous.
98
small init=0.002 medium init=0.02 large init=0.2
8 SIBERT, SGD 8 SIBERT, SGD 8 SIBERT, SGD
SIBERT, SGD+WD SIBERT, SGD+WD SIBERT, SGD+WD
Training Loss
Training Loss
Training Loss
SIBERT,2-clipped SIBERT,2-clipped SIBERT,2-clipped
6 SGD+WD 6 SGD+WD 6 SGD+WD
4 4 4
2 2 2
0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0
Steps 1e6 Steps 1e6 Steps 1e6
Figure 4.5: SGD+WD optimizes the scale invariant training loss of Sibert robustly
for all initialization scales, and thus for loss scalings and different learning rates
(with λη fixed). Here the default initialization for parameters in Sibert encoder is
a truncated normal distribution with standard deviation equal to 0.02 (the same as
Bert).
Clipping
Now we will present our analysis for the clipped SGD. Recall the clipped SGD update
from Algorithm 2 has the following norm dynamics.
Norm dynamics of clipped SGD:
( )
k∇Lγ (x(t))k22 2λC
kx(t + 1)k22 = (1 − ηλ)2 kx(t)k22 + η 2 min 2 , kx(t)k22 .
kx(t)k2 η
99
Let Px denote the distribution of k∇Lγ (x)k22 . Below is a mild assumption saying
Px is universally well-concentrated from below in the sense that the mean of the
smallest (1 − C1 ) part of Px is at least a constant fraction of the C-clipped mean of Px .
Since µPx ,C ≤ µx , the assumption below holds whenever αC µx ≤ Et∼Px [t1[t < MPx , 1 ]].
C
Assumption 4.5.2. ∃αC > 0, such that for all x 6= 0, αC · µPx ,C ≤ Et∼Px [t1[t <
MPx , 1 ]].
C
We further define µC := min µPx ,C and µC := max µPx ,C and have the following
kxk2 =1 kxk2 =1
theorem:
√ √
Theorem 4.5.3 ( C-Clipped SGD+WD). Let x(t) be defined by C-Clipped SGD
+WD (Algorithm 2). Under Assumption 4.5.2, for ηλ = O(min{1, C lnαTC/δ2 }), with
probability 1 − 5δ, we have
µC 2λ
∀T 0 ≤ t ≤ T − 1, ≤ kx(t)k42 ≤ 2µC . (4.3)
2 η
and
T −1 D 2
√ 3
1 X E π ρ µ C
p ρµ 2
C
∇L(x(t)), ∇L(x(t))
g ≤ √ + 4 ηλ
T − T 0 t=T 0 (T − T 0 ) 2ηλ µC
s s (4.4)
2 2 2 3
ln δ πρµC ln δ p ρµ
+ 0
8 + 0
16 λη 2C .
T −T µC T −T µC
kxk22
n o h nq oi
R2 µ
where T 0 = 1
αC ηλ
max ln µ 0 , ln RC2 +O(1) and ∇L(x)
g := E ∇Lγ (x) min 2Cλ
η k∇Lγ (x)k
, 1 .
C 0 2
The proof of this theorem is presented in Section 4.7. Note that with clipping
Theorem 4.5.3 shows that the norm convergence (4.3) is more robust as it doesn’t need
to make any assumption about the maximum gradient norm M , unlike Theorem 3.4.2.
Indeed, from the definition of C-clipped mean, for each x, we can allow all the
gradients with norm larger than C · µPx ,C to become infinity, and yet not affect the
norm convergence, as µPx ,C and the condition in Assumption 4.5.2 do not change.
100
107
small init=0.002 107
medium init=0.02 107
large init=0.2
SIBERT, SGD SIBERT, SGD
SIBERT, SGD+WD SIBERT, SGD+WD
106 SIBERT,2-clipped 106 SIBERT,2-clipped 106
SGD+WD SGD+WD SIBERT, SGD
Norm Sqaure
Norm Sqaure
Norm Sqaure
SIBERT, SGD+WD
105 105 105 SIBERT,2-clipped
SGD+WD
104 104 104
Figure 4.6: The robust optimization performance of SGD+WD over the scale invariant
training loss of Sibert originates from its ability to fast adjust the parameter norm.
In contrast, when the initial norm is too large, SGD w.o. WD optimizes slowly.
Relative Global Clipping reduces the spikes in the norm curve, which verifies our
theoretical result Theorem 4.5.3 that clipping leads to better norm convergence. Here,
only the norm of the scale invariant part, i.e., the encoder part is plotted.
D E
Under the additional assumption that ∇L(x(t)), ∇L(x(t)
g = Ω(k∇L(x(t))k22 ), we
can use Equation (4.4) to show convergence to stationary points. This is a reasonable
assumption if the clipping frequency is low, e.g., it’s 1.5% in our experiments for
Sibert.
4.6 Experiments
101
5.0
4.5
4.0
3.5
Training Loss
3.0 BERT, SGD, small LR
SIBERT,2-clipped
2.5 SGD+WD
SIBERT, AdamW
2.0 BERT, AdamW
1.5
1.00.0 0.2 0.4 0.6 0.8 1.0
Steps 1e6
Figure 4.7: Our recipe (Sibert, SGD+WD and Relative Global Clipping) significantly
improves the optimization performance compared to the baseline, Bert trained by
SGD with small LR. The final training loss is close to Bert trained by Adam.
non-scale invariant parts. The initial LR for SGD is 8e − 4 without warmup and is
divided by 10 at step 600k and 900k. Default training is for 1M steps. For Lamb we
use a linear decay schedule with initial learning rate 8e − 4 and a linear warmup of
10k steps.
102
Next, we compare the downstream performance on three benchmark datasets
(SQuADv1.1 [77], SQuADv2 [78] and MNLI [79]). We tried to follow standard setup,
e.g. Bert is finetuned by Adam. However for Sibert we had to use LAMB, as
Adam is very sensitive to the scale. We observe comparable performance and when
trained longer it can even outperform conventional Bert.
Sibert
+ clipping 82.6 89.3 76.8 1.58
+ 2x training 83.3 90.3 80.0 1.495
Bert 86.8 92.4 84.1 1.181
Large
103
when starting from very different initialization scale, SGD+WD (+clipping) quickly
brings parameter norm to desired ranges. In contrast, SGD struggles when initial
norm and learning rate are not aligned - see the rightmost plot with large initialization
in Figure 4.6. This shows that our recipe has the ability to quickly adapt to different
initialization scales, in-line with our theoretical result (Theorem 4.5.3) showing better
norm convergence of SGD+WD (+clipping).
Global Clipping
Lemma 4.7.1 (General Properties of GP,C ). For any C > 1 and measure P supported
on R≥0 , it holds that
1
3. C
MP, 1 ≤ µP,C ≤ µP , where µP is the expectation of P .
C
Proof of Lemma 4.7.1. (1). Note min{x, ·} is a continuous and concave function
for any x, we know GP,C is a concave function. (2). When GP,C is differentiable,
we have G0P,C (µ) = CFP,C
0
(Cµ) − 1. Let G0P,C (µ) = 0 implies that FP,C
0
(Cµ) =
0
1
C
. Note FP,C (Cµ) = Pt∼P [t > FP,C ], we know G0P,C ( C1 MP, 1 ) = 0. By concavity,
C
supµ≥0 GP,C (µ) = GP,C ( C1 MP, 1 ). This argument can be easily generalized to non-
C
differentiable case by using GP,C (µ) must be larger than GP,C (µ ± δ) for infinitesimal
104
δ. (3). First note that FP,C (MP, 1 ) = Et∼P [min{t, MP, 1 }] ≥ MP, 1 · Pt∼P [t ≥ MP, 1 ] =
C C C C
1
C
MP, 1 . In other words, GP,C ( C1 MP, 1 ) ≥ 0.
C C
1
Now suppose C
MP, 1 > µP,C . If GP,C ( C1 MP, 1 ) = 0, then by definition, 1
C
MP, 1 ≤
C C C
1. If P[x = 0] < 1 − C1 , then FP,C (Cµ) = µ has exact two solutions which are 0 and
µP,C > 0;
1 1
2. If P[x = 0] = 1 − C
, then FP,C (Cµ) = µ for all 0 ≤ µ ≤ C
MP,C and µP,C =
1
C
MP,C ;
3. If P[x = 0] > 1 − C1 , then FP,C (Cµ) = µ has only one solution which is µP,C = 0.
Proof. Suppose there are two solutions 0 < µ1 < µ2 . By concavity, we have ∀0 ≤ µ ≤
µ2 , GP,C (µ) = 0. Thus 0 = GP,C (0) + GP,C (µ2 ) = 2g( µ22 ), which implies that
Cµ2
Et∼P [min{t, Cµ2 }] = 2Et∼P [min{t, }] = Et∼P [min{2t, Cµ2 }],
2
that is, Pt∼P [t ≥ Cµ2 ∨ t = 0] = 1. Thus for any 0 ≤ µ ≤ µ2 , we have GP,C (µ) =
1
CµP[x ≥ Cµ2 ] − µ = 0, which implies µ2 = C
MP, 1 and P[x = 0] = 1 − C1 !
C
Lemma 4.7.3. Under Assumption 4.5.2, it holds that GP,Cx ( C1 MPx , 1 ) ≥ αC µPx ,C for
C
all x 6= 0.
1 1
GP,Cx ( MPx , 1 ) = Et∼Px [t1[t < MPx ,C ]] + (Pt∼Px [t ≥ MPx ,C ] − ) · MPx ,C . (4.6)
C C C
1
By the definition of the C
-median, the second term is non-negative. The proof is
completed by applying Assumption 4.5.2.
105
Lemma 4.7.4 (Lower and upped bounds for GPx ,C ). Under Assumption 4.5.2, it
holds that
µPx ,C
1. GPx ,C (µ) ≥ αC µ, for 0 ≤ µ ≤ 2
;
µPx ,C
2. GPx ,C (µ) ≥ αC (µPx ,C − µ), for 2
≤ µ ≤ µPx ,C ;
Proof of Lemma 4.7.4. By Lemma 4.7.3, Assumption 4.5.2 implies that GP,Cx ( C1 MPx , 1 ) ≥
C
The above inequalities also directly imply the following version using µC and µC
as thresholds.
Lemma 4.7.5 (Uniform Lower and upped bounds for GPx ,C ). Under Assumption 4.5.2,
it holds that for kxk2 = 1,
µC
1. GPx ,C (µ) ≥ αC µ, for 0 ≤ µ ≤ 2
;
µC
2. GPx ,C (µ) ≥ αC (µC − µ), for 2
≤ µ ≤ µC ;
αC µ 4µC
4. GPx ,C (µ) ≥ 4
, for 0 ≤ µ ≤ 5
; (4. follows from Property 1. and 2.)
s
t
X h i √ 1 2T 2
βl t−s (e gs | x(s)]) 1 Rs2 ≤ µC
gs − E[e ≤ CµC ln .
s=t0
1 − βl 2 δ
106
Let ET2 be the event that ∀0 ≤ t0 ≤ t ≤ T,
s
t
X √ 1 2T 2
βl t−s (e gs | x(s)]) 1 Rs2 ≤ 2µC
gs − E[e ≤ 2 CµC ln .
s=t0
1 − βl 2 δ
t
r
X 2T 2
g s − E[g s | x(s)] ≤ C T ln .
s=t0
δ
Proof of Lemma 4.7.6. Note the sequence in ETi are martingales whose differences
are uniformly bounded (µC , µC and C). The lemma follows directly from Hoeffding
Inequality and Azuma Inequality.
Theorem 4.7.7 (Norm lower bound with clipping: Warm Start). Suppose Assump-
tion 4.5.2 holds, with probability at least 1 − δ (or whenever ET1 holds), if Rt2 ≥ 34 µC ,
then for any t0 ≥ t, we have
s !
t0 −t
βl p 2C T2
Rt20 ≥ 1− − O( ηλ) − ηλ ln (1 + O(ηλ)) µC (4.7)
4 αC δ
µC
Proof. We first claim for any t ≤ t0 ≤ T , conditioned on ET1 , it holds that Rt20 ≥ 2
.
µC
Below we prove by contradiction. If not, let t0 be the smallest step such that Rt20 < 2
.
We let t∗ be the largest step between t and t0 such that Rt2∗ ≥ µC (t∗ = t − 1 is no
such t∗ exists) Thus if t∗ ≥ t then Rt2∗ +1 is at least (1 − ηλ)4 Rt2 = (1 − O(ηλ))µC .
√
Otherwise t∗ = t and it implies that Rt2∗ +1 = Rt2 = ( 34 − O( ηλ))µC . By the definition,
we know for any t∗ + 1 ≤ s ≤ t0 , Rs2 ≤ µC .
107
Similar to Equation (3.10), we have
2
Rs+1 =Rs2 (1 − ηλ)4 + 4ηλ(1 − ηλ)2 ges + 4η 2 λ2 get2
Thus for any s such that µC ≤ Rs2 ≤ 2µC , by Lemma 4.7.5, it holds that
2
Rs+1 ≥Rs2 (1 − 2η 2 λ2 + η 4 λ4 )
That is,
2
4ηλαC (1 − ηλ)2 µC
Rs+1 −
1 − βl
4ηλαC (1 − ηλ)2 µC
≥βl (Rs2 − )
1 − βl
+4ηλ(1 − ηλ)2 (e
gs − E[e
gs | x(s)])
108
Applying the above inequality for s = t∗ + 1, . . . , t0 − 1, we have that
!
t0 −t∗ −1
4ηλαC (1 − ηλ)2 µC
Rt20 ≥ βl Rt2∗ +1 −
1 − βl
| {z }
(A)
4ηλαC (1 − ηλ)2 µC
+
1 − βl
| {z }
(B)
t 0
X h i
2 t−s
+ 4ηλ(1 − ηλ) βl gs − E[e
(e gs | x(s)]) 1 Rs2 ≤ µC .
s=t∗ +1
| {z }
(C)
For term (B), we have 1 − βu = 4ηλαC (1 − ηλ)2 (1 + O(ηλ)) and thus (B) =
0 ∗ √
µC (1 + O(ηλ)). Since Rt∗ +1 ≥ 43 µC , it holds that (A) ≥ −βl t −t −1 ( 14 + O( λη))µC ≥
√
−( 41 + O( λη))µC . Since ET1 holds, we have
s s
√ 1 2T 2 2C T2
|(C)| ≤ 4ηλ(1 − ηλ)2 · CµC ln = µC ηλ ln (1 + O(ηλ))
1 − βl 2 δ αC δ
Thus there’s some constant ι, such for ηλ ≤ min{ι, 64C αlnCT 2 /δ }, (A) + (B) + (C) ≥
√ √ µ
( 6−8 2 − O( ηλ))µC ≥ 2C . This leads to a contradiction to the definition of t0 . Thus
µ
for any t ≤ t0 ≤ T , conditioned on ET1 , it holds that Rt20 ≥ 2C . Furthermore, if t∗ 6= t,
√ √
then Rt∗ +1 ≥ (1 − O( ηλ))µC . Thus (A) ≥ −O( ηλ)µC . Otherwise if t∗ = t, then
0 √
(A) ≥ −βl t −t ( 14 + O( λη))µC . Combine the bounds in these two cases, we conclude
that
s !
t0 −t
βl p 2C T2
Rt20 ≥ 1− − O( ηλ) − ηλ ln (1 + O(ηλ)) µC
4 αC δ
Theorem 4.7.8 (Norm upper bound with clipping: Warm Start). Suppose Assump-
tion 4.5.2 holds, with probability at least 1 − δ (or whenever ET2 holds), if Rt2 ≤ 32 µC ,
109
then for any t0 ≥ t, we have
0
s !
βl t −t p 2C T2
Rt20 ≤ 1+ + O( ηλ) + ηλ ln (1 + O(ηλ)) µC
2 αC δ
2
Rs+1 ≤Rs2 (1 − ηλ)4 + 4ηλ(1 − ηλ)2 ges + 4η 2 λ2 gbs2
Thus for any s such that µC ≤ Rs2 , by Lemma 4.7.5, it holds that
2
Rs+1 ≤Rs2 (1 − 2η 2 λ2 + η 4 λ4 + 4η 2 λ2 C 2 )
110
That is,
2 4ηλαC (1 − ηλ)2 µC
Rs+1 −
1 − βu
4ηλαC (1 − ηλ)2 µC
≤βu (Rs2 − ) + 4ηλ(1 − ηλ)2 (e
gs − E[e
gs | x(s)])
1 − βu
4ηλαC (1 − ηλ)2 µC
t0 −t∗ −1
Rt20 ≤ βu Rt2∗ +1 −
1 − βu
| {z }
(A)
2
4ηλαC (1 − ηλ) µC
+
1 − βu
| {z }
(B)
t 0
X
2
βu t−s (e gs | x(s)]) 1 Rs2 ≤ 2µC .
+ 4ηλ(1 − ηλ) gs − E[e
s=t∗ +1
| {z }
(C)
For term (B), we have 1 − βu = 4ηλαC (1 − ηλ)2 (1 + O(ηλ)) and thus (B) =
0 ∗ √
µC (1 + O(ηλ)). Since Rt∗ +1 ≤ 23 µC , it holds that (A) ≤ βu t −t −1 ( 12 + O( λη))µC ≤
√
( 21 + O( λη))µC . Since ET2 holds, we have that
s s
√ 1 2T 2 2C T2
|(C)| ≤ 8ηλ(1 − ηλ)2 · CµC 2 ln = 2µC ηλ ln (1 + O(ηλ))
1 − βu δ αC δ
Thus there’s some constant ι, such for ηλ ≤ min{ι, 64C αlnCT 2 /δ }, (A) + (B) + (C) ≤
√ √
( 6+4 2 + O( ηλ))µC ≤ 2µC . This leads to a contradiction to the definition of t0 . Thus
for any t ≤ t0 ≤ T , conditioned on ET1 , it holds that Rt20 ≥ 2µC . Furthermore, if t∗ 6= t,
√ √
then Rt∗ +1 ≤ (1 + O( ηλ))µC . Thus (A) ≤ O( ηλ)µC . Otherwise if t∗ = t, then
0 √
(A) ≤ βu t −t ( 12 + O( λη))µC . Combine the bounds in these two cases, we conclude
111
that
s !
t0 −t
βl p 2C T2
Rt20 ≤ 1+ + O( ηλ) + ηλ ln (1 + O(ηλ)) µC
2 αC δ
µC
≤ Rt2 ≤ 2µC .
2
0 0
p p
Rt2 ∈ [(1 − βlt−T )µC − O(
e λη), µC (1 + βut−T ) + O(
e λη)].
Proof of Theorem 4.7.9. We will prove the desired inequality always holds when ETi
holds, for i = 1, 2, 3. We have already proved the result for the case where 34 µC ≤
Rt2 ≤ 32 µC in Theorems 4.7.7 and 4.7.8. Now we turn to the case where R02 ≥ 32 µC
and R02 ≤ 12 µC . Our goal is to prove with high probability, that Rt2 ∈ [ 34 µC , 23 µC ] for
at least some t < T 0 .
Below we first show ∃0 < t < T 0 , Rt2 ≤ 23 µC . Otherwise, similar to Equation (4.9),
2
Rs+1 ≤Rs2 (1 − ηλ)4 + 4ηλ(1 − ηλ)2 ges + 4η 2 λ2 gbs2
112
Thus for any s such that 23 µC ≤ Rs2 , by Lemma 4.7.5, it holds that
αC 2
GPx(s) ,C (Rs2 ) = E[e
gs | x(s)] − Rs2 ≥ αC (µC − Rs2 ) ≥ − R .
3 s
Thus,
2
Rs+1 ≤Rs2 (1 − 2η 2 λ2 + η 4 λ4 + 4η 2 λ2 C 2 )
4
− ηλαC (1 − ηλ)2 Rs2 + 4ηλ(1 − ηλ)2 (e
gs − E[e
gs | x(s)])
3
2 2 2 4 4 2 2 2 4 2 2
=Rs 1 − 2η λ + η λ + 4η λ C − ηλαC (1 − ηλ) + 4ηλ(1 − ηλ) (g s − E[g s | x(s)])
3
2 4
ln Rs+1 − ln Rs2 ≤ − ηλαC + ηλ(g s − E[g s | x(s)]) + O(η 2 λ2 )
3
r
3 2 2 2 4T 2T 2
ln + ln µC − ln R0 ≤ ln RT 0 − ln R0 ≤ − ηλαC + Cηλ T ln + O(η 2 λ2 T ),
4 3 δ
2
R0 µ
max ln ,ln C +O(1)
µC R02
which is in contradiction with the definition of T 0 = αC ηλ
.
Now we show ∃0 < t < T 0 , Rt2 ≥ 43 µC . Otherwise, similar to Equation (4.9),
2
Rs+1 =Rs2 (1 − ηλ)4 + 4ηλ(1 − ηλ)2 ges + 4η 2 λ2 get2
Thus for any s such that Rs2 ≤ 54 µC , by Lemma 4.7.5, it holds that
αC 2
GPx(s) ,C (Rs2 ) = E[e
gs | x(s)] − Rs2 ≥ R .
4 s
113
Thus, we have that
2
Rs+1 ≥Rs2 (1 − 2η 2 λ2 + η 4 λ4 )
2
ln Rs+1 − ln Rs2 ≥ ηλαC + ηλ(g s − E[g s | x(s)]) + O(η 2 λ2 )
r
2T 2
ln µC − ln R02 ≥ ln RT2 0 − ln R02 ≥ T ηλαC − Cηλ T ln + O(η 2 λ2 T ),
δ
2
R0 µ
max ln ,ln C +O(1)
µC R02
which is in contradiction with the definition of T 0 = αC ηλ
.
114
Part II
115
Chapter 5
116
5.1 Introduction
Deep convolutional nets (“ConvNets”) are at the center of the deep learning revo-
lution [48, 80, 81]. For many tasks, especially in vision, convolutional architectures
perform significantly better their fully-connected (“FC”) counterparts, at least given
the same amount of training data. Practitioners explain this phenomenon at an
intuitive level by pointing out that convolutional architectures have better “inductive
bias”, which intuitively means the following: (i) ConvNet is a better match to the
underlying structure of image data, and thus are able to achieve low training loss with
far fewer parameters (ii) models with fewer total number of parameters generalize
better.
Surprisingly, the above intuition about the better inductive bias of ConvNets over
FC nets has never been made mathematically rigorous. The natural way to make
it rigorous would be to show explicit learning tasks that require far more training
samples on FC nets than for ConvNets. (Here “task”means, as usual in learning theory,
a distribution on data points, and binary labels for them generated given using a fixed
labeling function.) Surprisingly, the standard repertoire of lower bound techniques in
ML theory does not seem capable of demonstrating such a separation. The reason is
that any ConvNet can be simulated by an FC net of sufficient width, since a training
algorithm can just zero out unneeded connections and do weight sharing as needed.
Thus the key issue is not an expressiveness per se, but the combination of architecture
plus the training algorithm. But if the training algorithm must be accounted for, the
usual hurdle arises that we lack good mathematical understanding of the dynamics of
deep net training, whether FC or ConvNet. How then can one establish such limitation
of “FC nets + current training algorithms”? (Indeed, many lower bound techniques
in PAC learning theory are information theoretic and ignore the training algorithm.)
The current paper makes significant progress on the above problem by exhibiting
simple tasks that require Ω(d2 ) factor more training samples for FC nets than for
117
ConvNets, where d is the data dimension. (In fact this is shown even for 1-dimensional
ConvNets; the lowerbound easily extends to 2-D ConvNets.) The lower bound holds
for FC nets trained with vanilla SGD with Gaussian initialization of network weights,
with the optional use of momentum, `2 regularization, and various learning rate
schedules. Our proof relies on the fact that these popular algorithms lead to an
orthogonal-equivariance property on the trained FC nets, which says that at the end
of training the FC net —no matter how deep or how wide — will make the same
predictions even if we apply orthogonal transformation on all datapoints (i.e., both
training and test). This notion is inspired by Ng [82] (where it is named “orthogonal
invariant”), which showed the power of logistic regression with `1 regularization versus
other learners. For a variety of learners (including kernels and FC nets) that paper
described explicit tasks where the learner has Ω(d) higher sample complexity than
logistic regression with `1 regularization. The lower bound example and technique
can also be extended to show a (weak) separation between FC nets and ConvNets.
(See Section 5.5.2)
Our separation is quantitatively stronger than the results by Ng [82] because the
sample complexity gap is Ω(d2 ) vs O(1), and not Ω(d) vs O(1). But in a more subtle
way our result is conceptually far stronger: the technique by Ng [82] seems incapable
of exhibiting a sample gap of more than O(1) between Convnets and FC nets in our
framework. The reason is that the technique by Ng [82] can exhibit a hard task for FC
nets only after fixing the training algorithm. But there are infinitely many training
algorithms once we account for hyperparameters associated in various epochs with
LR schedules, `2 regularizer and momentum, etc.. Thus the technique by Ng [82]
cannot exclude the possibility that the hard task for “FC net + Algorithm 1” is easy
for “FC net + Algorithm 2”. Note that we do not claim any issues with the results
claimed by Ng [82]; merely that the technique cannot lead to a proper separation
between ConvNets and FC nets, when the FC nets are allowed to be trained with any
118
1.0 Gauss 1.0 cifar-10
0.9 0.9
0.8 0.8
test acc
test acc
2-layer cnn w/ quadratic
3-layer cnn w/ relu
0.7 0.7 resnet14 cnn
hybrid w/ quadratic
hybrid w/ relu
2-layer fc w/ quadratic
0.6 0.6 3-layer fc w/ quadratic
3-layer fc w/ relu
3-layer fc w/ relu + bn
0.5 2 3 4 5 6
0.5 2 3 4 5
10 10 10 10 10 10 10 10 10
# training data # training data
of the infinitely many training algorithms. (Section 5.5.2 spells out in more detail the
technical difference between our technique and Ng’s idea.)
The reader may now be wondering what is the single task that is easy for ConvNets
but hard for FC nets trained with any standard algorithm? A simple example is the
following: data distribution in Rd is standard Gaussian, and target labeling function
is the sign of d/2
P 2
Pd 2
i=1 zi − i=d/2+1 zi . Figure 5.1 shows that this task is indeed much
more difficult for FC nets. Furthermore, the task is also hard in practice for data
distributions other than Gaussian; the figure shows that a sizeable performance gap
exists even on CIFAR images with such a target label.
Extension to broader class of algorithms. The orthogonal-equivariance property
holds for many types of practical training algorithms, but not all. Notable exceptions
are adaptive gradient methods (e.g. Adam and AdaGrad), `1 regularizer, and initial-
119
ization methods that are not spherically symmetric. To prove a lower bound against
FC nets with these algorithms, we identify a property, permutation-invariance, which
is satisfied by nets trained using such algorithms. We then demonstrate a single and
natural task on Rd × {±1} that resembles real-life image texture classification, on
which we prove any permutation-invariant learning algorithm requires Ω(d) training
examples to generalize, while Empirical Risk Minimization with O(1) examples can
learn a convolutional net.
Structure of this chapter. In Section 5.2 we discuss about related works. In
section 5.3, we define the notation and cover some preliminaries in PAC learning. In
Section 5.4, we define algorithmic equivariance and prove the orthogonal equivariance
of FC-Net trained by gradient descent. In Section 5.5, we give two warmup examples
and an overview for the proof technique for the main theorem. In Section 5.6, we
present our main results on the lower bound of orthogonal and permutation equivariant
algorithms.
Du et al. [83] attempted to investigate the reason why convolutional nets are more
sample efficient. Specifically they prove O(1) samples suffice for learning a convolutional
filter and also proved a Ω(d) min-max lower bound for learning the class of linear
classifiers. Their lower bound is against learning a class of distributions, and their
work fails to serve as a sample complexity separation, because their upper and lower
bounds are proved on different classes of tasks.
Arjevani and Shamir [84] also considered the notion of distribution-specific hardness
of learning neural nets. They focused on proving running time complexity lower
bounds against so-called ”orthogonally invariant” and ”linearly invariant” algorithms.
However, here we focus on sample complexity.
120
Recently, there has been progress in showing lower bounds against learning with
kernels. Wei et al. [85] constructed a single task on which they proved a sample
complexity separation between learning with neural networks vs. with neural tangent
kernels. Notably the lower bound is specific to neural tangent kernels [15]. Relatedly,
Allen-Zhu and Li [86] showed a sample complexity lower bound against all kernels for
a family of tasks, i.e., learning k-XOR on the hypercube.
We will use Z = Rd , Y = {−1, 1} to denote the domain of the data and label and H =
{h | h : Z → Y} to denote the hypothesis class. Formally, given a joint distribution
P , the error of a hypothesis h ∈ H is defined as errP (h) := Pz,y∼P [h(z) 6= y]. If h is a
random hypothesis, we define errP (h) := Pz,y∼P,h [h(z) 6= y] for convenience. A class
of joint distributions supported on Z × Y is referred as a problem, P.
We use k·k2 to denote the spectrum norm and k·kF to denote the Frobenius norm
of a matrix. We use A ≤ B to denote that B − A is a semi-definite positive matrix.
We also use O(d) and GL(d) to denote the d-dimensional orthogonal group and general
2
linear group respectively. We use Bpd to denote the unit Schatten-p norm ball in Rd×d .
We use N (µ, Σ) to denote Gaussian distribution with mean µ and covariance
Σ. For random variables X and Y , we denote X is equal to Y in distribution by
d
X = Y . In this work, we also always use PZ to denote the distributions on Z
and P to denote the distributions supported jointly on Z × Y. Given an input
distribution PZ and a hypothesis h, we define PZ h as the joint distribution on
Z × Y, such that (PZ h)(S) = PZ ({z|(z, h(z)) ∈ S}), ∀S ⊂ Z × Y. In other
words, to sample (Z, Y ) ∼ PZ h means to first sample Z ∼ PZ , and then set
Y = h(Z). For a family of input distributions PZ and a hypothesis class H, we define
121
PZ H = {PZ h | PZ ∈ PZ , h ∈ H}. In this work all joint distribution P can be
written as PZ h for some h, i.e. PY|Z is deterministic.
For set S ⊂ Z and bijection g : Z → Z, we define g(S) = {g(x)|x ∈ S}. We use
◦ to denote function composition. (f ◦ g)(x) is defined as f (g(x)), and for function
classes F, G, F ◦ G = {f ◦ g | f ∈ F, g ∈ G}. For any distribution PZ supported on
Z , we define PZ ◦ g as the distribution such that (PZ ◦ g)(S) = PZ (g(S)). In other
words, if Z ∼ PZ ⇐⇒ g −1 (Z) ∼ PZ ◦ g, because
−1
∀S ⊆ Z, P g (Z) ∈ S = P [Z ∈ g(S)] = [PZ ◦ g](S).
Z∼PZ Z∼PZ
122
Algorithm 3 Iterative algorithm A
Require: Initial parameter distribution Pinit supported in W = Rm , total iterations
T , training dataset {zi , yi }ni=1 , parametric model M : W → H, iterative update
rule F (x, M, {zi , yi }ni=1 )
Ensure: Hypothesis h : Z → Y.
Sample x(0) ∼ Pinit .
for t = 0 to T − 1 do
n
x(t+1) = F (x(t) , M, {zi , y(Ti }) i=1 ).
return h(·) = sign M[x ](·) .
For function class H, we use ΠH (n) to denote the growth function of H, i.e.
ΠH (n) := sup |{(h(z1 ), h(z2 ), . . . , h(zn )) | h ∈ H}| . The Vapnik–Chervonenkis(VC)
z1 ,...,zn ∈Z
dimension of H is defined as the largest integer such that ΠH (n) = 2n and we denote
VCdim(H)
en
it by VCdim(H). By Sauer-Shelah Lemma, we know ΠH (n) ≤ VCdim(H) for
n ≥ VCdim(H).
VCdim(H) ln 1ε + ln 1δ
N (A, PZ H, ε, δ) = O . (5.1)
ε
VCdim(H) + ln 1δ
N (A, PZ H, ε, δ) = Ω . (5.2)
ε
123
5.3.2 Parametric Models and Iterative Algorithms
Here, σ : R → R can be any function, and we abuse the notation such that σ is also
defined for vector inputs, in the sense that [σ(z)]i = σ(zi ).
ConvNets (CNN): In this chapter we will only use two layer Convolutional
Neural Networks with one channel. Suppose d = d0 r for some integers d0 and r, a
2-layer CNN parameterized by its weights x = (w, a, b) ∈ Rk × Rr × R is a function
CNN[x](·) : Rd → R:
r
X
CNN[x](z) = ai σ([w ∗ z]d0 (i−1)+1:d0 i ) + b,
i=1
wise non-linearity.
124
5.4 Algorithmic Equivariance in Fully-connected
In this section, we first give the formal definition of equivariant algorithms. Then we
start with an informal sketch of why FC nets trained with standard algorithms have
certain equivariance properties and then give the formal proof.
The high level idea here is if update rule of the network, or more generally,
the parametrized model, exhibits certain symmetry per step, i.e., property 2 in
Theorem 5.4.2, then by induction it will hold till the last iteration.
Taking linear regression as an example, let zi ∈ Rd , i ∈ [n] be the data and y ∈ Rn
2
be the labels, the GD update for L(w) = 12 ni=1 (z> 2 1 >
P
i w − yi ) = 2 Z w − y 2 would
2. Update rule F is invariant under any joint group action (g, τ (g)), ∀g ∈ G. In
other words, [τ (g)](F (x, M, {zi , yi }ni=1 )) = F ([τ (g)](x), M, {g(zi ), yi }ni=1 ).
Here we want to address that the three conditions in Theorem 5.4.2 are natural and
almost necessary. Condition 1 is the minimal expressiveness requirement for model
M to allow equivariance. Condition 3 is required for equivariance at initialization.
Condition 2 is necessary for induction.
126
Proof of Theorem 5.4.2. ∀g ∈ GZ , we sample x(0) ∼ Pinit , and x̃(0) = τ (g)(x(0) ).
d
By property (3), x̃(0) = x(0) ∼ Pinit . Let x(t+1) = F x(t) , M, {zi , yi }ni=1 and
d
A {zi , yi }ni=1 = M[x(T ) ],
and
d
M[x̃(T ) ] ◦ g = A({g(zi ), yi }ni=1 ) ◦ g.
By property (1), we have M[x̃(T ) ](g(z)) = M[τ (g)(x(T ) ](g(z)) = M[x(T ) ](z).
d d
Therefore, A({zi , yi }ni=1 ) = M[x(T ) ] = M[x̃(T ) ] ◦ g = A({g(zi ), yi }ni=1 ) ◦ g, meaning A
is GZ -equivariant.
Remark 5.4.3. Theorem 5.4.2 can be extended to the stochastic case and the adaptive
case which allows the algorithm to use information of the whole trajectory, i.e., the
update rule could be generalized as x(t+1) = Ft ({x(s) }ts=1 , M, {zi , yi }ni=1 ), as long as
(the distribution of) each Ft is invariant under joint transformations.
Below are two example applications of Theorem 5.4.2. Other results in Table 5.1
could be achieved in the same way.
For classification tasks, optimization algorithms often work with a differentiable
surrogate loss ` : R → R instead the 0-1 loss, such that `(yh(z)) ≥ 1 [yh(z) ≤ 0],
and the total loss for hypothesis h and training, L(M[x]; {zi , yi }ni=1 ) is defined as
Pn
i=1 `(M[x](zi )yi ). It’s also denoted by L(x) when there’s no confusion.
Definition 5.4.4 (Gradient Descent for FC nets, Algorithm 4). We call Algorithm 3
Gradient Descent for FC nets if M = FC-NN and F = GDL , where GDL (x) =
x − η∇L(x) is called the one-step Gradient Descent update and η > 0 is the learning
rate.
127
Algorithm 4 Gradient Descent for FC-NN (FC networks)
Require: Initial parameter distribution Pinit , total iterations T , training dataset
{zi , yi }ni=1 , loss function `
Ensure: Hypothesis h : Z → Y.
Sample x(0) ∼ Pinit .
for t = 0 to T − 1 do
n
x(t+1) = x(t) − η ∇`(FC-NN(x(t) )(zi ), yi )
P
i=1
return h = sign FC-NN[x(T ) ] .
Proof of Corollary 5.4.5. We will verify the three conditions required in Theorem 5.4.2
one by one.
Condition 1: This is the only place we use the FC structure.
Proof of Lemma 5.4.6. By definition, FC-NN[x](z) could be written FC-NN[x2:L ](σ(W1 z)),
which implies FC-NN[x](z) = FC-NN[W1 R−1 , x2:L ](Rz), ∀R ∈ O(d), and thus we can
pick τ (R) = O ∈ O(m), where O(x) = [W1 R−1 , x2:L ], and GW = τ (O(d)).
128
Proof of Lemma 5.4.7. By definition, it suffices to show that for each i ∈ [n], and
every x and x0 = Ox,
For any R ∈ O(d), and set O = τ (R) by Lemma 5.4.6, (L ◦ O−1 )[x] =
Pn −1
Pn
i=1 `(FC-NN[O (x)](zi ), yi ) = i=1 `(FC-NN[x](Rzi ), yi ). The second condition in
Corollary 5.4.8. FC nets trained with newton’s method from zero initialization for
the first layer and any initialization for the rest parameters is GL(d)-equivariant, or
equivariant under the group of invertible linear transformations.
Here, Netwon’s method means to use NT(x) = x − η(∇2 L(x))−1 ∇L(x) as the
update rule and we assume ∇2 L(x) is invertible.
Proof of Corollary 5.4.8. The proof is almost the same as that of Corollary 5.4.5,
except the following modifications.
Condition 1: If we replace the O(d), O(m) by GL(d), GL(m) in the statement
and proof Lemma 5.4.6, the lemma still holds.
Condition 2:By chain rule, one can verify the update rule Newton’s method is
invariant under invertible linear re-parametization, i.e. ONT[W ] = NTL◦O−1 [OW ], for
all invertible matrix O.
Condition 3: Since the first layer is initialized to be 0, it is invariant under any
linear transformation.
129
Remark 5.4.9. The above results can be easily extended to the case of momentum and
Lp regularization. For momentum, we only need to ensure that the following update
rule, x(t+1) = GDM(x(t) , x(t−1) , M, {zi , yi }ni=1 ) = (1 + γ)x(t) − γx(t−1) − η∇L(x(t) ),
also satisfies the property in Lemma 5.4.7. For Lp regularization, because kxkp is
independent of {zi , yi }ni=1 , we only need to ensure kxkp = kτ (R)(x)kp , ∀R ∈ GZ ,
which is easy to check when GZ only contains permutation or sign-flip.
To demonstrate the wide application of our lower bounds, we give two more examples
of algorithmic equivariance where the algorithm is not iterative. The proofs are
folklore.
GZ -equivariant.
130
5.5 Warm-up Examples and Proof Idea for Main
Results
Equivariant Methods
We start with a simple but insightful example to how equivariance alone could suffice
for some non-trivial lower bounds.
We consider a task on Rd × {±1} which is a uniform distribution on the set
{(ei y, y)|i ∈ {1, 2, . . . , d}, y = ±1}, denoted by P . Each sample from P is a one-hot
vector in Rd and the sign of the non-zero coordinate determines its label. Now imagine
our goal is to learn this task using an algorithm A. After observing a training set of n
labeled points S := {(zi , yi )}ni=1 , the algorithm is asked to make a prediction on an
unseen test data z, i.e., A(S)(z). Here we are concerned with orthogonal equivariant
algorithms ——the prediction of the algorithm on the test point remains the same
even if we rotate every zi and the test point x by any orthogonal matrix U , i.e.,
d
A({(U zi , yi )}ni=1 )(U z) = A({(zi , yi )}ni=1 )(z)
Now we show this algorithm fails to generalize on task P , if it observes only d/2
training examples. The main idea here is that, for a fixed training set S, the prediction
A({(zi , yi )}ni=1 )(z) is determined solely by the inner products between z and zi ’s due
to orthogonal equivariance, i.e., there exists a random function f (which may depend
on S) such that2
d
A({(zi , yi )}ni=1 )(z) = f (z> z1 , . . . , z> zn )
2
this can be made formal using the fact that Gram matrix determine a set of vectors up to an
orthogonal transformation.
131
But the input distribution for this task is supported on 1-hot vectors. Suppose n < d/2.
Then at test time the probability is at least 1/2 that the new data point (z, y) ∼ P ,
is such that z has zero inner product with all n points seen in the training set S. This
fact alone fixes the prediction of A to the value f (0, . . . , 0) whereas y is independently
and randomly chosen to be ±1. We conclude that A outputs the wrong answer with
probability at least 1/4.
This warm up example illustrates the main insight of Ng [82], namely, that when
an orthogonal equivariant algorithm is used to do learning on a certain task, it is
actually being forced to simultaneously learn all orthogonal transformations of this
task. Intuitively, this should make the learning much more sample-hungry compared
to even Simple SGD on ConvNets, which is not orthogonal equivariant. Now we sketch
why the obvious way to make this intuition precise using VC dimension (Theorem 5.3.3)
does not give a proper separation between ConvNets and FC nets, as mentioned in
the introduction.
hP i
d P2d
We first fix the ground truth labeling function h∗ (z) = sign 2
i=1 zi − 2
i=d+1 i .
z
Algorithm A is orthogonal equivariant (Definition 5.4.1) means that for any task
P = PZ h∗ , where PZ is the input distribution and h∗ is the labeling function, A must
have the same performance on P and its rotated version P ◦ U = (PZ ◦ U ) (h∗ ◦ U ),
where U can be any orthogonal matrix. Therefore if there is an orthogonal equivariant
learning algorithm A that learns h∗ on all distributions, then A will also learn
every the rotated copy of h∗ , h∗ ◦ U , on every distribution PZ , simply because A
learns h∗ on distribution PZ ◦ U −1 . Thus A learns the class of labeling functions
h∗ ◦ O(2d) := {h | h(z) = h∗ (U z), ∀U ∈ O(2d)} on all distributions. (See formal
statement in Theorem 5.6.2) By the standard lower bounds with VC dimension (See
Theorem 5.3.3), it takes at least Ω( VCdim(H◦O(2d))
ε
) samples for A to guarantee 1 − ε
132
accuracy. Thus it suffices to show the VC dimension VCdim(H ◦ O(2d)) = Ω(d2 ),
towards a Ω(d2 ) sample complexity lower bound. (Ng [82] picks a linear thresholding
function as h∗ , and thus VCdim(h∗ ◦ O(2d)) is only O(d).)
Formally, we have the following theorem, whose proof is deferred into Section 5.7.2:
As noted in the introduction, this doesn’t imply there is some task hard for every
training algorithm for the FC net. The VC dimension based lower bound implies for
each algorithm A the existence of a fixed distribution PZ ∈ P and some orthogonal
matrix UA such that the task (PZ ◦ UA−1 ) h∗ is hard for it. However, this does not
preclude (PZ ◦ UA−1 ) h∗ being easy for some other algorithm A0 .
At first sight, the issue highlighted above (and in the Introduction) seems difficult to
get around. One possible avenue is if the hard input distribution PZ in the task were
invariant under all orthogonal transformations, i.e., PZ = PZ ◦ U for all orthogonal
matrices U . Unfortunately, the distribution constructed in the proof of lower bound
with VC dimension is inherently discrete and cannot be made invariant to orthogonal
transformations.
Our proof uses a fixed PZ , the standard Gaussian distribution, which is indeed
invariant under orthogonal transformations. The proof also uses the Benedek-Itai’s
lower bound, Theorem 5.5.2, and the main technical part of our proof is the lower
bound for the the packing number D(H, ρ, ε) defined below (also see Equation (5.4)).
Let ρ be a metric on H, We define N (H, ρ, ε) as the ε-covering number of H w.r.t.
ρ, and D(H, ρ, ε) as the ε-packing number of H w.r.t. ρ. For distribution PZ , we use
133
ρZ (h, h0 ) := PX∼PZ [h(X) 6= h0 (X)] to denote the discrepancy between hypothesis h
and h0 w.r.t. PZ .
Theorem 5.5.2. [Benedek-Itai’s lower bound [88]] For any algorithm A that (ε, δ)-
learns H with n i.i.d. samples from a fixed distribution PZ , it must hold for every
Since ΠH (n) ≤ 2n , we have N (A, PZ H, ε, δ) ≥ log2 D(H, ρZ , 2ε) + log2 (1 − δ), which
is the original bound by Benedek and Itai [88]. Later Long [89] improved this bound
for the regime n ≥ VCdim(H) using Sauer-Shelah lemma, i.e.,
VCdim(H) 1
N (A, PZ , ε, δ) ≥ ((1 − δ)D(H, ρZ , 2ε)) VCdim(H) . (5.4)
e
Intuition behind Benedek-Itai’s lower bound. We first fix the data distribu-
tion as PZ . Suppose the 2ε-packing is labeled as {h1 , . . . , hD(H,ρZ ,2ε) } and ground truth
is chosen from this 2ε-packing, (ε, δ)-learns the hypothesis H means the algorithm
is able to recover the index of the ground truth w.p. 1 − δ. Thus one can think this
learning process as a noisy channel which delivers log2 D(H, ρZ , 2ε) bits of information.
Since the data distribution is fixed, unlabeled data is independent of the ground
truth, and the only information source is the labels. With some information-theoretic
inequalities, we can show the number of labels, or samples (i.e., bits of information)
N (A, PZ H, ε, δ) ≥ log2 D(H, ρZ , 2ε) + log2 (1 − δ). A more closer look yields Equa-
tion (5.4), because when VCdim(H) < ∞, then only log2 ΠH (n) instead of n bits
information can be delivered.
134
5.6 Main Results: Sample Complexity Lower
Below we first present a reduction from a special subclass of PAC learning to equivariant
learning (Theorem 5.6.2), based on which we prove our main separation results,
Theorem 5.5.1, 5.6.4, 5.6.5 and 5.6.6.
Lemma 5.6.1. Let A be the set of all algorithms and AGZ be the set of all GZ -
equivariant algorithms, the following inequality holds. The equality is attained when
GZ is a compact group.
Take infimum over AGZ over the both side of Equation (5.6), and note that AGZ ⊂ A,
Inequality (5.5) is immediate.
Suppose the group GZ is compact and let µ be the Haar measure on it, i.e.
∀S ⊂ GZ , g ∈ GZ , µ(S) = µ(g ◦ S). We claim for each algorithm A, the sample
complexity of the following equivariant algorithm A0 is no higher than that of A on
P GZ :
A0 ({zi , yi }ni=1 ) = A({g(zi ), yi }ni=1 ) ◦ g, where g ∼ µ.
135
By the definition of Haar measure, A0 is GZ -equivariant. Moreover, for any fixed
n ≥ 0, we have
inf E [errP (A0 ({zi , yi }ni=1 ))] = inf E E [errP (A({zi , yi }ni=1 ))]
P ∈P (zi ,yi )∼P P ∈P g∼µ (zi ,yi )∼P ◦g −1
≥ inf inf E [errP (A({zi , yi }ni=1 ))] = inf E [errP (A({zi , yi }ni=1 ))] ,
P ∈P g∈GZ (zi ,yi )∼P ◦g −1 P ∈P◦GZ (zi ,yi )∼P
Remark 5.6.3. The sample complexity in standard PAC learning is usually defined
again hypothesis class H only, i.e., PZ is the set of all the possible input distributions.
In that case, PZ is always invariant under group GZ , and thus Theorem 5.6.2 says
that GZ -equivariant learning against hypothesis class H is as hard as learning against
hypothesis H ◦ GZ without equivariance constraint.
a Fixed Distribution
In this subsection we show Ω(d2 ) vs O(1) separation on a single task in our main
theorem (Theorem 5.6.4). With the same proof technique, we further show we can
2
get correct dependency on ε for the lower bound, i.e., Ω( dε ), by considering a slightly
136
larger function class, which can be learnt by ConvNets with O(d) samples. We also
generalize this Ω(d2 ) vs O(d) separation to the case of `2 regression with a different
proof technique.
hP i
d P2d
Theorem 5.6.4. There’s a single task, PZ h∗ , where h∗ = sign 2
i=1 zi − 2
i=d+1 zi
and PZ = N (0, I2d ) and a constant ε0 > 0, independent of d, such that for any
orthogonal equivariant algorithm A, we have
Proof of Theorem 5.6.4. Upper bound: implied by upper bound in Theorem 5.5.1.
Lower bound: Note that the PZ = N (0, I2d ) is invariant under O(2d), by The-
orem 5.6.2, it suffices to show that there’s a constant ε0 > 0 (independent of d),
for any algorithm A, it takes Ω(d2 ) samples to learn the augmented function class
h∗ ◦ O(2d) w.r.t. PZ = N (0, I2d ). Define hU = sign z>
d×d
1:d U zd+1:2d , ∀U ∈ R , and
by Lemma 5.7.4, we have H = {hU | U ∈ O(d)} ⊆ h∗ ◦ O(2d). Thus it suffices to
show a Ω(d2 ) sample complexity lower bound for the function subclass H, i.e.,
d(d−1)
By Lemma 5.7.6, there’s some constant C, such that D(H, ρZ , ε) ≥ ( Cε ) 2 , ∀ε > 0.
137
kU −V kF
The high-level idea for Lemma 5.7.6 is to first show that ρZ (hU , hV ) ≥ Ω( √
d
),
and then we show the packing number of orthogonal matrices in a small neighborhood
k·kF
of Id w.r.t. √
d
is roughly the same as that in the tangent space of orthogonal manifold
d(d−1)
at Id , i.e., the set of skew matrices, which is of dimension 2
and has packing
d(d−1)
number ( Cε ) 2 . The advantage of working in the tangent space is that we can apply
the standard volume argument.
d(d−1)
Setting δ = 12 , we have that N ∗ (A, P, ε0 ) ≥ N (A, P, 12 , 2ε0 ) ≥ 2
log2 C
4ε0
−1 =
Ω(d2 ).
Indeed, we can improve the above lower bound by applying Equation (5.4), and
get
12 21 − 2d1
1 d2 1 d C 1 1
N (A, P, ε, ) ≥ = Ω(d2 ε− 2 + 2d ). (5.11)
2 e 2 ε
1 1
Note that the dependency in ε in Equation (5.11) is ε− 2 + 2d is not optimal, as
opposed to ε−1 in upper bounds and other lower bounds. A possible reason for
this might be that Theorem 5.5.2 (Long’s improved version) is still not tight and it
might require a tighter probabilistic upper bound for the growth number ΠH (n), at
least taking PZ into consideration, as opposed to the current upper bound using VC
2
dimension only. We left it as an open problem to show a single task P with Ω( dε )
sample complexity to achieve ε error for all orthogonal equivariant algorithms.
However, if the hypothesis class is of VC dimension O(d), using a similar idea, we
can prove a Ω(d2 /ε) sample complexity lower bound for equivariant algorithms, and
O(d) upper bounds for ConvNets.
such that for any orthogonal equivariant algorithms A and ε > 0, N ∗ (A, P, ε) =
Ω(d2 /ε), while there’s a 2-layer ConvNets architecture, such that N (ERMCNN , P, ε, δ) =
d log 1ε +log 1
O( ε
δ
).
138
Interestingly, we can show an analog of Theorem 5.6.5 for `2 regression, i.e., the
algorithm not only observes the signs but also the values of labels yi . Here we define
the `2 loss of function h : Rd → R as `P (h) = E [(h(z) − y)2 ] and the sample
(z,y)∼P
complexity N ∗ (A, P, ε) for `2 loss similarly as the smallest number n ∈ N such that
∀P ∈ P, E [`P (A({zi , yi }ni=1 ))] ≤ ε E [y 2 ]. The last term E [y 2 ] is added
(zi ,yi )∼P (x,y)∼P (x,y)∼P
for normalization to avoid the scaling issue and thus any ε > 1 could be achieved
trivially by predicting 0 for all data.
R} , such that for any orthogonal equivariant algorithms A and ε > 0, N ∗ (A, P, ε) ≥
d(d+3)
2
(1 − ε) − 1, while there’s a 2-layer ConvNet architecture CNN, such that
N ∗ (ERMCNN , P, ε) ≤ d for any ε > 0.
In this subsection we will present Ω(d) lower bound for permutation equivariance via
a different proof technique — direct coupling. The high-level idea of direct coupling
is to show with constant probability over (Zn , z), we can find a g ∈ GZ , such that
g(Zn ) = Zn , but z and g(z) has different labels, in which case no equivariant algorithm
could make the correct prediction.
Theorem 5.6.7. Let ti = ei +ei+1 and si = ei +ei+2 3 and P be the uniform distribution
on {(si , 1)}ni=1 ∪ {(ti , −1)}ni=1 , which is the classification problem for local textures in
a 1-dimensional image with d pixels. Then for any permutation equivariant algorithm
A, N (A, P, 18 , 18 ) ≥ N ∗ (A, P, 14 ) ≥ d
10
. Meanwhile, N (ERMCN N , P, 0, δ) ≤ log2 1δ + 2,
where ERMCN N stands for ERMCN N for function class of 2-layer ConvNets.
3
For vector z ∈ Rd , we define zi = z(i−1) mod d+1 if i ∈
/ [d].
139
Remark 5.6.8. The task could be understood as detecting if there are two consecutive
white pixels in the black background. For proof simplicity, we take texture of length
2 as an illustrative example. It is straightforward to extend the same proof to
more sophisticated local pattern detection problem of any constant length and to
2-dimensional images.
5.7 Proofs
Lemma 5.7.1.
p
∀x ∈ [−1, 1], arccos x ≥ 2(1 − x).
Proof of Lemma 5.7.1. Let x = cos(t). If t = 0, then both sides are equal to 0 and
the inequality holds. Otherwise if t ∈ (0, π], we have the following
arccos(x) t t √
√ =p =√ ≥ 2,
1−x 1 − cos(t) 2 sin(t/2)
√ √
C kM kF / d ≤ E [kM zk2 ] ≤ kM kF / d. (5.12)
z∼Sd−1
s r
tr[M M > ]
kM k
r
kM zk22 = = √ F.
E [kM zk2 ] ≤ E tr M E [zz> ] M > =
z∼Sd−1 z∼Sd−1 z∼Sd−1 d d
140
E [kM zk2 ] = E [kΣzk2 ], w.l.o.g., we only need to prove the lower bound for
z∼Sd−1 z∼Sd−1
all diagonal matrices.
By Proposition 2.5.1 in [90], there’s some constant C, such that
v v
u d u d
uX uX
C kΣkF = C t σi2 ≤ E t zi2 σi2 = E [kM zk]2 .
z∼N (0,Id ) z∼N (0,Id )
i=1 i=1
r √
kzk22 = d.
By Cauchy-Schwarz Inequality, we have E [kzk2 ] ≤ E
z∼N (0,Id ) z∼N (0,Id )
Therefore, we have that
2 z
Pr (|x| ≤ z) ≤ √
x∼N (0,σ) πσ
z
r
x2
Z
1 2z
Pr (|x| ≤ z) = √ exp − 2 dx ≤
x∼N (0,σ) −z 2π σ 2σ πσ
141
Proof of Lemma 5.7.4. Note that
0 U Id 0 0 Id Id 0
= · · ,
U> 0 0 U> Id 0 0 U
and
√ √ √ √
2 2 2 2
0 Id 2 Id − I
2 d
I
d 0 2 Id I
2 d
= √ √ · · √ √ ,
2 2
Id 0 2 d
I 2 d
I 0 −I d − 22 Id 2
2 d
I
> 0 U
hU (z) = sign z>
1:d U zd+1:2d = sign z z
U> 0
(5.14)
Id 0
=sign gU (z)> gU (z) ,
0 −Id
√ √
2 2
2 Id − I
2 d Id 0
where gU (z) = √ √ · · z is an orthogonal transformation on R2d .
2 2
2 d
I 2 d
I 0 U
Thus we conclude that hU ∈ h∗ ◦ O(2d).
Proof of Lemma 5.7.5. Now we claim H shatters {ei + ed+j }1≤i<j≤d , i.e. O(d) can
shatter {ei e>
j }1≤i<j≤d , or for any sign pattern {σij }1≤i<j≤d , there exists U ∈ O(d),
u2
exp(u) = Id + u + + · · · ∈ SO(d), ∀u ∈ so(d).
2
142
σij (ei e> > +
P
Thus for any sign pattern {σij }1≤i<j≤d , let u = j −ej ei ) and λ → 0 ,
1≤i<j≤d
it holds that sign exp(λu), ei e>
j = sign [0 + λσij + O(λ2 )] = sign [σij + O(λ)] =
σij .
Proof of Theorem 5.5.1. Lower bound: Suppose d = 2d0 for some integer d0 , we
construct P = PZ H, where PZ is the set of all possible distributions on Z = R3k ,
hP 0 P2d0 i
d 2 2 0
>
and H = {sign z
i=1 i − i=d0 +1 i }. By Lemma 5.7.4, H = {sign z1:d U zd+1:2d |
z
U ∈ O(d0 )} ⊆ H ◦ O(d). By Theorem 5.6.2, we have that
inf N ∗ (A, PZ H, ε) ≥ inf N ∗ (A, PZ (H◦GZ ), ε) ≥ inf N ∗ (A, PZ H0 , ε) (5.15)
A∈AGZ A∈A A∈A
143
For any convex non-increasing surrogate loss of 0-1 loss l satisfying l(0) ≥
1, limx→∞ l(x) = 0 e.g. logistic loss, we define the loss of the weight x as (zk,i is
the kth coordinate of zi )
n n 2 d0 ! !
X X X X
L(x) = l(FCNN [x](zi )yi ) = l ai x2(k−1)d0 +j,i w12 + b yi ,
i=1 i=1 k=1 j=1
ρ(U, V ) := ρZ (hU , hV ) = Pz∼N (0,I2d ) [hU (z) 6= hV (z)]. There exists a constant C,
d(d−1)
such that the packing number D(H, ρZ , ε) = D(O(d), ρ, ε) ≥ Cε 2
.
Proof of Lemma 5.7.6. The key idea here is to first lower bound ρZ (U, V ) by
√
kU − V kF / d and apply volume argument in the tangent space of Id in O(d). We
144
have that
z>
>
= P 1:d U zd+1:2d z1:d V zd+1:2d < 0
z∼N (0,I2d )
>
z1:d U V > z1:d
1
= E arccos
π z1:d ∼N (0,Id ) kz1:d k2
"s #
> >
1 z U V z1:d (5.16)
≥ E 2 − 2 1:d (by Lemma 5.7.1)
π z1:d ∼N (0,Id ) kz1:d k2
1 hp
>U V >z
i
= E 2 − 2z
π z∼Sd−1
1
(U > − V > )z F
= E
π z∼Sd−1
√
≥C1 kU − V kF / d (by Lemma 5.7.2)
Lemma 5.7.7. [91, Implication of Lemma 4] For any matrix A, B ∈ so(d), satisfying
that kAk∞ ≤ π4 , kBk∞ ≤ π4 , we have
√ π d2 √
D(H, ρZ , ε) ≥ D(O(d), C1 k·kF / d, ε) ≥ D(so(d)∩ B∞ , C1 k·kF / d, 2.5ε). (5.18)
4
145
d(d−1) 2
Note that so(d) is a 2
-dimensional subspace of Rd , by Inverse Santalo’s
inequality (Lemma 3, [92]), we have that
2
! d(d−1)
2
p
d
vol(so(d) ∩ B∞ ) dim(so(d))
d2
≥ C2 .
vol(so(d) ∩ B2 ) E Πso(d) (G) ∞
G∼N (0,Id2 )
d(d−1) G−G>
where vol(·) is the 2
volume defined in the space of so(d) and Πso(d) (G) = 2
is the projection operator onto the subspace so(d). We further have that
G − G> √
E Πso(d) (G) ∞
= E ≤ E [kGk∞ ] ≤ C3 d,
G∼N (0,Id2 ) G∼N (0,Id2 ) 2 ∞ G∼N (0,Id2 )
π d2 √
D(so(d) ∩ B∞ , C1 k·kF / d, 2.5ε)
4 √
d2 10 dε
=D(so(d) ∩ B∞ , k·kF , )
C1 π
d2
d(d−1)
vol(so(d) ∩ B∞ ) C1 π 2
≥ 2 × √
vol(so(d) ∩ B2d ) 10 dε (5.19)
q d(d−1)
2
C1 C2 π d(d−1)
2
≥
10dε
d(d−1)
C 2
:=
ε
146
Ω(d2 /ε), while there’s a 2-layer ConvNets architecture, such that N (ERMCNN , P, ε, δ) =
d log 1ε +log 1
O( ε
δ
).
n n d
!
X X X
L(x) = l(CNN[x](zi )yi ) = l ( 2
w12 ai zk,i + b)yi ,
i=1 i=1 k=1
A and ε > 0,
d2
N ∗ (A, {N (0, I4d )} H, ε) = Ω( ).
ε
2
1 d
Proof of Lemma 5.7.8. Below we will prove a Ω( ε
) lower bound for packing num-
ber, i.e. D(H, ρZ , 2ε0 ) = D(Rd×d , ρ, 2ε0 ), where ρ(U, V ) = ρZ (hU , hV ). Then we can
apply Long’s improved version Equation (5.4) of Benedek-Itai’s lower bound and get a
Ω(d2 /ε) sample complexity lower bound. The reason that we can get the correct rate
of ε is that the VCdim(H) is exactly equal to the exponent of the packing number. (cf.
the proof of Theorem 5.6.4)
Similar to the proof of Theorem 5.6.4, the key idea here is to first lower bound
√
ρ(U, V ) by kU − V kF / d and apply
volume
argument. Recall for A ∈ Rd×d , we
A 0
define MA ∈ R2d×2d as MA = , and hA : R4d → {−1, 1} as hA (z) =
0 Id
>
sign z1:2d MA z2d+1:4d . Then for H = {hA | ∀A ∈ Rd×d } . Below we will see it
2
suffices to lower bound the packing number of a subset of Rd×d , i.e. Id + 0.1B∞
d
,
d 2 d 2
where B∞ is the unit spectral norm ball. Clearly ∀z, kzk2 = 1, ∀U ∈ Id + 0.1B∞ ,
0.9 ≤ kU zk2 ≤ 1.1.
148
d 2
Thus, ∀U, V ∈ Id + 0.1B∞ we have that,
z> z>
= P 1:2d MV z2d+1:4d < 0
1:2d MU z2d+1:4d
z∼N (0,I4d )
" !#
1 z> M
1:2d U M >
z
V 1:2d
= E arccos
π z1:2d ∼N (0,I2d ) MU z1:2d 2 MV> z1:2d 2
>
"s #
1 z> M
1:2d U M >
z
V 1:2d
≥ E 2−2 (by Lemma 5.7.1)
π z1:2d ∼N (0,I2d ) MU z1:2d 2 MV> z1:2d 2
>
√ q
2 > > > >
≥ E MU z1:2d 2 MV z1:2d 2 − z1:2d MU MV z1:2d
1.1π z1:2d ∼N (0,I2d )
q
1 > > 2 > >
2
= E (MU − MV )z1:2d 2 − MU z1:2d 2 − MV z1:2d 2
1.1π z1:2d ∼N (0,I2d )
1
(MU> − MV> )z1:2d 2
≥ ( E
1.1π z1:2d ∼N (0,I2d )
It remains to lower bound the packing number. We have the following for some
constant C:
d2 d2
√ d2
d2 vol(B∞ ) 0.1C1 C
M(0.1B∞ , C1 k·kF / d, ε) ≥ × √ ≥ , (5.20)
d2 ε
vol(B2 ) dε
The proof is completed by plugging the above bound and VCdim(H) = d2 into
Equation (5.4).
q q
E [k(R − S)zk2 ] − E kRzk22 + kyk22 2 2
− kSzk2 + kyk2 ≥ C0 E [k(R − S)zk2 ] ,
z z,y z
149
for some constants C0 independent of R, S and d.
q q
kRzk2 + kyk2 − kSzk22 + kyk22
2 2
kRzk2 + kSzk2
= |kRzk2 − kSzk2 | q q
kRzk2 + kyk2 + kSzk22 + kyk22
2 2
kRzk2 + kSzk2
≤ k(R − S)zk2 q q
kRzk22 + kyk22 + kSzk22 + kyk22
Let F (x, d) be the cdf of chi-square distribution, i.e. F (x, d) = Pz kzk22 ≤ x . Let
z = xd , we have F (zd, d) ≤ (ze1−z )d/2 ≤ (ze1−z )1/2 . Thus Py kyk22 ≤ d/2 < 1, which
√
implies for any kzk2 ≤ 10 d,
q q
2 2 2 2
E kRzk2 + kyk2 − kSzk2 + kyk2
y
kRzk2 + kSzk2
≤ k(R − S)zk2 E q q
y 2 2 2 2
kRzk2 + kyk2 + kSzk2 + kyk2
q q
2 2 2 2
E [k(R − S)zk2 ] − E kRzk2 + kyk2 − kSzk2 + kyk2
z z,y
h h √ ii
≥ E k(R − S)zk2 1 kzk ≤ 10 d
z
√ i
q q h
2 2 2 2
−E kRzk2 + kyk2 − kSzk2 + kyk2 1 kzk2 ≤ 10 d
z,y
h h √ ii
≥α1 E k(R − S)zk2 1 kzk2 ≤ 10 d
z
150
for some constant α2 > 0. Here we use the other side of the tail bound of cdf of
chi-square, i.e. for z > 1, 1 − F (zd, d) < (ze1−z )d/2 < (ze1−z )1/2 .
>
(z M z)2
E
z∼N (0,Id )
" #
X
= E zi zj zi0 zj 0 Mij Mi0 j 0
z∼N (0,Id )
i,j,i0 j 0
2 2 X 2
X
(Mij2
4
= + Mij Mji + Mii Mjj ) E x + Mii E x
x∼N (0,1) x∼N (0,1)
i6=j i
X X
= (Mij2 + Mij Mji + Mii Mjj ) + 3 Mii2
i6=j i
> 2
M +M
= + (tr[M ])2
2 F
R} , such that for any orthogonal equivariant algorithms A and ε > 0, N ∗ (A, P, ε) ≥
d(d+3)
2
(1 − ε) − 1, while there’s a 2-layer ConvNet architecture CNN, such that
N ∗ (ERMCNN , P, ε) ≤ d for any ε > 0.
Proof of Theorem 5.6.6. Lower bound: Similar to the proof of Theorem 5.6.5, it
suffices to for any algorithm A, N ∗ (A, H ◦ O(d), ε) ≥ d(d+3) 2
(1 − ε) − 1. Note that
P
H◦O(d) = { i,j βij zi zj | βij ∈ R} is the set of all quadratic functions. For convenience
we denote hM (z) = z> M z, ∀M ∈ Rd×d . Now we claim quadratic functions such that
151
d(d+1)
any learning algorithm A taking at most n samples must suffer 2
− n loss if the
ground truth quadratic function is sampled from i.i.d. gaussian. Moreover, the loss
d(d+3)
is at most 2
for the trivial algorithm always predicting 0. In other words, if the
d(d+1)
−n
expected relative error ε ≤ 2
d(d+3) , we must have the expected sample complexity
2
d(d+3)
N ∗ (A, P, ε) ≥ n. That is N ∗ (A, P, ε) ≥ 2
(1 − ε) − 1.
(1). Upper bound for E [y 2 ]. By Lemma 5.7.10,
" 2
#
2 M + M>
E E y = E + (tr[M ])2
M ∼N (0,Id2 ) z∼PZ ,y=z> M z M ∼N (0,Id2 ) 2 F
d(d − 1) d(d + 3)
=d + d + = .
2 2
inf E E [`P (A({zi , yi }ni=1 ))]
A M ∼N (0,I 2 ) (zi ,yi )∼PZ hM
d
([A({zi , yi }ni=1 )](z) 2
= inf E E E − y)
A M ∼N (0,I 2 ) (zi ,yi )∼PZ hM z,y∼PZ ◦hM
d
n 2
= inf E E E ([A({zi , hM (zi )}i=1 )](z) − hM (z))
A M ∼N (0,I 2 ) zi ∼PZ z∼PZ
d
where the inequality is achieved when [A({zi , yi }ni=1 )](z) = E [hM (z) | {zi , yi }ni=1 ].
M
n
Thus it suffices to lower bound VarM [hM (z) | {hM (zi )}i=1 ], for fixed {zi }ni=1 and
z. For convenience we define Sd = {A ∈ Rd×d | A = A> } be the linear space
of all d × d symmetric matrices, where the inner product hA, Bi := tr[A> B] and
Πn : Rd×d → Rd×d as the projection operator for the orthogonal complement of the
152
n-dimensional space spanned by zi z> d
i in S . By definition, we can expand
n
X
>
zz = αi zi z> >
i + Πn (zz ).
i=1
n
X
>
hM (z) = tr[zz ] = αi tr[zi z> >
i M ] + tr[Πn (zz )M ],
i=1
2
still follows a gaussian distribution, N (0, Πn (zz> ) F
).
Note we can always find symmetric matrices Ei with kEi kF = 1 and tr[Ei> Ej ] = 0
such that Πn (A) = ki=1 Ei tr[Ei> A], where the rank of Πn , is at least d(d+1)
P
2
− n. Thus
we have that
2
h i k
2
X
E Πn (zz> ) F
=E Ei tr[Ei> zz> ]
z z
i=1 F
k h i
2
X
= E Ei tr[Ei> zz> ] F
z
i=1
Xk
> > 2
= E (z Ei z) (by Lemma 5.7.10)
z
i=1
Xk
≥ kEi kF2 ≥ k
i=1
d(d + 1)
≥ −n
2
inf E E [`P (A({zi , yi }ni=1 ))]
A M ∼N (0,I 2 ) (zi ,yi )∼PZ hM
d
153
d(d + 1)
≥ − n.
2
Upper bound: We use the same CNN construction as in the proof of The-
nP o
d 2 2
orem 5.6.5, i.e., the function class is FCNN = i=1 a i w z
1 i + b|a i , w 1 , b ∈ R =
nP o
d 2 2 2 2
i=1 ai zi + b|ai , b ∈ R . Thus given d + 1 samples, w.p. 1, (z1 , z2 , . . . , zd , 1) will be
linear independent, which means ERMCNN could recover the ground truth and thus
have 0 loss.
Theorem 5.6.7. Let ti = ei +ei+1 and si = ei +ei+2 4 and P be the uniform distribution
on {(si , 1)}ni=1 ∪ {(ti , −1)}ni=1 , which is the classification problem for local textures in
a 1-dimensional image with d pixels. Then for any permutation equivariant algorithm
A, N (A, P, 18 , 18 ) ≥ N ∗ (A, P, 14 ) ≥ d
10
. Meanwhile, N (ERMCN N , P, 0, δ) ≤ log2 1δ + 2,
where ERMCN N stands for ERMCN N for function class of 2-layer ConvNets.
4
For vector z ∈ Rd , we define zi = z(i−1) mod d+1 if i ∈
/ [d].
154
Given Zn , yn , we define B := {d(z, zk ) ≥ 3, ∀k ∈ [n]} and we have P [B] =
d
d− 10 ∗5
Pz [d(z, zk ) ≥ 3, ∀k ∈ [n]] ≥ d
= 12 . Therefore, we have
∗
where = uses the Definition 5.4.1.
Thus for any permutation equivariant algorithm A, N ∗ (A, {P }, 14 ) ≥ d
10
.
Upper Bound: Take CNN as defined in Section 5.3.2 with d0 = d, r = 1, k = 2, σ :
Rd → R, σ(z) = di=1 zi2 , we have
P
( " d
#)
X
FCNN = z → sign a1 (w1 zi−1 + w2 zi−2 )2 + b|a1 , w1 , w2 , b ∈ R .
i=1
−n −n+1
P [∀i ∈ [n], zi ∈ {sj | j ∈ [d]}]+ P [∀i ∈ [n], zi ∈ {tj | j ∈ [d]}] = 2 × 2 = 2 .
155
equivariant algorithm for FC nets, thus it’s a valid separation even we restrict the
discussion to training algorithm freezing the first layer.)
For any convex non-increasing surrogate loss of 0-1 loss l satisfying l(0) ≥
1, limx→∞ l(x) = 0 e.g. logistic loss, we define the loss of the weight x as
n
X
L(x) = l(CNN[x](zi )yi )
i=1
6 0 with probability 1, which means the data are separable even with
Note w1 w2 =
fixed first layer, i.e. inf a1 ,b L(x) = 0. Further note L(x) is convex in a1 and b, which
implies with sufficiently small step size, GD converges to 0 loss solution. By the
definition of surrogate loss, L(x) < 1 implies for zi , l(zi yi ) < 1 and thus the training
error is 0.
156
Chapter 6
157
rank minimization is more likely to take effect for initialization with practical scale.
Interestingly, despite there is a separation between depth equal to 2 and depth larger
than 3, it turns out that being deeper than 3 (e.g., increasing depth to infinity) has
only marginal value on the implicit bias.
6.1 Introduction
There are usually far more learnable parameters in deep neural nets than the number
of training data, but still deep learning works well on real-world tasks. Even with
explicit regularization, the model complexity of state-of-the-art neural nets is so large
that they can fit randomly labeled data easily [4]. Towards explaining the mystery of
generalization, we must understand what kind of implicit regularization does Gradient
Descent (GD) impose during training. Ideally, we are hoping for a nice mathematical
characterization of how GD constrains the set of functions that can be expressed by a
trained neural net.
As a direct analysis for deep neural nets could be quite hard, a line of works
turned to study the implicit regularization on simpler problems to get inspirations, for
example, low-rank matrix factorization, a fundamental problem in machine learning
and information process. Given a set of observations about an unknown matrix
W ∗ ∈ Rd×d of rank r∗ d, one needs to find a low-rank solution W that is compatible
with the given observations. Examples include matrix sensing, matrix completion,
phase retrieval, robust principal component analysis, just to name a few (see Chi
et al. 94 for a survey). When W ∗ is symmetric and positive semidefinite, one way
to solve all these problems is to parameterize W as W = U U > for U ∈ Rd×r and
optimize L(U ) := 12 f (U U > ), where f ( · ) is some empirical risk function depending
on the observations, and r is the rank constraint. In theory, if the rank constraint
is too loose, the solutions do not have to be low-rank and we may fail to recover
158
W ∗ . However, even in the case where the rank is unconstrained (i.e., r = d), GD
with small initialization can still get good performance in practice. This empirical
observation reveals that the implicit regularization of GD exists even in this simple
matrix factorization problem, but its mechanism is still on debate. Gunasekar et al.
[16] proved that Gradient Flow (GD with infinitesimal step size, a.k.a., GF) with
infinitesimal initialization finds the minimum nuclear norm solution in a special case
of matrix sensing, and further conjectured this holds in general.
Conjecture 6.1.1 (Gunasekar et al. 16, informal). With sufficiently small initializa-
tion, GF converges to the minimum nuclear norm solution of matrix sensing.
Subsequently, Arora et al. [95] challenged this view by arguing that a simple
mathematical norm may not be a sufficient language for characterizing implicit
regularization. One example illustrated in Arora et al. [95] is regarding matrix sensing
with a single observation. They showed that GD with small initialization enhances
the growth of large singular values of the solution and attenuates that of smaller
ones. This enhancement/attenuation effect encourages low-rank, and it is further
intensified with depth in deep matrix factorization (i.e., GD optimizes f (U1 · · · UL )
for L ≥ 2). However, these are not captured by the nuclear norm alone. Gidel et al.
[96], Gissin et al. [97] further exploited this idea and showed in the special case of
full-observation matrix sensing that GF learns solutions with gradually increasing
rank. Razin and Cohen [98] showed in a simple class of matrix completion problems
that GF decreases the rank along the trajectory while any norm grows towards infinity.
More aggressively, they conjectured that the implicit regularization can be explained
by rank minimization rather than norm minimization.
Our Contributions. In this chapter, we move one further step towards resolving
the implicit regularization in the matrix factorization problem. Our theoretical results
show that GD performs rank minimization via a greedy process in a broader setting.
159
Specifically, we provide theoretical evidence that GF with infinitesimal initialization
is in general mathematically equivalent to another algorithm called Greedy Low-Rank
Learning (GLRL). At a high level, GLRL is a greedy algorithm that performs rank-
constrained optimization and relaxes the rank constraint by 1 whenever it fails to
reach a global minimizer of f ( · ) with the current rank constraint. As a by-product,
we refute Conjecture 6.1.1 by demonstrating an counterexample (Example 6.5.9).
We also extend our results to deep matrix factorization Section 6.6, where we
prove that the trajectory of GF with infinitesimal identity initialization converges to
a deep version of GLRL, at least in the early stage of the optimization. We also use
this result to confirm the intuition achieved on toy models [97], that benefits of depth
in matrix factorization is to encourage rank minimization even for initialization with
a relatively larger scale, and thus it is more likely to happen in practice. This shows
that describing the implicit regularization using GLRL is more expressive than using
the language of norm minimization. We validate all our results with experiments
in Section 6.8.
Norm Minimization. The view of norm minimization, or the closely related view of
margin maximization, has been explored in different settings. Besides the nuclear norm
minimization for matrix factorization [16] discussed in the introduction, previous works
have also studied the norm minimization/margin maximization for linear regression
[13, 99–103], deep linear neural nets [104, 105], homogeneous neural nets [106, 107],
ultra-wide neural nets [108–110].
6.3 Preliminaries
Notations. For two matrices A, B, we define hA, Bi := Tr(AB > ) as their inner
product. We use kAkF , kAk∗ and kAk2 to denote the Frobenius norm, nuclear norm
and the largest singular value of A respectively. For a matrix A ∈ Rd×d , we use
λ1 (A), . . . , λd (A) to denote the eigenvalues of A in decreasing order (if they are all
reals). We define Sd as the set of symmetric d × d matrices and S+
d ⊆ Sd as the set of
is a notable special case of matrix sensing in which every measurement has the form
Xi = epi e>
qi , where {e1 , · · · , ed } stands for the standard basis (i.e., exactly one entry is
dU
= −∇L(U ) = −∇f (U U > )U. (6.1)
dt
Let W (t) = U (t)U (t)> ∈ Rd×d . Then the following end-to-end dynamics holds for
W (t):
dW
= −W ∇f (W ) − ∇f (W )W =: g(W ). (6.2)
dt
162
We use φ(W0 , t) to denote the matrix W (t) in (6.2) when W (0) = W0 0. Throughout
this chapter, we assume φ(W0 , t) exists for all t ∈ R, W0 0. It is easy to prove that
U is a stationary point of L( · ) (i.e., ∇L(U ) = 0) iff W = U U > is a critical point of
(6.2) (i.e., g(W ) = 0); see Lemma 6.10.1 for a proof. If W is a minimizer of f ( · ) in
S+
d (i.e., W is a minimizer of (P)), then W is a critical point of (6.2), but the reverse
Before introducing our main results, we illustrate how GD performs greedy learning
using two warmup examples.
v1 , . . . , vd are orthogonal to each other. Then we can write the solution as:
Xd
U (t) = etQ U (0) = eµi t vi vi> U (0). (6.3)
i=1
163
When µ1 > µ2 , the ratio between eµ1 t and eµi t for i 6= 1 increases exponentially fast.
As t → +∞, U (t) and W (t) become approximately rank-1 as long as vi> U (0) 6= 0, i.e.,
lim e−µ1 t U (t) = v1 v1> U (0), lim e−2µ1 t W (t) = (v1> W (0)v1 )v1 v1> . (6.4)
t→∞ t→∞
The analysis for the simple linear case reveals that GD encourages low-rank through
a process similar to power iteration. However, f (W ) is non-linear in general, and the
linear approximation is close to f (W ) only if W is very small. With sufficiently small
initialization, we can imagine that GD still resembles the above power iteration in the
early phase of the optimization. But what if W (t) grows to be so large that the linear
approximation is far from the actual f (W )?
dU dW
= (W ∗ − U U > )U, = (W ∗ − W )W + W (W ∗ − W ). (6.5)
dt dt
Pd
Let W ∗ := i=1 µi vi vi> be the eigendecomposition of W ∗ . Our previous analysis
shows that the dynamics is approximately dU
dt
= W ∗ U in the early phase and thus
encourages low-rank.
√
To get a sense for the later phases, we simplify the setting by specifying U (0) = αI
for a small number α. We can write W (0) and W ∗ as diagonal matrices W (0) =
diag(α, α, · · · , α), W ∗ = diag(µ1 , µ2 , · · · , µd ) with respect to the basis v1 , . . . , vd . It is
easy to see that W (t) is always a diagonal matrix, since the time derivatives of non-
diagonal coordinates stay 0 during training. Let W (t) = diag(σ1 (t), σ2 (t), · · · , σd (t)),
d
then σi (t) satisfies the dynamical equation σ (t)
dt i
= 2σi (t)(µi − σi (t)), and thus
164
αµi
σi (t) = α+(µi −α)e−2µi t
. This shows that every σi (t) increases from α to µi over time.
As α → 0, every σi (t) has a sharp transition from near 0 to near µi at time roughly
( 2µ1 i + o(1)) log α1 , which can be seen from the following limit:
c ∈ (− 2µ1 i , 0),
αµi 0
lim σi ( 2µ1 i + c) log(1/α) = lim =
α→0 α→0 α + (µi − α)α1+2cµi
µ i
c ∈ (0, +∞).
This means for every q ∈ ( 2µ1 i , 2µ1i+1 ) for i = 1, . . . , d − 1 (or q ∈ ( 2µ1 i , +∞) for
i = d), limα→0 W (q log(1/α)) = diag(µ1 , µ2 , . . . , µi , 0, 0, · · · , 0). Therefore, when the
initialization is sufficiently small, GF learns each component of W ∗ one by one,
according to the relative order of eigenvalues. At a high level, this shows a greedy
nature of GD: GD starts learning with simple models; whenever it underfits, it increases
the model complexity (which is rank in our case). This is also called sequential learning
or incremental learning in the literature [96, 97].
However, it is unclear how and why this sequential learning/incremental learning
can occur in general. Through the first warmup example, we may understand why GD
learns a rank-1 matrix in the early phase, but does GD always learn solutions with
rank 2, 3, 4, . . . sequentially? If true, what is the mechanism behind this? The current
paper answers the questions by providing both theoretical and empirical evidence that
the greedy learning behavior does occur in general with a similar reason as for the
first warmup example.
165
6.5 Main Results: Equivalence between Gradi-
(GLRL)
λ1 (−∇f (Wr )) ≤ 0 (see Lemma 6.10.2), then GLRL returns Wr ; otherwise GLRL
enters phase r + 1.
return Wr
166
To set the initial point of GD in phase r, GLRL appends a small column vector
δr ∈ Rd to the resulting stationary point Ur−1 (∞) from the last phase, i.e., Ur (0) ←
[Ur−1 (∞) δr ] ∈ Rd×r (in the case of r = 1, U1 (0) ← [δ1 ] ∈ Rd×1 ). In this way,
Ur (0)Ur> (0) = Wr−1 + δr δr> is perturbed away from the (r − 1)-th critical point. In
√
GLRL, we set δr = ur , where ur is the top eigenvector of −∇f (Wr ) with unit
norm kur k2 = 1, and > 0 is a parameter controlling the magnitude of perturbation
(preferably very small). Note that it is guaranteed that λ1 (−∇f (Wr−1 )) > 0; otherwise
Wr−1 is a minimizer of the convex function f ( · ) in S+
d and GLRL exits before phase
r. Expanding f ( · ) around Wr−1 shows that the loss is decreasing in this choice of δr .
1 1
L(Ur (0)) = f (Wr−1 + δr δr> ) = L(Ur−1 (∞)) + δr> ∇f (Wr−1 )δr + O(kδr k42 )
2 2
= L(Ur−1 (∞)) − λ1 (−∇f (Wr−1 )) + O(2 ).
2
Definition 6.5.1 (Trajectory of GLRL). Let W 0, := 0 be the 0th critical point
of GLRL. For every r ≥ 1, if the (r − 1)-th critical point W r−1, exists and is
not a minimizer of f ( · ) in S+ G >
d , we define Wr, (t) := φ(W r−1, + ur, ur, , t), where
Throughout this chapter, we always focus on the case where the top eigenvalue of
every ∇f (W r−1, ) is unique. In this case, the trajectory of GLRL is unique for every
> 0, since the normalized top eigenvectors can only be ±ur, , and both of them lead
G
to the same Wr, (t).
167
Comparison to existing greedy algorithms for rank-constrained optimiza-
tion. The most related one to GLRL (Algorithm 5) is probably Rank-1 Matrix
Pursuit (R1MP) proposed by Wang et al. [114] for matrix completion, which was later
generalized to general convex loss in [115]. R1MP maintains a set of rank-1 matrices
as the basis, and in phase r, R1MP adds the same ur u> r as defined in Algorithm 5 into
Pr
its basis and solve minα f ( i=1 αi ui u>
i ) for rank-r estimation. The main difference
between R1MP and GLRL is that the optimization in each phase of R1MP is performed
on the coefficients α, while the entire Ur evolves with GD in each phase of GLRL. In
Figure 6.5, we provide empirical evidence that GLRL generalizes better than R1MP
when ground truth is low-rank, although GLRL may have a higher computational
cost depending on η, .
Similar to R1MP, Greedy Efficient Component Optimization (GECO, Shalev-
Shwartz and Singer 116) also chooses the r-th component of its basis as the top
eigenvector of −∇f (Wr ), while it solves minβ f ( 1≤i,j≤r βij ui u>
P
j ) for the rank-r es-
timation. Khanna et al. [117] provided convergence guarantee for GECO assuming
strong convexity. Haeffele and Vidal [118] proposed a local-descent meta algorithm, of
which GLRL can be viewed as a specific realization.
namical System
To prove the equivalence between GF and GLRL, we first introduce our high-level
idea by analyzing the behavior of a more general dynamical system around its critical
point, say 0. A specific example is (6.2) if we set x to be the vectorization of W .
dx
= g(x), where g(0) = 0. (6.6)
dt
168
We use φ(x0 , t) to denote the value of x(t) in the case of x(0) = x0 . We assume that g(x)
is C 2 -smooth with J(x) being the Jacobian matrix and φ(x0 , t) exists for all x0 and t.
For ease of presentation, in the main text we assume J(0) is diagonalizable over R and
defer the same result for the general case into Section 6.12.3. Let J(0) = Ṽ D̃Ṽ −1 be the
eigendecomposition, where Ṽ is an invertible matrix and D̃ = diag(µ̃1 , . . . , µ̃d ) is the
diagonal matrix consisting of the eigenvalues µ̃1 ≥ µ̃2 ≥ · · · ≥ µ̃d . Let Ṽ = (ṽ1 , . . . , ṽd )
and Ṽ −1 = (ũ1 , . . . , ũd )> , then ũi , ṽi are left and right eigenvectors associated with µ̃i
Pd
and ũ>
i ṽj = δij . We can rewrite the eigendecomposition as J(0) =
>
i=1 µ̃i ṽi ũi .
We also assume the top eigenvalue µ̃1 is positive and unique. Note µ̃1 > 0 means
the critical point x = 0 is unstable, and in matrix factorization it means 0 is a strict
saddle point of L( · ).
The key observation is that if the initialization is infinitesimal, the trajectory is
almost uniquely determined. To be more precise, we need the following definition:
Definition 6.5.2. For any x0 ∈ Rd and u ∈ Rd , we say that {xα }α∈(0,1) converges to
D E
−x0
x0 with positive alignment with u if lim xα = x0 and lim inf kxxαα−x 0 k2
, u > 0.
α→0 α→0
xα −x0
A special case is that the direction of xα − x0 converges, i.e., x̄ := limα→0 kxα −x0 k2
exists. In this case, {xα } has positive alignment with either u or −u except for a
zero-measure subset of x̄. This means any convergent sequence generically falls into
either of these two categories.
The following theorem shows that if the initial point xα converges to 0 with
positive alignment with ũ1 as α → 0, the trajectory starting with xα converges
1
to a unique trajectory z(t) := φ(αṽ1 , t + µ̃1
log α1 ). By symmetry, there is another
unique trajectory for sequences {xα } with positive alignment to −ũ1 , which is z 0 (t) :=
1
φ(−αṽ1 , t + µ̃1
log α1 ). This is somewhat surprising: different initial points should lead
to very different trajectories, but our analysis shows that generically there are only
two limiting trajectories for infinitesimal initialization. We will soon see how this
theorem helps in our analysis for matrix factorization in Sections 6.5.2 and 6.5.3.
169
1
Theorem 6.5.3. Let zα (t) := φ(αṽ1 , t + µ̃1
log α1 ) for every α > 0, then z(t) :=
limα→0 zα (t) exists and is also a solution of (6.6), i.e., z(t) = φ(z(0), t). If δα converges
to 0 with positive alignment with ũ1 as α → 0, then ∀t ∈ R, there is a constant C > 0
such that
γ̃
1 1 µ̃1 +γ̃
φ δα , t + µ̃1
log hδα ,ũ1 i
− z(t) ≤ C · kδα k2 , (6.7)
2
for every sufficiently small α, where γ̃ := µ̃1 − µ̃2 > 0 is the eigenvalue gap.
Proof sketch. The main idea is to linearize the dynamics near origin as we have done
for the first warmup example. For sufficiently small x, by Taylor expansion of g(x),
dx
the dynamics is approximately dt
≈ J(0)x, which can be understood as a continuous
version of power iteration. If the linear approximation is exact, then x(t) = etJ(0) x(0).
For large enough t0 , et0 J(0) = di=1 eµ̃i t0 ṽi ũ> µ̃1 t0
ṽ1 ũ> µ̃2 t0
P
i = e 1 + O(e ). Therefore, as
long as the initial point x(0) has a positive inner product with ũ1 , x(t0 ) should be very
close to ṽ1 for some > 0, and the rest of the trajectory after t0 should be close to
the trajectory starting from ṽ1 . However, here is a tradeoff: we should choose t0 to be
large enough so that the power iteration takes effect; but if t0 is so large that the norm
of x(t0 ) reaches a constant scale, then the linearization fails unavoidably. Nevertheless,
if the initialization scale is sufficiently small, we show via a careful error analysis that
there is always a suitable choice of t0 such that x(t0 ) is well approximated by ṽ1 and
the difference between x(t0 + t) and φ(ṽ1 , t) is bounded as well. We defer the details
to Section 6.12.
Now we establish the equivalence between GF and GLRL in the first phase. The main
idea is to apply Theorem 6.5.3 on (6.2). For this, we need the following lemma on the
eigenvalues and eigenvectors.
170
Lemma 6.5.4. Let g(W ) := −W ∇f (W ) − ∇f (W )W and J(W ) be its Jacobian.
Then J(0) is symmetric and thus diagonalizable. Let −∇f (0) = di=1 µi u1[i] u>
P
1[i] be
d X
X d
J(0)[∆] = (µi + µj ) ∆, u1[i] u> >
1[j] u1[i] u1[j] , (6.8)
i=1 j=1
where J(0)[∆] stands for the resulting matrix produced by left-multiplying J(0) to the
vectorization of ∆. For every pair of 1 ≤ i ≤ j ≤ d, µi + µj is an eigenvalue of J(0)
and u1[i] u> >
1[j] + u1[j] u1[i] is a corresponding eigenvector. All the other eigenvalues are 0.
below states that, for every fixed time t, the GF solution φ(Wα , T (Wα ) + t) after
shifting by a time offset T (Wα ) := 1
2µ1
log(hWα , u1 u> −1
1 i ) converges to the GLRL
solution W1G (t) as Wα → 0. The only assumption for this result is that 0 is not a
minimizer of f ( · ) in S+
d (which is equivalent to λ1 (−∇f (0)) > 0) and −∇f (0) has
an eigenvalue gap. In the full observation case, this assumption is satisfied easily if
the ground-truth matrix has a unique top eigenvalue. The proof for Theorem 6.5.6 is
deferred to Section 6.14.1.
Assumption 6.5.5. µ1 > max{µ2 , 0}, where µ1 := λ1 (−∇f (0)), µ2 := λ2 (−∇f (0)).
Theorem 6.5.6. Under Assumption 6.5.5, the following limit W1G (t) exists and is a
solution of (6.2).
W1G (t) := lim W1,
G 1
2µ1
log 1 + t = lim φ u1 u> , 1
1 2µ1 log 1
+ t . (6.9)
→0 →0
171
Let {Wα } ⊆ S+ >
d be PSD matrices converging to 0 with positive alignment with u1 u1
γ̃
φ Wα , 2µ1 1 log 1
+t − W1G (t) 2µ1 +γ̃
≤ C kWα kF (6.10)
hWα ,u1 u>1 i F
It is worth to note that W1G (t) has rank ≤ 1 for any t ∈ R, since every W1,
G
(t) has
rank ≤ 1 and the set S+
d,≤1 is closed. This matches with the first warmup example:
GD does start learning with rank-1 solutions. Interestingly, in the case where the limit
W 1 := limt→+∞ W1G (t) happens to be a minimizer of f ( · ) in S+
d , GLRL should exit
with the rank-1 solution W 1 after the first phase, and the following theorem shows
that this is also the solution found by GF.
Theorem 6.5.8. Under Assumptions 6.5.5 and 6.5.7, if kW1G (t)kF is bounded for all
t ≥ 0, then the limit W 1 := limt→+∞ W1G (t) exists. Further, if W 1 is a minimizer of
f ( · ) in S+ +
d , then for PSD matrices {Wα } ⊆ Sd converging to 0 with positive alignment
with u1 u>
1 as α → 0, it holds that limα→0 limt→+∞ φ(Wα , t) = W 1 .
172
Example 6.5.9 (Counter-example of Conjecture 6.1.1, Gunasekar et al. 16). Theo-
rem 6.5.8 enables us to construct counterexamples of the implicit nuclear norm regular-
ization conjecture in [16]. The idea is to construct a problem where every rank-1 station-
ary point of L(U ) (i.e., ∇L(U ) = 0 and U ∈ Rd×d is rank-1) attains the global minimum
but none of them is minimizing the nuclear norm. Below we give a concrete matrix
completion problem that meets the above requirement. Let M be a partially observed
matrix to be recovered, where the entries in Ω = {(1, 3), (1, 4), (2, 3), (3, 1), (3, 2), (4, 1)}
are observed and the others (marked with “?”) are unobserved. The optimization
problem is defined formally by L(U ) = 12 f (U U > ), f (W ) = 21 (i,j)∈Ω (Wij − Mij )2 .
P
? ?1 R R 1 1 R 1 R 1 R
? 2 2
? R ? 1 R R 1
R R R R
M = , Mnorm = , Mrank = .
1 R ? ? 1 R R 1 1 R 1 R
2 2
R ? ? ? R 1 1 R R R R R
Here R > 1 is a large constant, e.g., R = 100. The minimum nuclear norm solution
is the rank-2 matrix Mnorm , which has kMnorm k∗ = 4R (which is 400 when R = 100).
Mrank is a rank-1 solution with much larger nuclear norm, kMnorm k∗ = 2R2 + 2 (which
is 20002 when R = 100). We can verify that f ( · ) satisfies Assumptions 6.5.5 and 6.5.7
and W1G (t) converges to the rank-1 solution Mrank . Therefore, GF with infinitesimal
initialization converges to Mrank rather than Mnorm , which refutes the conjecture in
[16]. See Section 6.11 for a formal statement.
Theorem 6.5.6 shows that for any fixed time t, the trajectory of GLRL in the first
phase approximates GF with infinitesimal initialization, i.e., W1G (t) = limα→0 W
cα (t),
cα (t) := φ(Wα , 1 log(hWα , u1 u>
where W −1 G
1 i ) + t). However, W1 (∞) 6= limα→0 Wα (∞)
c
2µ1
does not hold in general, unless the prerequisite in Theorem 6.5.8 is satisfied, i.e.,
173
unless W 1 = W1G (∞) is a minimizer of f ( · ) in S+
d . This is because of the well-known
result that GD converges to local minimizers [119, 121]. We adapt Theorem 2 of Lee
et al. [119] to the setting of GF (Theorem 6.14.5) and obtain the following result
(Theorem 6.5.10); see Section 6.14.4 for the proof.
minimizer of f ( · ) with a higher rank and thus away from the rank-1 matrix W 1 . In
other words, W1G (t) only describes the limiting trajectory of GF in the first phase, i.e.,
when GF goes from near 0 to near W 1 . After a sufficiently long time (which depends
on α), GF escapes the critical point W 1 , but this part is not described by W1G (t).
To understand how GF escapes W 1 , a priori, we need to know how GF approaches
W 1 . Using a similar argument for Theorem 6.5.3, Theorem 6.5.11 shows that generically
GF only escapes in the direction of v1 v1> , where v1 is the (unique) top eigenvector of
−∇f (W 1 ), and thus the limiting trajectory exactly matches with that of GLRL in
the second phase until GF gets close to another critical point W 2 ∈ S+
d,≤2 . If W 2 is
top principal component of −∇f (W r−1 ) and gets close to W r . Each W r is a local
minimizer of f ( · ) in S+ +
d,≤r , but none of them is a minimizer of f ( · ) in Sd except W K .
The smaller the initialization is, the longer GF stays around each W r . Moreover,
{W r }K K
r=0 corresponds to {W r, }r=0 in Definition 6.5.1 with infinitesimal > 0.
101
W(0) F = 10 3
101
W(0) F = 10 6
101
W(0) F = 10 12
101
W(0) F = 10 24
101
W(0) F = 10 48
101
W(0) F = 10 96
10 2 10 2 10 2 10 2 10 2 10 2
10 3 10 3 10 3 10 3 10 3 10 3
10 4 10 4 10 4 10 4 10 4 10 4
0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0
Continuous Time 1e4 Continuous Time 1e4 Continuous Time 1e4 Continuous Time 1e4 Continuous Time 1e4 Continuous Time 1e4
Figure 6.1: The trajectory of depth-2 GD, WGD (t), converges to the trajectory of GLRL,
WGLRL (t), as the initialization scale goes to 0. We plot dist(t) = mint0 ∈T kWGD (t) −
WGLRL (t0 )kF for different initialization scale kW (0)kF , where T is a discrete subset of R
that δ-covers the entire trajectory of GLRL: maxt mint0 ∈T kWGLRL (t) − WGLRL (t0 )kF ≤
δ for δ ≈ 0.00042. For each kW (0)kF , we run 20 random seeds and plot them
separately. The ground truth W ∗ ∈ R20×20 is a randomly generated rank-3 matrix
with kW ∗ kF = 20. 30% entries are observed. See more in Section 6.8.1.
In this section we elaborate on the theoretical evidence that GF and GLRL are
equivalent generically, including the case where GLRL does not end in the first phase.
The word “generically” used when we want to assume one of the following regularity
conditions:
175
1. We want to assume that GF converges to a local minimizer (i.e., GF does not
get stuck on saddle points);
2. We want to assume that the top eigenvalue λ1 (−∇f (W )) is unique for a critical
point W of (6.2) that is not a minimizer of f ( · ) in S+
d;
cally, the top eigenvalue λ1 (−∇f (W r−1 )) should be unique, i.e., λ1 (−∇f (W r−1 )) >
λ2 (−∇f (W r−1 )). This enables us to apply Theorem 6.14.2 and deduce that the
limiting trajectory
1 1
WrG (t) := lim φ W r−1 + ur u>
r , log + t
→0 2λ1 (−∇f (W r−1 ))
exists, where ur is the top eigenvector of −∇f (W r−1 ). This WrG ( · ) is exactly the
trajectory of GLRL in phase r as → 0.
Note that WrG ( · ) corresponds to a trajectory of GF minimizing L( · ) in Rd×r ,
which should generically converge to a local minimizer of L( · ) in Rd×r . This means
the limit W r := limt→+∞ WrG (t) should generically be a local minimizer of f ( · ) in
S+ +
d,≤r . If W r is further a minimizer in Sd , then λ1 (−∇f (W r )) ≤ 0 and GLRL exits
that is, GF should generically align well with GLRL in the beginning of phase r + 1.
Definition 6.5.12. We say that GF aligns well with GLRL in the beginning of phase
(r) (r)
r if there exists Tα for every α > 0 such that φ(Wα , Tα ) converges to W r−1 with
positive alignment with ur u>
r as α → 0.
In this section, we consider matrix factorization problems with depth L ≥ 3. Our goal
is to understand the effect of the depth-L parametrization W = U1 U2 · · · UL on the
implicit bias — how does depth encourage GF to find low rank solutions? We take the
standard assumption in existing analysis for the end-to-end dynamics that the weight
matrices have a balanced initialization, i.e. Ui> (0)Ui (0) = Ui+1 (0)Ui+1
>
(0), ∀1 ≤ i ≤
L − 1. Arora et al. [29] showed that if {Ui }Li=1 is balanced at initialization, then we
177
have the following end-to-end dynamics. Similar to the depth-2 case, we use φ(W (0), t)
to denote W (t), where
L−1
dW X i i+1
=− (W W > ) L ∇f (W )(W > W )1− L . (6.11)
dt i=0
The lemma below is the foundation of our analysis for the deep case, which greatly
simplifies (6.11). We defer its derivations and applications into Section 6.15.
Lemma 6.6.1. If W (t) is a symmetric solution of (6.11), then for M (t) := W (t)2/L ,
we have
dM
= −∇f (M L/2 )M L/2 − M L/2 ∇f (M L/2 ). (6.12)
dt
Our main result, Theorem 6.6.2, gives a characterization of the limiting trajectory
for deep matrix factorization with infinitesimal identity initialization. Here W (t) :=
−(1−1/P )
limα→0 WαG (t) is the trajectory of deep GLRL, where WαG (t) := φ(αe1 e> α
1 , 2µ1 (P −1) + t)
(see Algorithm 6). The dynamics for general initialization is more complicated. Please
see discussions in Section 6.6.1.
178
L
Theorem 6.6.2. Let P = 2
, L ≥ 3. Suppose k∇f (0)k2 = λ1 (−∇f (0)) >
max{λ2 (−∇f (0)), 0},2
1
α−(1−1/P )
for every fixed t ∈ R, φ αI, 2µ1 (P −1)
+ t − W (t) = O(α P (P +1) ), (6.13)
F
−(1−1/P )
for every fixed t ∈ R, λk φ αI, α2µ1 (P −1) + t = O(α). (6.14)
M.
Suppose we run GF from αI for both depth-2 and depth-L cases. Intuitively, the
1-low-rankness of the depth-2 solution is Ω(α1−µ2 /µ1 ), which can be seen from the
second warmup example in Section 6.4. For the depth-L solution, though it may
diverge from the trajectory of deep GLRL more than the depth-2 solution does, its
2
k∇f (0)k2 = λ1 (−∇f (0)) is a technical assumption which we believe could be removed with a
more refined analysis.
179
1-low-rankness is only O(α), as shown in Theorem 6.6.4. The key idea is to show
that there is a basin in the manifold of rank-1 matrices around W0 such that any GF
starting within the basin converges to W0 . Based on this, we can prove that starting
from any matrix O(α)-close to the basin, GF converges to a solution O(α)-close to
W0 . See Section 6.16 for more details.
Theorem 6.6.4. In the same settings as Theorem 6.6.2, if W (∞) exists and is a
minimizer of f ( · ) in S+
d,≤1 , under the additional regularity assumption 6.16.1, we have
Interpretation for the advantage of depth with multiple phases. For depth-
2 GLRL, the low-rankness is raised to some power less then 1 per phase (depending on
the eigengap). For deep GLRL, we show the low-rankness is only multiplied by some
constant for the first phase and speculate it to be true for later phases. This conjecture
is supported by our experiments; see Figure 6.2. Interestingly, our theory and
experiments (Figure 6.7) suggest that while being deep is good for generalization,
being much deeper may not be much better: once L ≥ 3, increasing the depth
does not improve the order of low-rankness significantly. While this theoretical result
is only for identity initialization, Theorem 6.7.1 and Corollary 6.7.2 further show that
the dynamics of GF (6.11) with any initialization pointwise converges as L → ∞,
under a suitable time rescaling. See Figure 6.4 for experimental verification.
For deep matrix factorization, recall that we only prove that GF with infinitesimal
identity initialization escapes in the direction of the top eigenvector. The main burden
for us to generalize this proof to general initialization is that we don’t know how to
analyze the early phase dynamics of (6.12), i.e., the analytical solution of (6.16) is
180
102
L = 2, W(0) F = 10 6
102
L = 2, W(0) F = 10 12
102
L = 2, W(0) F = 10 24
102
L = 2, W(0) F = 10 48
102
L = 2, W(0) F = 10 96
10 2 10 2 10 2 10 2 10 2
10 6 10 6 10 6 10 6 10 6
10 10 10 10 10 10 10 10 10 10
0 1000 2000 3000 4000 0 1000 2000 3000 4000 0 1000 2000 3000 4000 0 1000 2000 3000 4000 0 1000 2000 3000 4000
102
L = 2, W(0) F = 10 3
102
L = 2, W(0) F = 10 4
102
L = 2, W(0) F = 10 5
102
L = 2, W(0) F = 10 6
102
L = 2, W(0) F = 10 7
10 2 10 2 10 2 10 2 10 2
10 6 10 6 10 6 10 6 10 6
10 10 10 10 10 10 10 10 10 10
102 103 102 103 102 103 102 103 102 103
102
L = 3, W(0) F = 10 3
102
L = 3, W(0) F = 10 4
102
L = 3, W(0) F = 10 5
102
L = 3, W(0) F = 10 6
102
L = 3, W(0) F = 10 7
10 2 10 2 10 2 10 2 10 2
10 6 10 6 10 6 10 6 10 6
10 10 10 10 10 10 10 10 10 10
0.00 0.25 0.50 0.75 1.00 0.0 0.5 1.0 1.5 0.0 0.5 1.0 1.5 2.0 0 1 2 3 0 1 2 3 4
1e4 1e4 1e4 1e4 1e4
102
L = 4, W(0) F = 10 3
102
L = 4, W(0) F = 10 4
102
L = 4, W(0) F = 10 5
102
L = 4, W(0) F = 10 6
102
L = 4, W(0) F = 10 7
10 2 10 2 10 2 10 2 10 2
10 6 10 6 10 6 10 6 10 6
10 10 10 10 10 10 10 10 10 10
0 1000 2000 3000 0.00 0.25 0.50 0.75 1.00 0 1 2 3 0.00 0.25 0.50 0.75 1.00 0 1 2 3
1e4 1e4 1e5 1e5
Continuous Time Continuous Time Continuous Time Continuous Time Continuous Time
distance grad norm r-low-rankness
Figure 6.2: GD passes by the same set of critical points as GLRL when the initialization
scale is small, and gets much closer to the critical points when L ≥ 3. Depth-2 GD
requires a much smaller initialization scale to maintain small low-rankness. Here the
ground truth matrix W ∗ ∈ R20×20 is of rank 3 as stated in Section 6.8.1. In this case,
GLRL has 3 phases and 4 critical points {W r }3r=0 , where W 0 = 0 and W 3 = W ∗ .
For each depth L and initialization scale kW (0)kF , we plot the distance between the
current step of GD and the closest critical point of GLRL, min0≤r≤3 kWGD (t) − W r kF ,
the norm of full gradient, k∇U1:L L(U1:L )kF and the (r + 1)-low-rankness of WGD (t)
with r := argmin0≤r≤3 kWGD (t) − W r kF .
However, unlike the depth-2 case, M can be different from v1 v1> even if v1> M (0)v1 > 0.
We here give an example for diagonal M (0) and ∇f (0) at Section 6.6.3. Nevertheless,
181
we still conjecture that except for a zero measure set of M (0), M = v1 v1> , based on
the following theoretical and experimental evidences:
• If v1> M (0)v1 > 0 and rank(M (0)) = 1, we prove that M = v1 v1> . (See Theo-
rem 6.6.5)
Theorem 6.6.5 (rank-1 initialization escapes along the top eigenvector). When
M (t)
rank(M (0)) = 1, limt→∞ kM (t)kF
= v1 v1> , if v1> M (0)v1 > 0.
Proof. Let u(0) be the vector such that M (0) = u(0)u(0)> and u(t) ∈ Rd be the
solution of
du(t)
= ku(t)kL−2
2 ∇f (0)u(t).
dt
dM du > du >
= u +u = − ∇f (0)M (t) ku(t)kL−2
2 − M (t)∇f (0) ku(t)kL−2
2
dt dt dt
= − ∇f (0)M L/2 − M L/2 ∇f (0).
Rt
Let τ (t) = 0
ku(s)kL−2
2 ds. Then
du du dt 1
= =− dτ
kukL−2
2 ∇f (0)u = −∇f (0)u.
dτ dt dτ dt
182
That is, under time rescaling t → τ (t), the trajectory of u(t) still follows the power
iteration, regardless of the depth L.
Let ∇f (0) = diag(2, 0.9, 0.8, . . . , 0.1) ∈ R10×10 be diagonal. Let W (0) be also diagonal
and W (0)i,i ∼ Unif[0.9, 1.1] · α for i ∈ [10] \ {2}, W (0)2,2 = 16α, where α = 10−16 is a
small constant. Let the depth be 4.
Lemma 6.6.6. With ∇f (0) and W (0) constructed above, v1 M (0)v1> > 0 and M 6=
v1 v1> .
dM (t)i,i
= −2∇f (0)i,i M (t)2i,i , ∀i ∈ [10],
dt
M (t)−1 −1
i,i = M (0)i,i − 2∇f (0)i,i t, ∀i ∈ [10].
For i ∈ [10], the time for M (t)i,i going to infinity is (2M (0)i,i ∇f (0)i,i )−1 . By simple
calculation, M (t)2,2 goes to infinity the fastest, thus M = e2 e> >
2 6= v1 v1 .
We remark that the scales of W (0) and ∇f (0) do not matter as in gradient flow,
as scaling ∇f (0) is equivalent to scaling time (by Lemma 6.6.7 below). And for this
kW (t)kF
reason, the x-axis is the chosen as kW (0)kF
, the relative growth rate.
183
dynamics ddtM = f(0)M2 M2 f(0)
10 1
| v1, ut(t) |
10 3
= 10 4
10 5 = 10 3
= 10 2
dx(t)
= g(x(t)), x(0) = αx0 (0). (6.17)
dt
Theorem 6.7.1 shows that the end-to-end dynamics (6.18) converges point-wise while
L → ∞ if the product of learning rate and depth, ηL, is fixed as constant. Interestingly,
(6.18) also allows us to simulate the dynamics of W (t) for all depths L while the
computation time is independent of L. In Figure 6.4, we compare the effect of depth
184
while fixing the initialization and ηL. We can see that deeper models converge faster.
The difference between L = 1, 2, and 4 is large, while difference among L ≥ 16 is
marginal.
0.4 L = 32
L = 64
0.2 L = 128
0.0
0 250 500 750 1000
Normalized Continuous Time
Figure 6.4: The marginal value of being deeper. The trajectory of GD converges when
depth goes to infinity. Solid (dotted) curves correspond to test (train) loss. The x-axis
stands for the normalized continuous time t (multiplied by L).
dW
= −LŨ Ũ > ∇f (W )Ṽ ◦ K (L) Ṽ > , (6.18)
dt
185
Proof. We start from (6.11):
L−1
dW X l L−1−l
=− (W W > ) L ∇f (W )(W > W ) L
dt l=0
L−1
X 2l 2(L−1−l)
=− Ũ Σ̃ L Ũ > ∇f (W )Ṽ Σ̃ L Ṽ
l=0
" L−1
#
X 2l 2(L−1−l)
= −LŨ L−1 Σ̃ (Ũ > ∇f (W )Ṽ )Σ̃
L L Ṽ.
l=0
2l 2(L−1−l)
Σ̃ L (Ũ > ∇f (W )Ṽ )Σ̃ L = (Ũ > ∇f (W )Ṽ ) ◦ H (l) ,
2l 2(L−1−l)
(l)
where Hi,j = σiL σj L
. Therefore,
L−1 L−1
X 2l 2(L−1−l) X
−1 > −1
L Σ̃ (Ũ ∇f (W )Ṽ )Σ̃
L L =L (Ũ > ∇f (W )Ṽ ) ◦ H (l)
l=0 l=0
PL−1
where K (L) = L−1 l=0 H (l) . Hence,
dW h
> (L)
i
= −LŨ (Ũ ∇f (W )Ṽ ) ◦ K Ṽ.
dt
σ 2−2/L ,
L−1
X 2l 2(L−1−l)
i i = j,
(L) −1
Ki,j =L σi σj
L L
=
σi2 −σj2
l=0
2/L 2/L , i 6= j.
Lσi −Lσj
186
σi2 −σj2
Corollary 6.7.2. As L → ∞, K (L) converges to K ∗ , where Ki,i
∗ ∗
= σi2 , Ki,j = ln σi2 −ln σj2
for i 6= j.
Experiment details. We follow the general setting in Section 6.8.1. The ground
truth W ∗ is different but is generated in the same manner and has the same shape of
20 × 20 and p = 0.3 is used for observation generation. We directly apply (6.18), in
which we compute Ṽ and Ũ through SVD, to simulate the trajectory together with a
10−3
constant learning rate of L
for depth L. W (0) is sampled from 10−3 × N (0, Id ).
6.8 Experiments
Gradient Descent. Let ˜ > 0 be the Frobenius norm of the target random initial-
ization. For the depth-2 case, we sample 2 orthogonal matrices V1 , V2 and a diagonal
matrix D with Frobenius norm ˜, and we set U = V1 D1/2 V2> ; for the depth-L case
with L ≥ 3, we sample L orthogonal matrices V1 , . . . , VL and a diagonal matrix D
187
Depth (L) Simulation method
2 Constant LR, η = 10−3 for 106 iterations
3 Adaptive LR, η = 2 × 10−5 and ε = 10−4 for 106 iterations
4 Adaptive LR, η = 3 × 10−4 and ε = 10−3 for 106 iterations
>
with Frobenius norm ˜, and we set Ui := Vi D1/L Vi+1 (VL+1 = V1 ). In this way,
we can guarantee that the end-to-end matrix W = U1 · · · UL is symmetric and the
initialization is balanced for L ≥ 3.
We discretize the time to simulate gradient flow. When L > 2, gradient flow
stays around saddle points for most of the time, therefore we use full-batch GD with
adaptive learning rate η̃t , inspired by RMSprop [58], for faster convergence:
GLRL. In Figures 6.1, 6.2, 6.5 and 6.6, the GLRL’s trajectory is obtained by running
Algorithm 5 with = 10−7 and η = 10−3 . The stopping criterion is that if the loop
has been iterated for 107 times.
188
6.8.2 Experimental Equivalence between GLRL and Gradi-
ent Descent
Here we provide experimental evidence supporting our theoretical claims about the
equivalence between GLRL and GF for both cases, L = 2 and L ≥ 3.
In Figure 6.1, we show the distance from every point on GF (simulated by GD)
from random initialization is close to the trajectory of GLRL. In Figure 6.2, we first
run GLRL and obtain the critical points {W r }3r=0 passed by GLRL. We also define
the distance of a matrix W to the critical points to be min0≤r≤3 kW − W r kF .
d = 20, W(0) F = 10 2
GLRL, L = 2
100
GD, L = 2
nuclear norm
10 1 R1MP (rank 3)
R1MP (rank 10)
loss
10 2
10 3
0.0 0.5 1.0 1.5 2.0 2.5
Continuous Time 1e4
Figure 6.5: GD with small initialization outperforms R1MP and minimal nuclear
norm solution on synthetic data with low-rank ground truth. Solid (dotted) curves
correspond to test (training) loss. Here the loss f (W ) := d12 kW − W ∗ k2F and f (0) = 1.
We run 10 random seeds for GD and plot them separately (most of them overlap).
189
6.8.4 How does initialization affect the convergence rate to
We use the general setting in Section 6.8.1. In these experiments, we use the constant
learning rate 10−5 for 4 × 107 iterations. The reference matrix Wref is obtained by
running the first stage of GLRL with kW (0)kF = 10−48 and we pick one matrix in the
trajectory with kWref kF about 0.6.
For every = 10i , i ∈ {−1, −2, −3, −4, −5}, we run both gradient descent and
the first phase of GLRL with kW (0)kF = . For gradient descent, we use random
initialization so kW (0)kF is full rank w.p. 1. The distance of a trajectory to Wref
is defined as mint≥0 kW (t) − Wref kF . In practice, as we discretized time to simulate
gradient flow, we check every t during simulation to compute the distance. As a result,
the estimation might be inaccurate when a trajectory is really close to Wref .
The result is shown at Figure 6.6. We observe that GLRL trajectories are closer
to the reference matrix Wref by magnitudes. Thus the take home message here is that
GLRL is in general a more computational efficient method to simulate the trajectory
of GF (GD) with infinitesimal initialization, as one can start GLRL with a much
larger initialization, while still maintaining high precision.
dence on initialization
To verify the our theory in Section 6.6, we run gradient descent with different depth and
initialization. The results are shown in Figure 6.7. We can see that as the initialization
becomes smaller, the final solution gets closer to the ground truth. However, a
depth-2 model requires exponentially small initialization, while deeper models require
polynomial small initialization, though it takes much longer to converge.
190
100
t 0 Wref WGD(t)
10 2
rank 1
10 4 rank d
min
10 6
10 5 10 4 10 3 10 2 10 1
WGD(0) F
Figure 6.6: Using v1 v1> (denoted by “rank 1”) as initialization makes GD much closer
to GLRL compared to using random initialization (denoted by “rank d”), where v1 is
the top eigenvector of −∇f (0). We take a fixed reference matrix on the trajectory of
GLRL with constant norm and plot the distance of GD with each initialization to it
respectively..
4 10 12 10 6 4 10 6 10 6
10 15 10 7 10 7 5.0 10 7
log10 test loss
2 6
6 10 18 7.5
8
8 3 10.0
10
12.5
10 12
4 15.0
12 14
17.5
5 16
0 1 2 3 4 0 2 4 6 0 2 4 6 0 1 2 3
Continuous Time 1e4 Continuous Time 1e4 Continuous Time 1e4 Continuous Time 1e5
Figure 6.7: Deep matrix factorization encourages GF to find low rank solutions at a
much practical initialization scale, e.g. 10−3 . Here the ground truth is rank-3. For
each setting, we run 5 different random seeds. The solid curves are the mean and
the shaded area indicates one standard deviation. We observe that performance of
GD is quite robust to its initialization. Note that for L > 2, the shaded area with
initialization scale 10−7 is large, as the sudden decrement of loss occurs at quite
different continuous times for different random seeds in this case.
191
6.9 Future Directions
Our result on the equivalence between gradient flow with infinitesimal initialization
and GLRL is based on some regularity conditions that we expect to hold generically.
We leave it a future work to justify these condition, possibly through a smoothed
analysis on the objective f ( · ). Another interesting future direction is to find the
counterpart of GLRL in training deep neural nets. This could be one way to go beyond
the view of norm minimization in the study of the implicit regularization of gradient
descent.
Lemma 6.10.1. For U0 ∈ Rd×r and W0 := U0 U0> , the following statements are
equivalent:
Proof. (2) ⇒ (3) is trivial. We only prove (1) ⇒ (2), (3) ⇒ (1).
Proof for (1) ⇒ (2). If U0 is a stationary point, then 0 = ∇L(U0 ) = ∇f (W0 )U0 .
So
∇f (W0 )W0 = (∇f (W0 )U0 ) U0> = 0.
0 = hg(W0 ), ∇f (W0 )i = −2 Tr(∇f (W0 )W0 ∇f (W0 )) = −2k∇f (W0 )U0 k2F ,
∇f (W0 ) 0.
minimizer of f (W ) in S+
d iff
Note that h∇f (W0 ), W0 i = Tr(∇f (W0 )W0 ). By Lemma 6.10.1, h∇f (W0 ), W0 i = 0.
Combining this with (6.19), we know that W0 is a global minimizer iff
Proposition 6.11.2 (Formal Statement for Example 6.5.9). For constant R > 1, let
? ? 1 R R 1 1 R 1 R 1 R
? ? R ? 1 R R 1 R R 2 R R 2
M = , Mnorm = , and Mrank = .
1 R ? ? 1 R R 1 1 R 1 R
2 2
R ? ? ? R 1 1 R R R R R
193
and
1 1 X
L(U ) = f (U U > ), f (W ) = (Wij − Mij )2
2 2
(i,j)∈Ω
where Ω = {(1, 3), (1, 4), (2, 3), (3, 1), (3, 2), (4, 1)}.
Then for any Winit 0, s.t. u>
1 Winit u1 > 0,
Moreover, we have
G
Proof. We define W1, (t), W1G (t) in the same way as in Definition 6.5.1, Theorem 6.5.6.
G
(t) := φ u1 u>
W1, 1 , t ,
2. W1G (t) F
bounded for t ≥ 0;
194
Proof for Item 1. Let M0 := ∇f (0), then
0 0 1 R
0 0 R 0
M0 = .
1 R 0 0
R 0 0 0
√ √
1+ 1+R2 1− 1+R2
Let A := [ R1 R0 ], then we have λ1 (A) = 2
, λ2 (A) = 2
, thus λ1 (A) >
|λ2 (A)| > 0 > λ2 (A). As a result, λ1 (A) = kAk2 . Let v1 ∈ R2 be the top eigenvector
of A. We claim that u1 = [ vv11 ] ∈ R4 is the top eigenvector of ∇f (0). First by definition
2
it is easy to check that M0 u1 = λ1 (A)u1 . Further noticing that M02 = A0 A02 , we
know λ2i (M0 ) ∈ {λ21 (A), λ22 (A)} for all eigenvalues λi (M0 ). That is, λ1 (M0 ) = λ1 (A),
λ2 (M0 ) = −λ2 (A), λ3 (M0 ) = λ2 (A), and λ4 (M0 ) = −λ1 (A). Thus Assumption 6.5.5
is satisfied. Also note that f is quadratic, thus analytic, i.e., Assumption 6.5.7 is also
satisfied.
Proof for Item 2. Let (x (t), y (t)) ∈ R2 be the gradient flow of g(x, y) = 12 (x2 −
√
1)2 + (xy − R)2 starting from (x (0), y (0)) = v1 .
dx(t)
= (1 − x(t)2 )x(t) − 2y(t)(x(t)y(t) − R)
dt (6.21)
dy(t)
= −2x(t)(x(t)y(t) − R)
dt
x (t)
y (t)
W (t) := x (t) y (t) x (t) y (t) .
x (t)
y (t)
195
G
Then it is easy to verify that W (0) = W1, (0) and W (t) satisfies (6.2). Thus by the
G
existence and uniqueness theorem, we have W (t) = W1, (t) for all t. Taking the limit
→ 0, we know that W1G (t) can also be written in the following form:
x(t)
y(t)
G
W1 (t) = x(t) y(t) x(t) y(t) ,
x(t)
y(t)
and (x (t), y (t)) ∈ R2 is a gradient flow of g(x, y) = 12 (x2 − 1)2 + (xy − R)2 .
Since g(x(t), y(t)) is non-increasing overtime, and lim g(x(−t), y(−t)) =
t→−∞
Proof for Item 3. Note that (x(∞), y(∞)) is a stationary point of g(x, y). It is
clear that g(x, y) only has 3 stationary points — (0, 0), (1, R) and (−1, −R). Thus W 1
can only be 0 or Mrank . However, since for all t, f (W1G (t)) < f (0), W 1 = limt→∞ W1G (t)
cannot be 0. So W 1 must be Mrank .
Proof for Item 4. Let mij be (i, j)th element of M . Suppose M 0, we have
196
Thus 4R = minW 0,f (W )=0 kW k∗ , where
the equality is
only attained at mii = R, i =
m11 m14 m22 m23
1, 2, 3, 4. Otherwise, either or will have negative eigenvalues.
m41 m44 m32 m33
Contradiction to that M 0.
Below we will show the rest unknown off-diagonal entries must be 1. Let
1 −1 0 0
V =
0 0 1 0
0 0 0 1
In this section, we prove Theorem 6.5.3 in Section 6.5.1. In Section 6.12.1, we show
how to reduce Theorem 6.5.3 to the case where J(0) is exactly a diagonal matrix, then
we prove this diagonal case in Section 6.12.2. Finally, in Section 6.12.3, we discuss
how to extend it to the case where J(0) is non-diagonalizable.
197
6.12.1 Reduction to the Diagonal Case
Proof for Theorem 6.5.3. We show how to prove Theorem 6.5.3 based on Theo-
dx
rem 6.12.1. Let dt
= g(x) be the dynamical system in Theorem 6.5.3. Let J(0) =
Ṽ D̃Ṽ −1 be the eigendecomposition, where Ṽ is an invertible matrix and D̃ =
diag(µ̃1 , . . . , µ̃d ). Now we define the following new dynamics by changing the ba-
sis:
x̂(t) = Ṽ −1 x(t).
dx̂(t)
Then dt
= ĝ(x̂) for ĝ(x̂) := Ṽ −1 g(Ṽ x̂), and the associated Jacobian matrix is
ˆ
J(x̂) ˆ = diag(µ̃1 , . . . , µ̃d ).
:= Ṽ −1 J(Ṽ x̂)Ṽ , and thus J(0)
Now we apply Theorem 6.12.1 to x̂(t). Then ẑα (t) := Ṽ −1 zα (t) converges to the
limit ẑ(t) := lim ẑα (t). This shows that the limit z(t) = Ṽ ẑ(t) exists in Theorem 6.5.3.
α→0
γ̃
−1 1 1
Ṽ φ δα , t + log − ẑ(t) ≤ C · kδ̂α k2µ̃1 +γ̃ (6.22)
µ̃1 hδα , ũ1 i 2
for every sufficiently small α. As Ṽ are invertible, this directly implies (6.7).
198
6.12.2 Proof for the Diagonal Case
Now we only need to prove Theorem 6.12.1. Let e1 , . . . , ed be the standard basis.
Then ũ1 = ṽ1 = e1 in this diagonal case. We only use e1 to stand for ũ1 and ṽ1 in the
rest of our analysis.
Let R > 0. Since g(x) is C 2 -smooth, there exists β > 0 such that
for all kxk2 , kx + hk2 ≤ R. Then the following can be proved by integration:
Z 1
g(x + h) − g(x) = J(x + ξh)dξ h, (6.24)
0
Let κ := β/µ̃1 . We assume WLOG that R ≤ 1/κ. Let F (x) = log x − log(1 + κx).
It is easy to see that F 0 (x) = 1
x+κx2
and F (x) is an increasing function with range
(−∞, log(1/κ)). We use F −1 (y) to denote the inverse function of F (x). Define
Tα (r) := µ̃11 (F (r) − F (α)) = µ̃11 log αr − log 1+κα
1+κr
.
Our proof only relies on the following properties of J(0) (besides that µ̃1 , e1 are
the top eigenvalue and eigenvector of J(0)):
199
Pd
Proof. For Item 1, h> J(0)h = i=1 µ̃i h2i ≤ µ̃1 khk22 . For Item 2, etJ(0) − eµ̃1 t e1 e>
1 2
=
diag(0, eµ̃2 t , . . . , eµ̃d t ) 2
= eµ̃2 t .
1 + κr
kx(t)k2 ≤ α · eµ̃1 t ≤ r.
1 + κα
1 dkx(t)k22
= hx(t), g(x(t))i ≤ hx(t), J(0)x(t)i + βkx(t)k32 ≤ µ̃1 kx(t)k22 + βkx(t)k32 .
2 dt
dkx(t)k2
This implies dt
≤ µ̃1 (kx(t)k2 + κkx(t)k22 ). Since F 0 (x) = 1
x+κx2
, we further have
d
F (kx(t)k2 ) ≤ µ̃1 .
dt
Lemma 6.12.4. For x(t) = φ(x0 , t) with kx0 k2 ≤ α and t ≤ Tα (r), we have
200
Proof. Let x̂(t) = etJ(0) x0 . Then we have
1d
kx(t) − x̂(t)k22 ≤ hg(x(t)) − J(0)x̂(t), x(t) − x̂(t)i
2 dt
= hg(x(t)) − J(0)x(t), x(t) − x̂(t)i + (x(t) − x̂(t))> J(0)(x(t) − x̂(t))
where the last inequality is due to Lemma 6.12.2. By (6.26) and Lemma 6.12.3, we
have
2
1 + κr
kg(x(t)) − J(0)x(t)k2 ≤ βkx(t)k22 ≤β α · e2µ̃1 t .
1 + κα
d 1+κr
2
So we have dt
kx(t) − x̂(t)k2 ≤ β 1+κα
α · e2µ̃1 t + µ̃1 kx(t) − x̂(t)k2 . By Grönwall’s
inequality,
Z t 2
1 + κr
kx(t) − x̂(t)k2 ≤ β α · e2µ̃1 τ eµ̃1 (t−τ ) dτ.
0 1 + κα
2 2
eµ̃1 t − 1
1 + κr µ̃1 t 1 + κr
kx(t) − x̂(t)k2 ≤ β α e · ≤κ α · eµ̃1 t ≤ κr2 ,
1 + κα µ̃1 1 + κα
Lemma 6.12.5. Let x(t) = φ(x0 , t), x̂(t) = φ(x̂0 , t). If max{kx0 k2 , kx̂0 k2 } ≤ α, then
for t ≤ Tα (r),
kx(t) − x̂(t)k2 ≤ eµ̃1 t+κr kx0 − x̂0 k2 .
1d
kx(t) − x̂(t)k22 = hg(x(t)) − g(x̂(t)), x(t) − x̂(t)i
2 dt Z 1
= (x(t) − x̂(t))> J(xξ (t))dξ (x(t) − x̂(t)),
0
201
where xξ (t) := ξx(t) + (1 − ξ)x̂(t). By Lemma 6.12.3, max{kx(t)k2 , kx̂(t)k2 } ≤
1+κr 1+κr
1+κα
α · eµ̃1 t for all t ≤ Tα (r). So kxξ (t)k2 ≤ 1+κα
α · eµ̃1 t . Combining these with (6.23)
and Lemma 6.12.2, we have
> > > 1 + κr µ̃1 t
h J(xξ (t))h = h J(0)h + h (J(xξ (t)) − J(0))h ≤ µ̃1 + β · α·e khk22 ,
1 + κα
d 1+κr
for all h ∈ Rd . Thus, dt
kx(t) − x̂(t)k2 ≤ µ̃1 + β · 1+κα
α · eµ̃1 t kx(t) − x̂(t)k2 . This
implies
Z t
kx(t) − x̂(t)k2 1 + κr µ̃1 τ
log ≤ µ̃1 + β · α·e dτ
kx(0) − x̂(0)k2 0 1 + κα
1 + κr µ̃1 t
≤ µ̃1 t + κ · αe
1 + κα
≤ µ̃1 t + κr.
Lemma 6.12.6. For every t ∈ (−∞, +∞), z(t) exists and zα (t) converges to z(t) in
the following rate:
kzα (t) − z(t)k2 = O(α),
Proof. We prove the lemma in the cases of t ∈ (−∞, F (R)/µ̃1 ] and t > F (R)/µ̃1
respectively.
α̃
Case 1. Fix t ∈ (−∞, F (R)/µ̃1 ]. Let α̃ be the unique number such that 1+κα̃
=α
(i.e., F (α̃) = log α). Let α0 be an arbitrary number less than α. Let t0 := 1
µ̃1
log αα0 .
Then t0 = 1
µ̃1
(F (α̃) − log α0 ) ≤ Tα0 (α̃). By Lemma 6.12.4, we have
202
Let r := F −1 (µ̃1 t) ≤ R. Then t + 1
µ̃1
log α1 = Tα̃ (r) if α̃ < r.
By Lemma 6.12.3, kφ (α0 e1 , t0 )k2 ≤ α̃. Also, kαe1 k2 = α̃
1+κα̃
≤ α̃. By Lemma 6.12.5,
0 1 1 1 1
kzα (t) − zα0 (t)k2 = φ α e1 , t + log 0 − φ αe1 , t + log
µ̃1 α µ̃1 α 2
0 1 1 1 1
= φ φ(α e1 , t0 ), t + log − φ αe1 , t + log
µ̃1 α µ̃1 α 2
µ̃1 (t+ µ̃1 log 1
)+κr
≤ O(α̃2 · e 1 α
)
2
α̃
≤O .
α
This implies that {zα (t)} satisfies Cauchy’s criterion for every t, and thus the limit
z(t) exists for t ≤ F (R)/µ̃1 . The convergence rate can be deduced by taking limits
for α0 → 0 on both sides.
Case 2. For t = F (R)/µ̃1 + τ with τ > 0, φ(x, τ ) is locally Lipschitz with respect
to x. So
= O(α),
has already been proved in Lemma 6.12.6, where we show kzα (t) − z(t)k2 = O(α).
203
By the continuity of φ( · , t) for every t ∈ R, we have
1 1 1 1
z(t) = lim φ αṽ1 , t + log = φ lim φ αṽ1 , log , t = φ (z(0), t) .
α→0 µ̃1 α α→0 µ̃1 α
Now it is only left to prove (6.7). WLOG we can assume that kδα k2 is decreasing and
α
2
≤ kδα k2 ≤ α (otherwise we can do reparameterization). Then our goal becomes
proving
γ̃
kxα (t) − z(t)k2 = O α µ̃1 +γ̃ . (6.27)
1
where xα (t) := φ δα , t + µ̃1
log hδα1,e1 i . We prove (6.27) in the cases of t ∈
(−∞, F (R)/µ̃1 ] and t > F (R)/µ̃1 respectively.
γ̃
α̃1
Case 1. Fix t ∈ (−∞, (F (R) + log q)/µ̃1 ]. Let α̃1 = α µ̃1 +γ̃ . Let α1 := eF (α̃1 ) = 1+κα̃1
.
1
Let t0 := µ̃1
(F (α̃1 ) − log α) ≤ Tkδα k2 (α̃1 ). At time t0 , by Lemma 6.12.2 we have
µ̃2 α µ̃µ̃2
(F (α̃1 )−log α) 1
e t0 J(0)
−eµ̃1 t0
e1 e>
1 2 =e µ̃2 t0
=e µ̃1
= 1
. (6.28)
α
δα
Let qα := α
, e1 . By Definition 6.5.2, there exists q > 0 such that qα ≥ q for all
sufficiently small α. Then we have
= O(α12 ).
204
By Lemma 6.12.5,
1 1 1 1
kxα (t) − zα1 (t)k2 ≤ φ φ (δα , t0 ) , t + log − φ α1 qα e 1 , t + log
µ̃1 α1 q α µ̃1 α1 qα 2
1 1
µ̃ t+ log +κr
= O α12 · e 1 µ̃1 α1 qα
= O(α1 ).
Combining this with the convergence rate for zα1 (t), we have
kxα (t) − z(t)k2 ≤ kxα (t) − zα1 (t)k2 + kzα1 (t) − z(t)k2 = O(α1 ).
Case 2. For t = (F (R) + log q)/µ̃1 + τ with τ > 0, φ(x, τ ) is locally Lipschitz with
respect to x. So
kxα (t) − z(t)k2 = kφ(xα ((F (R) + log q)/µ̃1 ), τ ) − φ(z((F (R) + log q)/µ̃1 ), τ )k2
= O(kxα ((F (R) + log q)/µ̃1 ) − z((F (R) + log q)/µ̃1 )k2 )
= O(α1 ),
The proof in Section 6.12.2 can be generalized to the case where J(0). Now we state the
theorem formally and sketch the proof idea. We use the notations g(x), φ(x0 , t), J(x)
as in Section 6.5.1, but we do not assume that J(0) is diagonalizable. Instead, we use
µ̃1 , µ̃2 , . . . , µ̃d ∈ C to denote the eigenvalues of J(0), repeated according to algebraic
multiplicity. We sort the eigenvalues in the descending order of the real part of each
eigenvalue, i.e., <(µ̃1 ) ≥ <(µ̃2 ) ≥ · · · ≥ <(µ̃d ), where <(z) stands for the real part of
205
a complex number z ∈ C. We call the eigenvalue with the largest real part the top
eigenvalue.
Theorem 6.12.7. Assume that x = 0 is a critical point and the following regularity
conditions hold:
1. g(x) is C 2 -smooth;
3. The top eigenvalue of J(0) is unique and is a positive real number, i.e.,
Let ṽ1 , ũ1 be the left and right eigenvectors associated with µ̃1 , satisfying ũ>
1 ṽ1 = 1.
1
Let zα (t) := φ(αṽ1 , t + µ̃1
log α1 ) for every α > 0, then ∀t ∈ R, z(t) := lim zα (t) exists
α→0
and z(t) = φ(z(0), t). If δα converges to 0 with positive alignment with ũ1 as α → 0,
then for any t ∈ R and for any > 0, there is a constant C > 0 such that for every
sufficiently small α,
γ̃
−
1
φ δα , t + µ̃1
log hδα1,ũ1 i − z(t) ≤ C · kδα k2µ̃1 +γ̃ , (6.29)
2
a −b
where C = b a ∈ R2×2 .
By linear algebra, the real matrix J(0) can be written in the real Jordan normal
form, i.e., J(0) = Ṽ diag(J[1] , . . . , J[m] )Ṽ −1 , where Ṽ ∈ Rd×d is an invertible matrix,
and each J[j] is a real Jordan block. Recall that there are two types of real Jordan
(r) (r)
blocks, Ja,1 or Ja,b,1 . The former one is associated with a real eigenvalue a, and the
latter one is associated with a pair of complex eigenvalues a ± bi. The sum of sizes
of all Jordan blocks corresponding to a real eigenvalue a is its algebraic multiplicity.
The sum of sizes of all Jordan blocks corresponding to a pair of complex eigenvalues
a ± bi is two times the algebraic multiplicity of a + bi or a − bi (note that a ± bi have
the same multiplicity).
(r) (r)
It is easy to see that Ja,δ = DJa,1 D−1 for D = diag(δ r , δ r−1 , . . . , δ) ∈ Rr×r and
(r) (r)
Ja,b,δ = DJa,b,1 D−1 for D = diag(δ r , δ r , δ r−1 , δ r−1 , . . . , δ, δ) ∈ R2r×2r . This means for
every δ > 0 there exists Ṽδ such that J(0) = Ṽδ Jδ Ṽδ−1 , where Jδ := diag(Jδ[1] , . . . , Jδ[m] ),
(r) (r) (r) (r)
Jδ[j] := Ja,δ if J[j] := Ja,1 , or Jδ[j] := Ja,b,δ if J[j] := Ja,b,1 . Since the top eigenvalue of
J(0) is positive and unique, µ̃1 corresponds to only one block [µ̃1 ] ∈ R1×1 . WLOG we
let J1 = [µ̃1 ], and thus Jδ[1] = [µ̃1 ].
We only need to select a parameter δ > 0 and prove the theorem in the case
of J(0) = Jδ since we can change the basis in a similar way as we have done in
Section 6.12.1. By scrutinizing the proof for Theorem 6.12.1, we can find that we only
207
need to reprove Lemma 6.12.2. However, Lemma 6.12.2 may not be correct since J(0)
is not diagonal anymore. Instead, we prove the following:
0
2. For any µ̃02 ∈ (<(µ̃2 ), µ̃1 ), if δ ∈ (0, µ̃02 − <(µ̃2 )), then etJδ − eµ̃1 t e1 e>
1 2
≤ eµ̃2 t
for all t ≥ 0.
Proof for Item 1. Let K be the set of pairs (k1 , k2 ) such that k1 6= k2 and the
entry of Jδ at the k1 -th row and the k2 -th column is non-zero. Then we have
d
> > Jδ + Jδ> X X
h Jδ h = h h= <(µ̃k )h2k + hk1 hk2 δ
2 k=1 (k1 ,k2 )∈K
d
X X h2k1 + h2k2
≤ <(µ̃k )h2k + δ.
k=1
2
(k1 ,k2 )∈K
Note that <(µ̃k ) ≤ <(µ̃2 ) for k ≥ 2. Also note that there is no pair in K has k1 = 1
or k2 = 1, and for every k ≥ 2 there are at most two pairs in K has k1 = k or k2 = k.
Combining all these together gives
d
X
>
h Jδ h ≤ µ̃1 h21 + (<(µ̃2 ) + δ) h2k ≤ µ̃1 khk22 ,
k=2
Proof for Item 2. Since Jδ is block diagonal, we only need to prove that ketJδ[j] k2 ≤
0 (r)
eµ̃2 t for every j ≥ 2. If Jδ[j] = Ja,δ = aI + δN , where N is the nilpotent matrix, then
208
where the second equality uses the fact that I and N are commutable. So we have
(r)
If Jδ[j] = Ja,δ = D + δN 2 , where D = diag(C, C, . . . , C) and N is the nilpotent matrix,
then
2 2
etJδ[j] = etD+δtN = etD eδtN ,
where the second equality uses the fact that D and N 2 are commutable. Note that
h i
− sin(bt)
etC = eat cos(bt) tD tC at
sin(bt) cos(bt) , which implies ke k2 = ke k2 = e . So we have
2 2k
ketJδ[j] k2 ≤ ketD k2 · keδtN k2 = eat eδtkN 2
≤ e(a+δ)t .
Since δ ∈ (0, µ̃02 − <(µ̃2 )), we know that a + δ < µ̃02 , which completes the proof.
Proof for a fixed δ. Since Item 1 continues to hold for δ ∈ (0, γ̃), Lemmas 6.12.3
to 6.12.6 also hold. This proves that z(t) exists and satisfies (6.6).
It remains to prove (6.29) for any > 0. Let γ̃ 0 ∈ (0, γ̃) be a number such
γ̃ 0 γ̃
that µ̃1 +γ̃ 0
≥ µ̃1 +γ̃
− . Fix µ̃02 = µ̃1 − γ̃ 0 , δ = µ̃02 − <(µ̃2 ). By Item 2, we have
0
etJδ − eµ̃1 t e1 e>
1 2
≤ eµ̃2 t for all t ≥ 0. By scrutinizing the proof for Theorem 6.12.1,
we can find that the only place we use Item 2 in Lemma 6.12.2 is in (6.28). For proving
(6.29), we can repeat the proof while replacing all the occurrences of µ̃2 by µ̃02 . Then
we know that for every t ∈ R, there is a constant C > 0 such that
γ̃ 0
0
Ṽδ−1 φ δα , t + 1
µ̃1
log 1
hδα ,ũ1 i
− Ṽδ−1 z(t) ≤C· kṼδ−1 δα k2µ̃1 +γ̃ ,
2
209
γ̃ 0 γ̃
for every sufficiently small α. By definition of γ̃ 0 , µ̃1 +γ̃ 0
≥ µ̃1 +γ̃
− . Since δα → 0 as
α → 0, we have kṼδ−1 δα k2 < 1 for sufficiently small α. Then we have
φ δα , t + 1
µ̃1
log 1
hδα ,ũ1 i
− z(t) ≤ kṼδ k2 · Ṽδ−1 φ δα , t + 1
µ̃1
log 1
hδα ,ũ1 i
− Ṽδ−1 z(t)
2 2
γ̃ 0
µ̃1 +γ̃ 0
≤ kṼδ k2 · C · kṼδ−1 δα k2
γ̃ 0 γ̃
0 −
≤ C · kṼδ k2 · kṼδ−1 k2µ̃1 +γ̃ · kδα k2µ̃1 +γ̃ .
γ̃ 0
0
Absorbing kṼδ k2 · kṼδ−1 k2µ̃1 +γ̃ into C proves (6.29).
In this section we analyze the eigenvalues of the Jacobian J(W ) at critical points of
(6.2).
For notation simplicity, we write sz(A) := A + A> to denote the symmetric matrix
produced by adding up A and its transpose, and write ac{A, B} = AB + BA to
denote the anticommutator of two matrices A, B. Then g(W ) can be written as
g(W ) := −ac{∇f (W ), W }.
Let U0 ∈ Rd×r be a stationary point of the function L : Rd×r → R, L(U ) =
1
2
f (U U > ), i.e., ∇L(U0 ) = ∇f (U0 U0> )U0 = 0. By Lemma 6.10.1, this implies
210
Define J(W ) := Dg(W ). By simple calculus, we can compute the formula for
J(W0 ):
where ∆, ∆1 , ∆2 ∈ Rd×d .
We can also compute the formula for D2 L(U0 ):
where ∆, ∆1 , ∆2 ∈ Rd×r .
The eigenvalues of J(0) is given in Lemma 6.5.4. Now we provide the proof.
211
Pd
Let −∇f (0) = i=1 µi u1[i] u>
1[i] be the eigendecomposition of the symmetric matrix
d
X
µi u1[i] u> >
J(0)[∆] = 1[i] ∆ + ∆u1[i] u1[i]
i=1
d X
X d
µi u1[i] u> > > >
= 1[i] ∆u1[j] u1[j] + u1[j] u1[j] ∆u1[i] u1[i]
i=1 j=1
d X
X d
= (µi + µj )u1[i] u> >
1[i] ∆u1[j] u1[j]
i=1 j=1
d X
X d
= (µi + µj ) ∆, u1[i] u> >
1[j] u1[i] u1[j] ,
i=1 j=1
212
6.13.2 Eigenvalues at Second-Order Stationary Points
W0 = U0 U0> .
Proof. Assume to the contrary that U0 has rank < r and W0 is a minimizer of f ( · ) in
S+ r
d . The former one implies that there exists a unit vector q ∈ R such that U0 q = 0,
and the latter one implies that there exists v ∈ Rd such that v > ∇f (W0 )v < 0 by
Lemma 6.10.2.
Let ∆ = vq > . Then we have
1
D2 L(U0 )[∆, ∆] = ∇f (W0 ), vv > + D2 f (W0 )[sz(v(U0 q)> ), sz(v(U0 q)> )]
2
1
= ∇f (W0 ), vv > + D2 f (W0 )[0, 0]
2
= ∇f (W0 ), vv > .
By (6.30), the symmetric matrices −∇f (W0 ) and W0 commute, so they can be
simultaneously diagonalizable. Since (6.30) also implies that they have different
column spans, we can have the following diagonalization:
d−r
X d
X
− ∇f (W0 ) = µi vi vi> , W0 = µi vi vi> . (6.31)
i=1 i=d−r+1
213
First we prove the following lemma on the eigenvalues and eigenvectors of the
linear operator −D2 L(U0 ):
then ∆ is an eigenvector of the linear operator −D2 L(U0 )[ · ] : Rd×r → Rd×r associated
with eigenvalue 0. Moreover, the solutions of (6.32) spans a linear space of dimension
r(r−1)
2
.
Proof. Suppose U0 ∆> + ∆U0> = 0. Then we have U0 ∆> = −∆U0> , and thus ∆> =
−U0+ ∆U0> , where U0+ is the pseudoinverse of the full-rank matrix U0 . This implies
that there is a matrix R ∈ Rr×r , such that ∆ = U0 R. Then we have
−D2 L(U0 )[∆] = −∇f (W0 )U0 R − D2 f (W0 )[U0 ∆> + ∆U0> ]U0
= 0.
rd
X
−D2 L(U0 )[∆] = ξp hEp , ∆i Ep
p=1
214
satisfying hEp , Eq i = δpq . We enforce ξp to be 0 and Ep to be a solution of (6.32) for
r(r−1)
every rd − 2
< p ≤ rd.
Lemma 6.13.4. Let A ∈ RD×D be a matrix. If {û1 , . . . , ûK } is a set of linearly inde-
pendent left eigenvectors associated with eigenvalues λ̂1 , . . . , λ̂K and {ṽ1 , . . . , ṽD−K } is a
set of linearly independent right eigenvectors associated with eigenvalues λ̃1 , . . . , λ̃D−K ,
and hûi , ṽj i = 0 for all 1 ≤ i ≤ K, 1 ≤ j ≤ D − K, then λ̂1 , . . . , λ̂K , λ̃1 , . . . , λ̃D−K are
all the eigenvalues of A.
Proof. Let Û := (û1 , . . . , ûK )> ∈ RK×D and Ṽ := (ṽ1 , . . . , ṽD−K ) ∈ RD×(D−K) . Then
both Û and Ṽ are full-rank. Let Û + = Û > (Û Û > )−1 , Ṽ + = (Ṽ > Ṽ )−1 Ṽ > be the
pseudoinverses of Û and Ṽ .
Now we define
Û
P := , Q := Û +
Ṽ .
Ṽ +
Then we have
+
Û Û Û Ṽ
PQ = .
Ṽ + Û + Ṽ + Ṽ
Note that Û Û + = IK , Û Ṽ = 0, Ṽ + Û + = (Ṽ > Ṽ )−1 (Û Ṽ )> (Û Û > )−1 = 0, Ṽ + Ṽ = ID−K .
So P Q = ID , or equivalently Q = P −1 . Then we have
diag(λ̂1 , . . . , λ̂K ) ∗
P −1 AP = ,
0 diag(λ̃1 , . . . , λ̃D−K )
Theorem 6.13.5. The eigenvalues of J(W0 ) can be fully classified into the following
3 types:
215
1. µi + µj is an eigenvalue for every 1 ≤ i ≤ j ≤ d − r, and Ûij := vi vj> + vj vi> is
an associated left eigenvector.
r(r−1)
2. ξp is an eigenvalue for every 1 ≤ p ≤ rd − 2
, and Ṽp := Ep U0> + U0 Ep> is
an associated right eigenvector.
Proof of Theorem 6.13.5. We first prove each item respectively, and then prove that
these are all the eigenvalues of J(W0 ).
Ûij W0 = 0
W0 Ûij = 0
So we have
D E D E
J(W0 )[∆, Ûij ] = (λi + λj ) ∆, Ûij − D2 f (W0 )[∆, 0] = (λi + λj ) ∆, Ûij ,
216
Right-multiplying both sides by U0> , we get
where the second equality uses the fact that ∇f (W0 )U0 = 0 since U0 is a critical point.
Taking both sides into sz(·) gives
= J(W0 )[Ṽp ],
Proof for Item 3. Since ∇f (W ) is symmetric, g(W ) is also symmetric. For any
∆ = −∆> ,
J(W0 )[∆] = J(W0 )[∆> ] = J(W0 )[−∆].
(d − r)(d − r + 1) r(r − 1) d(d + 1)
+ rd − = = dim(Sd ).
2 2 2
D E
Also note that Ûij , Ṽp = 2vi> Ep U0> vj + 2vj> Ep U0> vi = 0. By Lemma 6.13.4, Items 1
and 2 give all the eigenvalues of h, and thus Items 1, 2, 3 give all the eigenvalues of
J(W0 ).
Proof for Theorem 6.5.6. Since W (t) is always symmetric, it suffices to study the
dynamics of the lower triangle of W (t). For any symmetric matrix W ∈ Sd , let
d(d+1) d(d+1)
vecLT (W ) ∈ R 2 be the vector consisting of the 2
entries of W in the lower
triangle, permuted according to some fixed order.
Let g(W ) be the function defined in (6.2), which always maps symmetric matrices to
d(d+1) d(d+1)
symmetric matrices. Let g̃ : R 2 →R 2 be the function such that g̃(vecLT (W )) =
vecLT (g(W )) for any W ∈ Sd . For W (t) evolving with (6.2), we view vecLT (W (t)) as
a dynamical system.
d
vecLT (W (t)) = g̃(vecLT (W (t))).
dt
The proof for Theorem 6.5.8 relies on the following Lemma on the gradient flow around
a local minimizer:
Lemma 6.14.1. If x̄ is a local minimizer of L(x) and for all kx − x̄k2 ≤ r, x satisfies
Lojasiewicz inequality:
k∇L(x)k2 ≥ c (L(x) − L(x̄))µ
for some µ ∈ [1/2, 1), then the gradient flow x(t) = φ(x0 , t) converges to a point x∞
near x̄ if x0 is close enough to x̄, and the distance can be bounded by kx∞ − x̄k2 =
2(1−µ)
O(kx0 − x̄k2 ).
d 1−µ −µ dx
(L(x(t)) − L(x̄)) = (1 − µ) (L(x(t)) − L(x̄)) · ∇L,
dt dt
dx
= −(1 − µ) (L(x(t)) − L(x̄))−µ · k∇Lk2 ·
dt 2
dx
≤ −(1 − µ)c .
dt 2
Rt dx 1 2(1−µ)
Therefore, kx(t) − x0 k2 ≤ 0 dt 2
dt ≤ (1−µ)c
L(x0 )1−µ = O(kx0 − x̄k2 ). If we
choose kx(t) − x̄k2 small enough, then kx(t) − x̄k2 ≤ kx(t) − x0 k2 + kx0 − x̄k2 =
2(1−µ) R +∞ dx
O(kx0 − x̄k2 ) < r, and thus 0 dt 2
dt is convergent and finite. This implies
2(1−µ)
that x∞ := limt→+∞ x(t) exists and kx∞ − x̄k2 = O(kx0 − x̄k2 ).
such that u(t)u(t)> = W1G (t) and u(t) satisfies (6.1), i.e., du
dt
= −∇L(u), where
219
L : Rd → R, u 7→ 12 f (uu> ). If W1G (t) does not diverge to infinity, then so does u(t).
This implies that there is a limit point ū of the set {u(t) : t ≥ 0}.
Let U := {u : L(u) ≥ L(ū)}. Since L(u(t)) is non-increasing, we have u(t) ∈ U for
all t. Note that ū is a local minimizer of L( · ) in U. By analyticity of f ( · ), Lojasiewicz
inequality holds for L( · ) around ū [126]. Applying Lemma 6.14.1 for L restricted
on U, we know that if u(t0 ) is sufficiently close to ū, the remaining length of the
trajectory of u(t) (t ≥ t0 ) is finite and thus limt→+∞ u(t) exists. As ū is a limit point,
this limit can only be ū. Therefore, W 1 := limt→+∞ W1G (t) = ūū> exists.
If W 1 is a minimizer of f ( · ), U = (ū, 0, · · · , 0) ∈ Rd×d is also a minimizer
of L : Rd×d → R, U 7→ 12 f (U U > ). By analyticity of f ( · ), Lojasiewicz inequality
holds for L( · ) around U . For every > 0, we can always find a time t such that
ku(t ) − ūk2 ≤ /2. On the other hand, by Theorem 6.5.6, there exists a number α
such that for every α < α ,
1 1
φ(Wα , T (Wα ) + t ) − W1G (t ) ≤ /2, where T (W ) := log .
2 2µ1 W, u1 u>
1
220
6.14.3 Proof for Theorem 6.5.11
1 1
G
W (t) := lim φ W + v1 v1> , log + t .
→0 2µ1
For {Wα } ⊆ S+
d , if there exists time Tα ∈ R for every α so that φ(Wα , Tα ) converges
to W with positive alignment with the top principal component v1 v1> as α → 0, then
∀t ∈ R, !
1 1
lim φ Wα , Tα + log >
+ t = W G (t).
α→0 2µ1 φ(Wα , Tα ), v1 v1
d
vecLT (W (t)) = g̃(vecLT (W (t))).
dt
>
Let W = U U be a factorization of W , where U ∈ Rd×r . Since W is a local minimizer
of f ( · ) in S+
d,≤r , U is also a local minimizer of L : R
d×r
→ R, U 7→ 12 f (U U > ). Since W
is not a minimizer of f ( · ) in S+
d , by Lemma 6.13.1, U is full-rank. By Theorem 6.13.5,
221
Since U is a local minimizer, ξp ≤ 0 for all p. If µ1 > µ2 , then 2µ1 is the unique
largest eigenvalue, and Theorem 6.13.5 shows that vecLT (v1 v1> ) is a left eigenvector
associated with 2µ1 . The eigenvalue gap γ̃ := 2µ1 − max{µ1 + µ2 , max{ξp : 1 ≤ p ≤
r(r−1)
rd − 2
}} ≥ 2µ1 − max{µ1 + µ2 , 0}.
Also note that φ(Wα , Tα ) − W , v1 v1> = φ(Wα , Tα ), v1 v1> because W , v1 v1> =
0 by (6.31). If φ(Wα , Tα ) converges to W as α → 0, then it has positive alignment
hφ(Wα ,Tα ),v1 v> i
with v1 v1> iff lim inf α→0 φ(W ,T )−W1 > 0. Then it is easy to translate Theorem 6.5.3
k α α kF
to Theorem 6.14.2.
rem 6.5.10)
The proof for Theorem 6.5.10 is based on the following two theorems from the literature.
Theorem 6.14.3 (Theorem 3.1 in Du and Lee 127). Let f : Rd×d → R be a C 2 convex
function. Then L : Rd×k → R, L(U ) = f (U U > ), k ≥ d satisfies that (1). Every
local minimizer of L is also a global minimizer; (2). All saddles are strict. Here
saddles denote those stationary points whose hessian are not positive semi-definite
3
(thus including local maximizers).
Theorem 6.14.5 (GF only finds minimizers, a continuous analog of Theorem 6.14.4).
Let f : Rd → Rd be a C 1 -smooth function, and φ : Rd × R → Rd be the solution of the
Pn
3
Though the original theorem is proven for convex functions of form i=1 `(xi U U > x> i , yi ), where
`(·, ·) is C 2 convex for its first variable. By scrutinizing their proof, we can see the assumption can be
relaxed to f is C 2 convex.
222
following differential equation,
dφ(x, t)
= f (φ(x, t)), φ(x, 0) = x, ∀x ∈ Rd , t ∈ R.
dt
Then the set of initial points that converge to a unstable critical point has measure
zero, µ x0 : limt→∞ φ(x0 , t) ∈ Uf∗ = 0, where Uf∗ = {x : f (x) = 0, λ1 (Df (x)) > 0}
Proof of Theorem 6.14.5. By Theorem 1 in Section 2.3, Perko [128], we know φ(·, ·)
is C 1 -smooth for both x, t. We let g(x) = φ(x, 1), then we know g −1 (x) = φ(x, −1)
and both g, g −1 are C 1 -smooth. Note that Dg −1 (x) is the inverse matrix of Dg(x). So
both of the two matrices are invertible. Thus we can apply Theorem 6.14.4 and we
know µ {x0 : limk→∞ g k (x0 ) ∈ A∗g } = 0.
Note that if limt→∞ φ(x, t) exists, then limk→∞ g k (x) = limt→∞ φ(x, t). It remains
to show that Uf∗ ⊆ A∗g . For f (x0 ) = 0, we have φ(x0 , t) = x0 and thus g(x0 ) = x0 . Now
it suffices to prove that λ1 (Dg(x0 )) > 1. For every t ∈ [0, 1], by Corollary of Theorem
∂
1 in Section 2.3, Perko [128], we have ∂t
Dφ(x, t) = Df (φ(x, t))Dφ(x, t), ∀x, t. Thus,
∂
Dφ(x0 , t) = Df (φ(x0 , t))Dφ(x0 , t) = Df (x0 )Dφ(x0 , t).
∂t
Solving this ODE gives Dg(x0 ) = Dφ(x, 1) = eDf (x0 ) Dφ(x, 0) = eDf (x0 ) , where the
last equality is due to Dφ(x, 0) ≡ I, ∀x. Combining this with λ1 (Df (x0 )) > 0, we
have λ1 (Dg(x0 )) > 1.
Thus we have Uf∗ := {x0 : f (x0 ) = 0, λ1 (Df (x0 )) > 0} ⊆ A∗g , which implies that
{x0 : limt→∞ φ(x0 , t) ∈ U ∗ } ⊆ {x0 : limk→∞ g k (x0 ) ∈ A∗g }
223
minimizers; (2). For any random initialization, GF (6.1) converges to strict saddles
of L(U ) with probability 0.
Proof of Theorem 6.5.10. For (1), by Theorem 6.14.3, we immediately know all the
stationary points of L( · ) are either global minimizers or strict saddles. (2) is just a
direct consequence of Theorem 6.14.5 by setting f in the above proof to −∇L.
Lemma 6.15.1. If W (0) 0, then W (t) 0 and rank(W (t)) = rank(W (0)) for all
t.
Proof. Note that we can always find a set of balanced Ui (t), such that U1 (t) . . . UL (t) =
W (t), d2 = d3 = · · · = dL = rank(W (t)) and write the dynamics of W (t) in the space
of {Ui }Li=1 . Thus it is clear that for all t0 , rank(W (t0 )) ≤ rank(W (t)). We can apply
the same argument for t0 and we know rank(W (t)) ≤ rank(W (t0 )). Thus rank(W (t))
is constant over time, and we denote it by k. Since eigenvalues are continuous matrix
functions, and ∀t, λi (W (t)), i ∈ [k] 6= 0. Thus they cannot change their signs and it
must hold that W (t) 0.
aP −bP
Lemma 6.15.2. ∀a, b, P ∈ R, if a > b ≥ 0, P ≥ 1, then a−b
≤ P aP −1 .
∀M, N 0,
kDF (N )[M ]kF ≤ P kN kP2 −1 kM kF ,
224
F (N +tM )−F (N )
where DF (N )[M ] := limt→0 t
is the directional derivative of F along M .
kF (N + tM ) − F (N )kF
kDF (N )[M ]kF = lim
t→0 t
>
F (Σ + tU M U ) − F (Σ) F
= lim
t→0 t
= DF (Σ)[U > M U ] F
.
Therefore, it suffices to prove the lemma for the case where N is diagonal, i.e., N = Σ.
1
q
Assume P = p
, where p, q ∈ N and q ≥ p > 0. Define G(N ) = N p . Then
G(Σ)p = Σ. Taking directional derivative on both sides along direction M , we have
p
X
G(Σ)i−1 DG(Σ)[M ]G(Σ)p−1 = M,
i=1
So we have
mij
[DG(Σ)[M ]]ij = P k−1 p−k .
p p p
k=1 σi σj
q k−1 q−k
X
[DH(G(Σ))[M ]]ij = mij σi p σj p .
k=1
That is,
k−1 q−k
Pq p
k=1 σi σj p
[DF (Σ)[M ]]ij = mij P k−1 p−k .
p p p
k=1 σi σj
225
q−p
σiP − σjP
|[DF (Σ)[M ]]ij | = |mij | ≤ |mij | P σiP −1 ≤ |mij | P kΣkP2 −1 .
σi − σj
where the first inequality is by Lemma 6.15.2. Thus we conclude the proof.
n o
AP − B P F
≤ P kA − BkF max kAkP2 −1 , kBkP2 −1 .
Proof. Since both sides are continuous in P and Q is dense in R, it suffices to prove
the lemma for P ∈ Q. Let ρ := max {kAk2 , kBk2 } and F (M ) = M P . Define
N : [0, 1] → S+
d , N (t) = (1 − t)A + tB, we have
Therefore,
Z 1
dF (N (t))
kF (N (1)) − F (N (0))kF ≤ dt
0 dt F
Z 1
= kDF (N (t))[B − A]kF dt
t=0
≤ P kA − BkF ρP −1 ,
226
For a locally Lipschitz function f ( · ), the Clarke subdifferential [129–131] of f at
any point x is the following convex set
∂ ◦ f (x) n o
:= co lim ∇f (xk ) : xk → x, f is differentiable at xk ,
∂x k→∞
Theorem 6.15.6 (Theorem 5.3, Hiriart-Urruty and Lewis 132). The Clarke subdif-
ferential of the eigenvalue function λm is given below, where co denotes the convex
hull:
∂ ◦ λm (M )
= co{vv > : M v = λm (M )v, kvk2 = 1}.
∂M
L−1
dW X 2i 2i+2
=− W L ∇f (W )W 2− L . (6.33)
dt i=0
227
Proof for Lemma 6.6.1. Suppose W (t) is a symmetric solution of (6.11). By
Lemma 6.15.1, we know W (t) also satisfies (6.33). Below we prove the lemma for
even L and odd L respectively:
1
• L is odd: let R(t) be the solution of the following ODE with R(0) := (W (0)) L .
1
Note we do not define R(t) by (W (t)) L .
L−1
dR X
=− (−1)i Ri ∇f (RL )RL−1−i . (6.34)
dt i=0
L−1 L−1 X
L−1
dRL X
j dR L−1−j
X
=− R R = (−1)i Ri+j ∇f (RL )R2L−2−i−j
dt j=0
dt j=0 i=0
L−1 X k
!
X
=− (−1)i Rk ∇f (RL )R2L−2−k
k=0 i=0
2L−2 L−1
!
X X
− (−1)i Rk ∇f (RL )R2L−2−k
k=L i=k−L+1
k 2+k
X
=− (RL ) L ∇f (RL )(RL )2− L
0≤k≤2L−2
k even
L−1
2+2i
L 2i
X
=− (R ) ∇f (RL )(RL )2−
L L .
i=0
dM dR dR
=R + R = −∇f (M L/2 )M L/2 − M L/2 ∇f (M L/2 ),
dt dt dt
228
• L is even: let M f(0) := (W (0)) L2 .
f(t) be the solution of the following ODE with M
f(t) by (W (t)) L2 .
Note we do not define M
dMf
f)L/2 )(M
f)L/2 − (M
f)L/2 ∇f ((M
f)L/2 ).
= −∇f ((M (6.35)
dt
f)L/2 L/2−1
d(M
L−1
f)j dM (M
X f X
= (M f)L/2−1−j = − f)j ∇f ((M
(M f)L/2 )(M
f)L−1−j
dt j=0
dt j=0
Now we turn to prove Theorem 6.6.2. Let P = L/2. Then (6.12) can be rewritten as
dM
= − ∇f (M P )M P + M P ∇f (M P ) .
(6.36)
dt
The following lemma about the growth rate of λk (M ) is used later in the proof.
Lemma 6.15.7. Suppose M (t) satisfies (6.36), we have for any T 0 > T , and k ∈ [d],
Z T0
0
λk (M (T )) − λk (M (T )) ≤ 2λk (M (t))P k∇f (M P (t))k2 dt. (6.37)
T
and
Z T0
1 0
λ1−P 1−P
2k∇f (M P (t))k2 dt.
(M (T )) − λ (M (T )) ≤ (6.38)
P −1 k k
T
229
Proof. Since λk (M (t)) is locally Lipschitz in t, by Rademacher’s theorem, we know
λk (M (t)) is differentiable almost everywhere, and the following holds
Z T0
0 dλk (M (t))
λk (M (T )) − λk (M (T )) = dt.
T dt
dλk (M (t))
When dt
exists, we have
∂ ◦ λk (M )
dλk (M (t)) dM (t)
∈ G, :G∈
dt dt ∂M
∂ ◦ λk (M )
P P
= 2λk (M (t)) G, −∇f (M (t)) : G ∈
∂M
To prove Theorem 6.6.2, it suffices to consider the case that M (0) = α̂I where
α̂ := α1/P . WLOG we can assume −∇f (0) = diag(µ1 , . . . , µd ) by choosing a suitable
standard basis. By assumption in Theorem 6.6.2, we have µ1 > max{µ2 , 0} and
µ1 = k∇f (0)k2 . We use φm (M0 , t) to denote the solution of M (t) when M (0) = M0 .
Let R > 0. Since f ( · ) is C 3 -smooth, there exists β > 0 such that
230
Lemma 6.15.8. For any x ∈ [α̂, R] we have
Z α̂−(P −1)
−(P −1) −(P −1) 1
α̂ −x − Fα̂ (x) = 1− dz ≥ 0.
x−(P −1) 1 + κz −P/(P −1)
≤ κ(P − 1)x,
Lemma 6.15.9. Let M0 be a PSD matrix with kM0 k2 ≤ 1. For M (t) := φm (α̂M0 , t)
and t ≤ Tα̂ (c),
1
kM (t)k2 = λ1 (M (t)) ≤ gα̂,c (t) P −1 .
Proof. Since k∇f (M P )k2 ≤ k∇f (0)k2 + βkM kP2 ≤ µ1 + β(λ1 (M ))P , by Lemma 6.15.7,
we have
Z t
λ1 (M (t)) ≤ λ1 (M (0)) + 2(µ1 + β(λ1 (M (τ )))P )(λ1 (M (τ )))P dτ
0
Z t
dτ
= α̂ + 2µ1 (P − 1) 0
0 Fα̂ (λ1 (M (τ ))
So
Fα̂ (λ1 (M (t))) ≤ 2µ1 (P − 1)t.
231
1
If kM (t)k2 < α̂, then kM (t)k2 ≤ gα̂,c (t) P −1 . If kM (t)k2 ≥ α̂, then by Lemma 6.15.8,
Fα̂ (kM (t)k2 ) ≤ 2µ1 (P − 1)Tα̂ (c) = α̂−(P −1) − κ(P − 1)c − c−(P −1) ≤ Fα̂ (c),
so kM (t)k2 ≤ c for all t ≤ Tα̂ (c). Applying Lemma 6.15.8 again, we have
−(P −1)
α̂−(P −1) − kM (t)k2 ≤ F (kM (t)k2 ) + κ(P − 1)c ≤ 2µ1 (P − 1)t + κ(P − 1)c,
1
which implies kM (t)k2 ≤ gα̂,c (t) P −1 by definition.
dMc
P P
= − ∇f (0)M + M ∇f (0) .
c c
dt
We use φ̂m (M
c0 , t) to denote the solution of M
c(t) when M
c(0) = M
c0 . For diagonal
matrix M
c0 , M
c(t) is also diagonal for any t, and it is easy to show that
1
P −1
1
e>
i M0 ei 6= 0,
c
−(P −1)
e> (α̂e>i M
c0 ei ) −2µi (P −1)t
i M (t)ei = (6.39)
c
e>
0
i M0 ei = 0.
c
for diagonal initialization, i.e., (6.39) (note that the identity matrix is diagonal). And
this is the main barrier for extending our two-phase analysis to the case of general
initialization when L ≥ 3. In Section 6.6.1, we give a more detailed discussion on this
barrier.
232
Lemma 6.15.11. Let M0 be a diagonal PSD matrix with kM0 k2 ≤ 1. For M (t) :=
φm (α̂M0 , t) and M
c(t) := φ̂m (α̂M0 , t), we have
dD
= 2 ∇f (0) M P − M
cP + ∇f (M P ) − ∇f (0) M P
dt F F
P
cP kF + k∇f (M P ) − ∇f (0)kF kM P k2
≤ 2 k∇f (0)k2 kM − M
P −1 P −1 2P
≤ 2 µ1 P max{kM k2 , kM k2 }kDkF + β kM k2 ,
c
Z t Z t
dD(τ ) 2P
kD(t)kF ≤ dτ ≤ 2 µ1 P gα̂,r (τ ) kD(τ )kF + βgα̂,r (τ ) P −1 dτ.
τ =0 dτ F 0
So
!
Z Tα̂ (r) Z Tα̂ (r)
2P
kD(Tα̂ (r))kF ≤ 2βgα̂,r (t) P −1 exp 2µ1 P gα̂,r (τ )dτ dt
0 t
Z Tα̂ (r)
2P P gα̂,r (Tα̂ (r))
= 2βgα̂,r (t) P −1 exp ln dt
0 P −1 gα̂,r (t)
Z Tα̂ (r)
P P
= 2βgα̂,r (t) P −1 gα̂,r (Tα̂ (r)) P −1 dt
0
1 1 P
= 2β · gα̂,r (Tα̂ (r)) P −1 · gα̂,r (Tα̂ (r)) P −1
2µ1
P +1
= κgα̂,r (Tα̂ (r)) P −1
= κrP +1 .
233
Lemma 6.15.12. Let M (t) = φm (α̂M0 , t), M f0 , t). If max{kM0 k2 , kM
f(t) = φm (α̂M f0 k2 } ≤
r P P
kM (t) − M
f(t)kF ≤ e2κr kM (0) − M
f(0)kF .
α̂
dD
=2 ∇f (M P ) M P − M
fP + ∇f (M P ) − ∇f (M
fP ) MfP
dt F F
≤ 2 k∇f (M P )k2 kM P − M
fP kF + βkM P − M
fP kF kM
fP k2
≤ 2 µ1 + βkM k2 + βkM k2 P max{kM kP2 −1 , kM
f P P fkP −1 }kDkF ,
2
Lemma 6.15.13. For every t ∈ (−∞, +∞), M (t) exists and Mα̂G (t) converges to
M (t) in the following rate:
234
α̂−(P −1)
Case 1. Fix t ∈ (−∞, T̄ ]. Then 2µ1 (P −1)
+ t ≤ Tα̂ (c). Let α̃ be the unique number
such that κ(P − 1)α̃ + α̃−(P −1) = α̂−(P −1) . Let α̂0 < α̂ be an arbitrarily small number.
(α̂0 )−(P −1) −α̂−(P −1)
Let t0 := Tα̂0 (α̃) = 2µ1 (P −1)
. By Lemma 6.15.11 and (6.39), we have
c P P
φm (α̂0 e1 e> >
1 , t0 + t) − φ(α̂e1 e1 , t) F
≤ e2κc · O(α̃P +1 ) = O(α̃) = O(α̂).
α̃
This implies that {Mα̂G (t)} satisfies Cauchy’s criterion for every t, and thus the limit
M (t) exists for t ≤ T̄ . The convergence rate can be deduced by taking limits for
α̂0 → 0 on both sides.
= O(α̂),
α̂−(P −1)
1
φm α̂I, + t − M (t) = O(α̂ P +1 ), (6.40)
2µ1 (P − 1) F
235
and for any 2 ≤ k ≤ d,
α̂−(P −1)
λk φm α̂I, +t = O(α̂). (6.41)
2µ1 (P − 1)
α̂−(P −1)
Proof. Let Mα̂ (t) := φm α̂I, 2µ 1 (P −1)
+ t . Again we let c be a sufficiently small
−κ(P −1)c−c−(P −1)
constant and T̄ := 2µ1 (P −1)
. We prove in the cases of t ∈ (−∞, T̄ ] and t > T̄
respectively.
1
Case 1. Fix t ∈ (−∞, T̄ ]. Let α̂1 := α̂ P +1 . Let α̃1 be the unique number such that
−(P −1)
−(P −1) −(P −1) α̂−(P −1) −α̂1
κ(P − 1)α̃1 + α̃1 = α̂1 . Let t0 := Tα̂ (α̃1 ) = 2µ1 (P −1)
. Then
= O(α̃1P +1 + α̂)
= O(α̂).
By Lemma 6.15.9, kφm (α̂0 I, t0 )k2 ≤ α̃1 . Then by Lemma 6.15.12, we have
Combining this with the convergence rate for Mα̂G1 (t) proves the bound (6.40).
For (6.41), by Lemma 6.15.7, we have
Z T̄
λ1−P
k (Mα̂ (T̄ )) − λ1−P
k (Mα̂ (t0 )) ≤ 2(P − 1) ∇f (Mα̂P (t)) 2
dt
t0
Z T̄
≤ 2(P − 1)(µ1 + β kMα̂ (t)kP2 ))dt (6.42)
t0
κ 1
≤ −2(P − 1) µ1 (t − T1 ) + · gα̂,c (t) P −1 .
2
236
By Lemma 6.15.11, λ1 (Mα̂ (T̄ )) = Mα̂ (T̄ ) 2
= c + O(cP +1 ). For k ≥ 2,
λ1−P
k (Mα̂ (T̄ )) − λ1−P
k (Mα̂ (T̄ + τ ))
Z T̄ +τ
≤ 2(P − 1) ∇f (Mα̂P (t)) 2 dt
T̄
Z T̄ +τ
2(P − 1) β Mα̂P (t) − (M G )P (t) G P
≤ 2
+ ∇f (M ) (t) 2
dt
T̄
Z T̄ +τ
1 P
≤ 2(P − 1)(O(α̂ 1+P ) + β M G (t) 2
)dt
T̄
≤ O(1).
Thus λ1−P
k (Mα̂ (T̄ + τ )) = Ω(α̂−(P −1) ), that is, λk (Mα̂ (T̄ + τ )) = O(α̂), ∀k ≥ 2.
237
P
Proof of Theorem 6.6.2. Note that M (t) = W (t) and
P !
α̂−(P −1) α̂−(P −1)
φm α̂I, +t = φ αI, +t .
2µ1 (P − 1) 2µ1 (P − 1)
α−(1−1/P )
φ αI, + t − W (t)
2µ1 (P − 1) F
−(P −1)
P
α̂ P
≤ φm α̂I, +t − M (t)
2µ1 (P − 1)
F
−(P −1)
P −1
α̂−(P −1)
α̂
≤ P φm α̂I, + t − M (t) max φm α̂I, +t , M (t) 2
2µ1 (P − 1) F 2µ1 (P − 1) 2
1 1
= O(α̂ P +1 )O(1) = O(α P (P +1) ),
and for 2 ≤ k ≤ d,
In this section, we will present the theorems that guarantee the linear convergence to
a minimizer W0 of f ( · ) if the dynamics (6.43) is initialized sufficiently close to W0 ,
i.e., kW (0) − W0 kF is sufficiently small. In Section 6.16.3, we will apply this result to
prove Theorem 6.6.4.
L−1
dW X 2i 2i+2
=− W L ∇f (W )W 2− L =: g(W ). (6.43)
dt i=0
238
Throughout this section, we assume rank(W0 ) = k and use m := λk (W0 ) to denote
the k-th smallest non-zero eigenvalue of W0 . The tangent space of manifold of rank-k
symmetric matrices at W0 is T = {V W0> + W0 V > : V ∈ Rd×d }. It can be shown that
k(k+1) k(2d−k+1)
dim(T ) = k(d − k) + 2
= 2
.
Let J(W ) be the Jacobian of g(W ) in (6.43). For depth-2 case, we have shown
that T is an invariant subspace of J(W0 ) in Theorem 6.13.5, property 2. This can
be generalize to the deep case where L ≥ 3. Therefore, we can use J(W0 )|T : T → T
2
to denote the linear operator J(W0 ) restricted on T . We also define Πd1 (W ) as the
2 2
projection of W ∈ Rd×d on T , and Πd2 (W ) := W − Πd1 (W ).
Towards showing the main convergence result in the section, we make the following
assumption.
Assumption 6.16.1. Suppose J(W0 )|T is diagonalizable and all eigenvalues are
negative real numbers.
W0 is a minimizer, so it is clear that J(W0 )|T has no eigenvalues with positive real
parts (otherwise there is a descending direction of f ( · ) from W0 , since the loss f ( · )
strictly decreases along the trajectory of (6.43)). If further Assumption 6.16.1 holds,
then we know J(W0 )|T : T → T can be diagonalized as J(W0 )|T [ · ] = V(ΣV −1 ( · )),
where Σi = diag(−µ1 , . . . , −µdim(T ) ), V : Rdim(T ) → T , V(x) = dim(T )
P
i=1 xi Vi , and Vi is
the eigenvector associated with eigenvalue −µi .
As shown in Theorem 6.16.3 below, this assumption implies that if W (0) is rank-k
and is sufficiently close to W0 , then kW (t) − W0 kF ≤ Ce−µ1 t for some constant C.
For depth-2 case, the above assumption is equivalent to that L(U0 ) is “strongly convex”
at U0 , except those 0 eigenvalues due to symmetry, by property 2 of Theorem 6.13.5).
For the case where L ≥ 3, because this dynamics is not gradient flow, in general it
does not correspond to a loss function and strongly convexity does not make any
sense. Nevertheless, in experiments we do observe linear convergence to W0 , so this
assumption is reasonable.
239
6.16.1 Rank-k Initialization
2
2 2
−1
kW kV := V Πd1 (W ) , kW kF,1 := Πd1 (W ) , kW kF,2 := Πd2 (W ) .
F F F
The reason for such definition of norms, as we will see later, is that the norm (or
the difference) in the tangent space of the manifold of symmetric rank-r matrices,
kW − W 0 kF,1 , dominates that in the orthogonal complement of the tangent space,
kW − W 0 kF,2 , when both W, W 0 get very close to the W0 (see a more rigorous statement
in Lemma 6.16.2). WLOG, we can assume
k · kF,1
≤ k · kV ≤ k · kF,1 ,
K
for some constant K, which may depend on f and W0 . This also implies that
k · kV ≤ k · kF . Below we also assume for sufficiently small R, and any W such that
kW − W0 kF ≤ R, we have k∇f (W )k2 ≤ ρ and kJ(W )[∆]kF ≤ β k∆kF for any ∆.
In the proof below, we assume such properties hold as long as we can show the
boundedness of W (t) − W0 .
5r
kW − W 0 kF,2 ≤ kW − W 0 kF,1 .
m
5 kW − W 0 k2F,1
kW − W0 kF,2 ≤ .
m
240
Proof. WLOG we can assume W0 is only non-zero in the first k dimension, i.e.,
[W0 ]ij = 0, for all i ≥ k + 1, j ≥ k + 1. We further denote W and W 0 by
> 0 0>
A B A B
W = and W 0 = ,
B C B0 C 0
kW − W 0 kF,2
= kC − C 0 kF
−1 >
= BA−1 B > − B 0 A0 B 0
F
−1 > −1 >
≤ kB − B 0 kF kA−1 B > kF + kBA−1 kF kA0 − AkF kA0 B 0 kF + kB 0 A0 kF kB > − B 0 kF
2
0 2r 0 2r 2r
≤ kW − W kF,1 + kW − W kF,1 + kW − W 0 kF,1
m m m
5r
≤ kW − W 0 kF,1 .
m
Theorem 6.16.3 (Linear convergence of rank-k matrices). Suppose that rank(W (0)) =
rank(W0 ) = k and
m µ1
kW (0) − W0 kV ≤ R := max , 2 ,
2K K (29β + 10ρ/m)
241
2 2
Proof. For convenience, we define W1 (t) := Πd1 (W (t) − W0 ) , W2 (t) := Πd2 (W (t) − W0 ) =
2
Πd2 (W (t)). We also use h·, ·iV −1 = hV −1 (·) , V −1 (·)i for short.
D E
2
For the first term Πd1 (J(W0 )[W1 (t)]) , W1 (t) , we know W1 (t) ∈ T , and T is an
V −1
invariant space of J(W0 ). Recall J(W0 )|T [·] = V (ΣV −1 (·)), we have
D 2 E
2 Πd1 (J(W0 )[W1 (t)]) , W1 (t) = 2 ΣV −1 (W1 (t)) , V −1 (W1 (t)) ≤ −2µ1 kW1 (t)kF,1 .
V −1
For the second term 2β kJ(W0 )[W2 (t)]kV kW1 (t)kV , we have
2 kJ(W0 )[W2 (t)]kV ≤ 2 kJ(W0 )[W2 (t)]kF ≤ 2 kJ(W0 )k2 kW2 (t)kF = 2ρ kW2 (t)kF .
242
For the third term 2 kg(W (t) − W0 ) − J(W0 )[W (t) − W0 ]kV kW1 (t)kV , we have
Thus we have shown the following. Note so far we have not used the assumption that
W is rank-k.
d kW1 (t)k2V
≤ −2µ1 kW1 (t)k2V +2 kW1 (t)kV ρ kW2 (t)kF + 2βK 2 kW1 (t)k2V + 2β kW2 (t)k2F ,
dt
that is,
Since µ1 < 0, kW1 (t)kV decreases for [0, T ). Thus T must be ∞, otherwise
kW1 (T )kV = limt→T − kW1 (t)kV < R1 . Contradiction.
243
µ1
Therefore, for any t ∈ [0, ∞), we have kW1 (t)kV ≤ kW1 (0)kV e− 2
t
. That is,
Z ∞
2 2R
kW1 (t)kV dt ≤ kW1 (0)kV ≤ .
0 µ1 µ1
Z ∞
K2
kW (t)kV = kW1 (t)kV ≤ kW1 (0)kV exp −µ1 t + (29β + 10ρ/m) kW1 (t)kV dt
2 0
K 2R
≤ kW1 (0)kV exp −µ1 t + (29β + 10ρ/m)
µ1
=: C kW (0)kV e−µ1 t ,
We use M (t) to denote the top-k components of W (t) in SVD, and N (t) to denote
the rest part, i.e., W (t) − M (t). One can think M (t) as the main part and N (t) as
the negligible part.
Below we show that for deep overparametrized matrix factorization, where W (t)
satisfies (6.43), if the trajectory is initialized at some W (0) in a small neighborhood
of the k-th critical point W0 of deep GLRL, and W (0) is approximately rank-k, in
the sense that N (0) is very small, then inf t≥0 kW (t) − W0 kV is roughly at the same
magnitude of N (0).
Theorem 6.16.4 (Linear convergence of almost rank-k matrices, deep case). Suppose
W0 is a critical point of rank k and W0 satisfies Assumption 6.16.1, there exists
constants C0 and r, such that if C0 kN (0)kF ≤ kW1 (0)kV ≤ r, then there exists a time
T and constants C, C 0 , such that
244
(2). kW (T ) − W0 kF ≤ C 0 kN (0)kF .
λmin (W0 )
kM (t) − W0 kF,1 ≤ kW (t) − W0 kF,1 + kN (t)kF,1 ≤ ,
2
Thus we can pick constant C0 large enough and r small enough, such that for any
t ≥ 0, if C0 kN (t)kF ≤ kW1 (t)kV ≤ r, then it holds that:
1
• The spectral norm 2
k∇f (W (t))k2 ≤ k∇f (W0 )k2 =: ρ for all t ≥ 0.
2
κL x L −1 2 L−2
• ∀x < r, (L−2)ρ
> µ1
ln C2r0 x , where κL = 1 − 0.5 L .
Note these conditions can always be satisfied by some C0 and r because we can first
find 3 groups (C0 , r) to satisfy each individual condition, and then take the maximal
C0 and minimal r, it’s easy to check these conditions are still verified. And we let
TC0 ,r be the earliest time that such condition, i.e., C0 kN (t)kF ≤ kW1 (t)kV ≤ r fails.
245
µ1 t
Thus by (6.44), for t ∈ [0, TC0 ,r ), we have kW (t)kV = kW1 (t)kV ≤ kW1 (0)kV e− 2 =
µ1 t
kW (0)kV e− 2 . Thus (1) holds for any T smaller than TC0 ,r . If TC0 ,r = ∞, then clearly
we can pick a sufficiently large T , such that (2) holds. Therefore, below it suffices to
consider the case where TC0 ,r is finite. And we know the condition that fails must be
C0 kN (t)kF ≤ kW1 (t)kV , i.e. C0 kN (TC0 ,r )kF = kW1 (TC0 ,r )kV .
By (6.38) in Lemma 6.15.7, we have
2 2
−1 −1
kN (0)k2L − kN (t)k2L ≤ (L − 2)ρt.
2 −1
2 2
κL kN (0)k2L −1 −1
Define T 0 := (L−2)ρ
, we know for any t < T 0 , we have kN (0)k2L − kN (t)k2L ≤
2
−1
κL kN (t)k2L . That is,
2
−1 h
kN (t)k2L 1 L−2 L−2
i kN (t)k2
2 ∈ 1 − κL , = 0.5 L , 0.5− L =⇒ ∈ [1/2, 2].
kN (0)k2L
−1 1 − κL kN (0)k2
C0 0 0
kN (0)k2 ≤ C0 kN (T 0 )kF ≤ kW1 (T 0 )kV ≤ e−µ1 T /2 kW1 (0)kV ≤ e−µ1 T /2 r.
2
2 −1
κL kN (0)k2L
Therefore, (L−2)ρ
= T0 ≤ 2
µ1
ln C0 kN2r(0)k , which contradicts to the definition of C0
2
and r.
As a result, we have
√
2C0 d kN (0)k2 ≥ 2C0 kN (0)kF ≥ C0 kN (Tc0 ,r )kF = kW1 (TC0 ,r )kV
and therefore,
2 kW1 (0)kV
TC0 ,r ≤ ln √ .
µ1 2 dC0 kN (0)kF
246
Thus by Lemma 6.16.2, we know
= O(kN (0)kF ).
Proof for Theorem 6.6.4. Let C0 , r be the constants predicted by Theorem 6.16.4 w.r.t.
to W (∞). We claim that we can pick large enough constant T , and α0 sufficiently
small, such that for all α ≤ α0 , the initial condition in Theorem 6.16.4 holds, i.e.
−(P −1)
C0 kN (0)kF ≤ kW1 (0)kV ≤ r, where W (0) := φ αI, 2µα−1 (P −1) + T .
1
247
Chapter 7
7.1 Introduction
Implicit bias refers to the phenomenon in machine learning that the solution obtained
from loss minimization has special properties that were not implied by value of the
loss function and instead arise from the trajectory taken in parameter space by the
optimization. Quantifying implicit bias necessarily has to go beyond the traditional
248
black-box convergence analyses of optimization algorithms. Implicit bias can explain
how choice of optimization algorithm can affect generalization [20, 25, 133].
Many existing results about implicit bias treat training (in the limit of infinitesimal
step size) as a differential equation or process {x(t)}t≥0 ⊂ RD . To show the implicit
bias of x(t), the idea is to show for another (more intuitive or better understood)
process {w(t)}t≥0 ⊂ Rd that x(t) is simulating w(t), in the sense that there exists a
mapping G : RD → Rd such that w(t) = G(x(t)). Then the implicit bias of x(t) can
be characterized by translating the special properties of w(t) back to x(t) through
G. A related term, implicit regularization, refers to a handful of such results where
particular update rules are shown to lead to regularized solutions; specifically, x(t) is
simulating w(t) where w(t) is solution to a regularized version of the original loss.
The current paper develops a general framework involving optimization in the
continuous-time regime of a loss L : Rd → R that has been re-parametrized before
optimization1 as w = G(x) for some G : RD → Rd . Then the original loss L(w) in
the w-space induces the implied loss (L ◦ G)(x) ≡ L(G(x)) in the x-space, and the
gradient flow in the x-space is given by
Using w(t) = G(x(t)) and the fact that ∇(L ◦ G)(x) = ∂G(x)> ∇L(G(x)) where
∂G(x) ∈ Rd×D denotes the Jacobian of G at x, the corresponding dynamics of (7.1)
in the w-space is
1
Two examples from recent years, where G does not change expressiveness of the model, involve
(a) overparametrized linear regression where the parameter vector w is reparametrized (for example as
w = u 2 −v 2 [20]) and (b) deep linear nets [95] where a matrix W is factorized as W = W1 W2 · · · WL
where each W` is the weight matrix for the `-th layer.
249
Our framework is developed to fully understand phenomena in recent papers [20,
105, 134–138], which give examples suggesting that gradient flow in the x-space
could end up simulating a more classical algorithm, mirror descent (specifically, the
continuous analog, mirror flow) in the w-space. Recall that mirror flow is continuous-
time limit of the classical mirror descent, written as d∇R(w(t)) = −∇L(w(t))dt where
R : Rd → R ∪ {∞} is a strictly convex function [139, 140], which is called mirror map
or Lengendre function in literature. Equivalently it is Riemannian gradient flow with
metric tensor ∇2 R, an old notion in geometry:
250
2
cannot exist. If only one of such x can be reached by gradient flow, we must decide
which x it is in order to decide the value of ∇2 R using ∂G∂G> . Conversely, Amid and
Warmuth [137] raises the following question: for what Legendre function R can the
corresponding mirror flow be the result of gradient flow after some reparametrization
G? Answering the questions in both directions requires a deeper understanding of the
impact of parametrizations.
The following are the main contributions of the current paper:
• In the reverse direction, we use the famous Nash embedding theorem to show
that every mirror flow in the w-space with respect to some Legendre function R
simulates a gradient flow with commuting parametrization under some embedding
x = F (w) where F : Rd → RD and the parametrization G is the inverse of F
(Theorem 7.5.1). This provides an affirmative and fully general answer to the
2
To avoid such an issue, Amid and Warmuth [137] has to assume all the preimages of G at w
have the same ∂G(∂G)> and a recent paper Ghai et al. [141] assumes that G is injective.
251
question of when such reparametrization functions exist, giving a full answer to
questions raised in a more restricted setting in Amid and Warmuth [137].
Notations. We denote N as the set of natural numbers. For any positive integer n,
we denote {1, 2, . . . , n} by [n]. For any vector u ∈ RD , we denote its i-th coordinate
by ui . For any vector u, v ∈ RD and α ∈ R, we define u v = (u1 v1 , . . . , uD vD )>
and u α
= ((u1 )α , . . . , (uD )α )> . For any k ∈ N ∪ {∞}, we say a function f is C k
if it is k times continuously differentiable, and use C k (M ) to denote the set of all
C k functions from M to R. We use ◦ to denote the composition of functions, e.g.,
f ◦ g(x) = f (g(x)). For any convex function R : RD → R ∪ {∞}, we denote its
domain by dom R = {w ∈ RD | R(w) < ∞}. For any set S, we denote its interior by
int(S) and its closure by S.
We assume that the model has parameter vector w ∈ Rd and C 1 loss func-
tion L : Rd → R. Training involves a reparametrized vector x ∈ RD , which is a
reparametrization of w such that w = G(x) for some differentiable parametrization
function G, and the objective is L(G(x)). From now on, we follow the convention
that d is the dimension of the original parameter w and D is the dimension of the
reparametrized x. We also refer to Rd as the w-space and RD as the x-space.
In particular, we are interested in understanding the dynamics of gradient flow
under the objective L ◦ G on some submanifold M ⊆ RD . Most of our results also
generalize to the following notion of time-dependent loss.
253
Definition 7.3.1 (Time-dependent loss). A time-dependent loss Lt (w) is a function
piecewise constant in time t and continuously differentiable in w ∈ Rd , that is, there
exists k ∈ N, 0 = t1 < t2 < · · · < tk+1 = ∞ and C 1 loss functions L(1) , L(2) , . . . , L(k)
such that for each i ∈ [k] and all t ∈ [ti , ti+1 ),
Vector fields are a natural way to formalize the continuous-time gradient descent (a
good reference is Lee [173]). Let M be any smooth submanifold of RD . A vector field
X on M is a continuous map from M to RD such that for any x ∈ M , X(x) is in the
tangent space of M at x, which is denoted by Tx (M ). Formally, Tx (M ) := { dγ
dt t=0
|
∀ smooth curves γ : R → M, γ(0) = x}.
Definition 7.3.2 (Complete vector field; p.215, Lee 173). Let M be a smooth
submanifold of RD and X be a vector field on M . We say X is a complete vector
field on M if and only if for any initialization xinit ∈ M , the differential equation
dx(t) = X(x(t))dt has a solution on (−∞, ∞) with x(0) = xinit .
We say φtf (x) is well-defined at time t when the above differential equation has a
solution at time t. Moreover, for any differentiable function X : M → Rd , we denote
its Jacobian by
7.3.2 Parametrizations
Note that a regular parametrization G can become irregular when its domain is
changed. For example, G(x) = x2 is regular on R+ , but it is not regular on R as
∂G(0) = 0.
Given a C 2 parametrization G : M → Rd , for any x ∈ M and µ ∈ Rd , we define
when it is well-defined, i.e., the corresponding integral equation has a solution. For
any x ∈ M , we define the domain of ψ(x; ·) as
Indeed, under Assumption 7.3.5, we can show that for any x ∈ M , U(x) is a
hyperrectangle, as summarized in the following lemma. See Section 7.7 for a proof.
For any initialization xinit ∈ M , the set of points that are reachable via gradient
flow under G with respect to some time-dependent loss (see Definition 7.3.1) is a
subset of M that depends on G and xinit .
n o
Ωx (xinit ; G) = φµL11 ◦G ◦ φµL22 ◦G ◦ · · · ◦ φµLkk ◦G (xinit ) ∀k ∈ N, ∀i ∈ [k], Li ∈ C 1 (Rd ), µi ≥ 0 .
Next, we introduce some basic notions for mirror descent [139, 140]. We refer the
readers to Section 7.6 for more preliminaries on convex analysis.
∞.
In particular, we call R a mirror map if R further satisfies the following condition (see
p.298 in Bubeck et al. 175):
258
Usually ∇R is required to be surjective so that after a discrete descent step in the
dual space, it can be projected back to the primal space via (∇R)−1 . Nonetheless,
as long as ∇R(wk ) − η∇L(wk ) is in the range of ∇R, the above discrete update is
well-defined. In the limit of η → 0, (7.6) becomes the continuous mirror flow:
We recall a well-known implicit bias result for mirror flow (which holds for mirror
descent as well) [142], which shows that for a specific type of loss, if mirror flow
converges to some optimal solution, then the convergence point minimizes some convex
regularizer among all optimal solutions.
Theorem 7.3.9. Given any data Z ∈ Rn×d and corresponding label Y ∈ Rn , suppose
the loss L(w) is in the form of L(w) = L(Zw)
e e : Rn → R.
for some differentiable L
Assume that initialized at w(0) = winit , the mirror flow (7.7) converges and the
convergence point w∞ = limt→∞ w(t) satisfies Zw∞ = Y , then
See Section 7.7 for a proof. The above theorem is the building block for proving
the implicit bias induced by any commuting parametrization in overparametrized
linear models (see Theorem 7.4.16).
259
t t
φtGi i (x) −∇Gj φtGi i ◦ φGj j (x) = φGj j ◦ φtGi i (x)
tj
−∇Gi ti ti −∇Gi
tj
x −∇Gj t
φGj j (x)
260
The above definition of commuting parametrizations builds upon the differential
properties of the gradient vector fields {∇Gi }di=1 , where each Lie bracket [∇Gi , ∇Gj ]
characterizes the change of ∇Gj along the flow generated by ∇Gi . In particular,
when G is a commuting parametrization satisfying Assumption 7.3.5, it is further
equivalent to a characterization of ‘commuting’ in the integral form, as summarized in
Theorem 7.4.2. Also see Figure 7.1 for an illustration.
Theorem 7.4.2 (Adapted from Theorem 9.44 in Lee [173]). Let M be a smooth
submanifold of RD and G : M → Rd be a C 2 parametrization. For any i, j ∈ [d],
[∇Gi , ∇Gj ](x) = 0 for all x ∈ M if and only if for any x ∈ M , whenever both
φsGi ◦ φtGj (x) and φtGj ◦ φsGi (x) are well-defined for all (s, t) in some rectangle I1 × I2
where I1 , I2 ⊆ R are open intervals, it holds that φsGi ◦ φtGj (x) = φtGj ◦ φsGi (x) for all
(s, t) ∈ I1 × I2 .
Under Assumption 7.3.5, Lemma 7.3.6 implie s that the domain of φsGi ◦ φtGj (x) is
exactly Ii (x) × Ij (x), and thus the above theorem simplifies into the following.
The commuting condition clearly holds when each Gi only depends on a different
subset of coordinates of x, because we then have ∇2 Gi (·)∇Gj (·) ≡ 0 for any distinct
i, j ∈ [d] as ∇2 Gi and ∇Gj live in different subspaces of RD . We call such G separable
parametrizations, and this case covers all the previous examples [20, 105, 134, 136, 137].
Another interesting example is the quadratic parametrization: We parametrize w ∈ Rd
by G : RD → Rd where for each i ∈ [d], there is a symmetric matrix Ai ∈ RD×D such
that Gi (x) = 12 x> Ai x. Then each Lie bracket [Gi , Gj ](x) = (Aj Ai − Ai Aj )x, and thus
G is a commuting parametrization if and only if matrices {Ai }di=1 commute.
261
For concreteness, we analyze two examples below. The first one is both a separable
parametrization and a commuting quadratic parametrization. The second one is a
quadratic parametrization but not commuting.
2 2
Example 7.4.4 (u −v parametrization, Woodworth et al. [20]). Parametrize
w ∈ Rd by w = u 2
− v 2 . Here D = 2d, and the parametrization G is given by
for x = uv ∈ RD . Since each Gi (x) involves only ui and vi , G is
2 2
G(x) = u −v
a separable parametrization and hence a commuting parametrization. Meanwhile,
each Gi (x) is a quadratic form in x, and it can be directly verified that the matrices
underlying these quadratic forms commute with each other.
Therefore, ∇2 Gij does not commute with ∇2 Gii due to the same reason as in the
rank-1 case.
262
Remark 7.4.6. This non-commuting issue for general matrix factorization does not
conflict with the theoretical analysis in Gunasekar et al. [105] where the measurements
are commuting, or equivalently, only involves diagonal elements, as {Gii }di=1 are
indeed commuting parametrizations. Gunasekar et al. [105] is the first to identify the
above non-commuting issue and conjectured that the implicit bias result for diagonal
measurements can be extended to the general case.
Next, we proceed to present our analysis for gradient flow with commuting parametriza-
tion. The following two lemmas highlight the special properties of commuting
parametrizations. Lemma 7.4.7 shows that the point reached by gradient flow with
any commuting parametrization is determined by the integral of the negative gradient
of the loss along the trajectory.
Rt
Further define µ(t) = 0
−∇Lt (G(x(s)))ds. Suppose µ(t) ∈ U(xinit ) for all t ∈ [0, T )
where T ∈ R ∪ {∞}, then it holds that x(t) = ψ(xinit ; µ(t)) for all t ∈ [0, T ).
Based on Lemma 7.4.7, the next key lemma reveals the essential approach to find
the Legendre function.
263
G(ψ(xinit ; µ)) for all µ ∈ U(xinit ). Moreover, let R be the convex conjugate of Q, then
R is also a Legendre function and satisfies that int(dom R) = Ωw (xinit ; G) and
−1
∇2 R(G(ψ(xinit ; µ))) = ∂G(ψ(xinit ; µ))∂G(ψ(xinit ; µ))>
Next, we present our main result on characterizing any gradient flow with com-
muting parametrization by a mirror flow.
Define w(t) = G(x(t)) for all t ≥ 0, then the dynamics of w(t) is a mirror flow with
respect to the Legendre function R given by Lemma 7.4.8, i.e.,
Moreover, this R only depends on the initialization xinit and the parametrization G,
and is independent of the loss function Lt .
Proof of Theorem 7.4.9. Recall that the gradient flow in the x-space governed by
−∇(Lt ◦ G)(x) is
264
Using w(t) = G(x(t)), the corresponding dynamics in the w-space is
By Lemma 7.4.7, we know that the solution to the gradient flow satisfies x(t) =
Rt
ψ(xinit ; µ(t)) where µ(t) = 0 −∇Lt (G(x(s)))ds. Therefore, applying Lemma 7.4.8,
we get a Legendre function R : Rd → R ∪ {∞} with domain Ωw (xinit ; G) such that
−1
∇2 R(w(t)) = ∇2 R(G(ψ(xinit ; µ(t)))) = ∂G(ψ(xinit ; µ(t)))∂G(ψ(xinit ; µ(t)))
or equivalently,
which is exactly the mirror flow with respect to R initialized at w(0) = G(xinit ).
Further note that the result of Lemma 7.4.8 is completely independent of the loss
function Lt , and thus R only depends on the initialization xinit and the parametrization
G. This finishes the proof.
Theorem 7.4.9 provides a sufficient condition for when a gradient flow with certain
parametrization G is simulating a mirror flow. The next question is then: What are
the necessary conditions on the parametrization G so that it enables the gradient flow
to simulate a mirror flow? We provide a (partial) characterization of such G in the
following theorem.
265
Theorem 7.4.10 (Necessary condition on smooth parametrization to be commuting).
Let M be a smooth submanifold of RD and G : M → Rd be a smooth parametrization.
If for any xinit ∈ M , there is a Legendre function R such that for all time-dependent
loss Lt ∈ L, the gradient flow under Lt ◦ G initialized at xinit can be written as the
mirror flow under Lt with respect to R, then G must be a regular parametrization, and
it also holds that for each x ∈ M ,
Lie≥2 (∂G) x
⊆ ker(∂G(x)), (7.9)
where Lie≥K (∂G) := span [[[[∇Gj1 , ∇Gj2 ], . . .], ∇Gjk−1 ], ∇Gjk ] | k ≥ K, ∀i ∈ [k], ji ∈
[d]} is the subset of the Lie algebra generated by the gradients of coordinate functions of
G only containing elements of order higher than K, and ker(∂G(x)) is the orthogonal
complement of span({∇Gi (x)}di=1 ) in RD .
With the above necessary condition (7.9), we can formally refute the possibility
that one can use mirror flow to characterize the implicit bias of gradient flow for matrix
factorization in general settings, as summarized in Corollary 7.4.11. In Chapter 6
we will constructed a concrete counter example showing that the implicit bias for
commuting measurements, that gradient flow finds the solution with minimal nuclear
norm, does not hold for the general case, where gradient flow could prefer the solution
with minimal rank instead.
Corollary 7.4.11 (Gradient flow for matrix factorization cannot be written as mirror
flow). For any d, r ∈ N, let M be an open set in Rd×r and G : M → Rd×d be a smooth
parametrization given by G(U ) = U U > . Then there exists a initial point xinit ∈ M
and a time-dependent loss Lt such that the gradient flow under Lt ◦ G starting from
Uinit cannot be written as a mirror flow with respect to any Legendre function R under
the loss Lt .
266
Proof of Corollary 7.4.11. It turns out that the necessary condition in Theorem 7.4.10
is already violated by only considering the Lie algebra spanned by {∇G11 , ∇G12 }. We
follow the notation in Example 7.4.5 to define each Eij ∈ Rd as the one-hot matrix with
the (i, j)-th entry being 1, and denote E ij = 12 (Eij + Eji ) and ∆ij = Eij − Eji . Then
[∇G11 , ∇G12 ](U ) = 4(E 11 E 12 − E 12 E 11 )U = ∆12 U and [∇G11 , [∇G11 , ∇G12 ]](U ) =
(E 11 ∆12 − ∆12 E 11 )U = E 12 U . Further noting that h[∇G11 , [∇G11 , ∇G12 ]], ∇G12 i =
2
2 E 12 U F = 12 ri=1 (U1i2 + U2i2 ) must be positive at some U in every open set M , by
P
Theorem 7.4.10, we know such Uinit and Lt exist. Moreover, Lt will only depend on
G11 (U ) and G12 (U ).
The following corollary shows that gradient flow with non-commuting parametriza-
tion cannot be mirror flow, when the dimension of the reachable set matches with
that of the w-space.
(b) There is a Legendre function R such that for any time-dependent loss Lt ∈ L,
the gradient flow governed by −∇(Lt ◦ G) with initialization xinit can be written
as a mirror flow with respect to R.
Proof of Corollary 7.4.12. By the condition (b) and Theorem 7.4.10, we know that
each Lie bracket [∇Gi , ∇Gj ] ∈ ker(∂G). By the condition (a), we know that each Lie
bracket [∇Gi , ∇Gj ] ∈ span{∇Gi }di=1 . Combining these two facts, we conclude that
each [∇Gi , ∇Gj ] ≡ 0, so G is a commuting parametrization.
Next, we establish the convergence of w(t) = G(x(t)) when x(t) is given by some
gradient flow with the commuting parametrization G. Here we require that the convex
267
function R given by Lemma 7.4.8 is a Bregman function (see definition in Section 7.6).
The proofs of Theorem 7.4.13, Corollary 7.4.14 and Theorem 7.4.15 are in Section 7.8.
Theorem 7.4.13. Under the setting of Theorem 7.4.9, further assume that the loss L
is quasi-convex, ∇L is locally Lipschitz and argmin{L(w) | w ∈ dom R} is non-empty
where R : Rd → R ∪ {∞} is the convex function given by Lemma 7.4.8. Suppose
R is a Bregman function, then as t → ∞, w(t) converges to some w∗ such that
∇L(w∗ )> (w − w∗ ) ≥ 0 for all w ∈ dom R. Moreover, if the loss function L is convex,
then w(t) converges to a minimizer in dom R.
Corollary 7.4.14. Under the setting of Theorem 7.4.13, if the reachable set in the w-
space satisfies Ωw (xinit ; G) = Rd , then R is a Bregman function and all the statements
in Theorem 7.4.13 hold.
Theorem 7.4.15. Under the setting of Theorem 7.4.13, consider the commuting
quadratic parametrization G : RD → Rd where each Gi (x) = 12 x> Ai x, for symmetric
matrices A1 , A2 , . . . , Ad ∈ RD×D that commute with each other, i.e., Ai Aj − Aj Ai = 0
for all i, j ∈ [d]. For any xinit ∈ RD , if {∇Gi (xinit )}di=1 = {Ai xinit }di=1 are linearly
independent, then the following holds:
(a) For all µ ∈ Rd , ψ(xinit ; µ) = exp( di=1 µi Ai )xinit where exp(·) is the matrix
P
(b) For each j ∈ [d] and all µ ∈ Rd , Gj (ψ(xinit ; µ)) = 12 xinit > exp( di=1 2µi Ai )Aj xinit .
P
2
kψ(xinit ; µ)k22 =
1 1
Pd
(c) Q(µ) = 4 4
exp( i=1 µi Ai )xinit 2
is a Legendre function with
domain Rd .
(d) R is a Bregman function with dom R = range ∇Q where range ∇Q is the range
of ∇Q, and thus all the statements in Theorem 7.4.13 hold.
268
7.4.3 Solving underdetermined linear regression with com-
muting parametrization
There exists a convex function R (given by Lemma 7.4.8, depending only on the
initialization xinit and the parametrization G), such that for any dataset {(zi , yi )}ni=1 ⊂
Rd × R, if w(t) = G(x(t)) converges as t → ∞ and the convergence point w∞ =
limt→∞ w(t) satisfies Zw∞ = Y , then
that is, gradient flow implicitly minimizes the convex regularizer R among all interpo-
lating solutions.
269
Proof of Theorem 7.4.16. By Theorem 7.4.9, w(t) obeys the following mirror flow:
where the last equality follows from the property of convex conjugate. Combining
(7.10) and (7.11), we get R(w∞ ) ≤ R(w) for all w ∈ int(dom R) such that Zw = Y .
By the continuity of R, this property can be further extended to the entire dom R,
and for any w ∈
/ dom R, we have R(w) = ∞ by definition, so R(w∞ ) ≤ R(w) holds
trivially. This finishes the proof.
270
some interpolating solution, then the convergence point is closest to the initialization
in Euclidean distance among all interpolating solutions. This recovers the well-known
implicit bias of gradient flow for underdetermined regression.
Furthermore, we can recover the results on the quadratically overparametrized
linear model studied in a series of papers [20, 105, 138], as summarized in the following
Corollary 7.4.17. Note that their results assumed convergence in order to characterize
the implicit bias, whereas our framework enables us to directly prove the convergence
as in Theorem 7.4.15, where the convergence guarantee is also more general than
existing convergence results for Example 7.4.4 in Li et al. [133], Pesme et al. [176].
Corollary 7.4.17. Consider the underdetermined linear regression problem with data
Z ∈ Rd×n and Y ∈ Rn . Let L
e : Rn → R be a differentiable loss function such that L
e
e is locally Lipschitz, and Y ∈ Rn is its unique global minimizer.
is quasi-convex, ∇L
Consider solving minw L(Zw)
e by running gradient flow on L(w) = L(Zw)
e with the
quadratic parametrization w = G(x) = u 2 − v 2 where x = uv ∈ R2d
+ , for any
where R is given by
d
1 X w q
i 2 2 2 u0,i
R(w) = wi arcsinh − wi + 4u0,i v0,i − wi ln .
4 i=1 2u0,i v0,i v0,i
271
7.5 Every mirror flow is a gradient flow with com-
muting parametrization
Consider any smooth Legendre function R : Rd → R ∪ {∞}, and recall the correspond-
ing mirror flow:
d∇R(w(t)) = −∇L(w(t))dt.
Note that int(dom R) is a convex open set of Rd , hence a smooth manifold (see
Example 1.26 in Lee [173]). Then ∇2 R is a continuous positive-definite metric on
int(dom R). As discussed previously, the above mirror flow can be further rewritten
as the Riemannian gradient flow on the Riemannian manifold (int(dom R), ∇2 R), i.e.,
272
it holds that w(t) = G(x(t)) for all t ≥ 0 where x(t) is given by the gradient flow under
the objective Lt ◦ G initialized at xinit , i.e.,
To illustrate the idea, let us first suppose such a smooth and regular parametrization
G exists and is a bijection between the reachable set Ωx (xinit ; G) ⊂ RD and int(dom R),
whose inverse is denoted by F . It turns out that we can show
where the second equality follows from the relationship between R and G as discussed
in the introduction on Equation (7.2). Note that this corresponds to expressing the
metric tensor ∇2 R using an explicit map F , which is further equivalent to embedding
the Riemannian manifold (int(dom R), ∇2 R) into a Euclidean space (RD , g) in a way
that preserves its metric. This refers to a notion called isometric embedding in
differential geometry.
273
Nash’s embedding theorem is a classic result in differential geometry that guarantees
the existence of isometric embedding of any Riemannian manifold into a Euclidean
space with a plain geometry.
Theorem 7.5.3 (Nash’s embedding theorem, Nash [177], Gunther [178]). Any d-
dimensional Riemannian manifold has an isometric embedding to (RD , g) for some
D ≥ d.
The other way to understand Theorem 7.4.9 is that we can view ∇2 R(w)−1 ∇L(w)
as the gradient of L with respect to metric tensor gR , where g R is the Hessian
metric induced by strictly convex function R in the sense gxR (u, v) := u> ∇2 R(x)v
for any u, v ∈ Rd . It is well-known that gradient flow is invariant under isometric
embedding and thus we can use Nash’s embedding theorem to write the gradient flow
on riemmanian manifold (int(dom R), g R ) as that on (RD , g).
Despite the recent line of works on the connection between mirror descent and
gradient descent [136–138, 141, 142], so far we have not seen any concrete example
of non-separable parametrizaiton (in the sense of Definition 7.5.4) such that the
reparametrized gradient flow can be written as a mirror flow. In this subsection, we
discuss how we can use Theorem 7.5.1 to construct non-separable, yet commuting
parametrizations.
274
a matrix A ∈ Rd×d and a vector b ∈ Rd , such that
G(x) = AG(x)
b + b, ∀x ∈ M.
as ∇2 G b j = P i ∇2 G
bi ∇G bi Pi · Pj ∇G
bj ≡ 0 for all i =
6 j, so each Lie bracket [∇Gi , ∇Gj ] is
also 0 by the linearity.
As a concrete example, for matrix sensing with commutable measurement
A1 , . . . , Am ∈ Rd×d , let V = (v1 , . . . , vd ) ∈ Rd×d be a common eigenvector matrix
for {Ai }m >
= dj=1 σi,j vi vi> for each i ∈ [m].
P
i=1 such that we can write Ai = V Σi V
However, the bad news is that separable commuting parametrizations can only
express a restricted class of Legendre functions. It is easy to see ∂ G(x)∂
b b > must be
G(x)
diagonal for every x. Thus ∂G(x)∂G(x)> are simultaneously diagonalizable for all x,
and so are the Hessian of the corresponding Legendre function (given by Lemma 7.4.8).
There are interesting Legendre functions that does not always have their Hessians
simultaneously diagonalizable, such as
d
X d
X d
X
R(w) = wi (ln wi − 1) + 1 − wi ln 1 − wi −1 ,
i=1 i=1 i=1
Pd Pd
where each wi > 0 and i=1 wi < 1. We can check that ∇R(w) = i=1 ln 1−Pwdi wi
i=1
2
and ∇ R(w) = diag(w (−1)
)+ 1d 1>
d. It is proposed as an open problem by [137]
that whether we can find a parametrization G such that the reparametrized gradient
flow in the x-space simulates the mirror flow in the w-space with respect to the
aforementioned Legendre function R.
Our Theorem 7.5.1 answers the open problem by [137] affirmatively since it shows
every mirror flow can be written as some reparametrized gradient flow. According
275
to the previous discussion, every mirror flow for Lengendre function whose Hessian
cannot be simultaneously diagonalized always induces a non-separable commuting
parametrization. But this type of construction has two caveats: First, the construction
of the Legendre function uses Nash’s Embedding theorem, which is implicit and hard
to implement; second, the parametrization given by Theorem 7.5.1, though defined on
an open set in RD , is only commuting on the reachable set, which is a d-dimensional
submanifold of RD . This is different from all the natural examples of commuting
parametrizations which are commuting on an open set, leading to the following open
question.
Open Question: Is there any smooth, regular, commuting, yet non-separable (in
the sense of Definition 7.5.4) parametrization from an open subset of RD to Rd , for
some integers D and d?
Theorem 7.5.5. All smooth, regular and commuting parametrizations are non-
separable when D = 1.
Proof of Theorem 7.5.5. Note that [∇Gi , ∇Gj ] ≡ 0 implies that all Gi share the same
set of stationary points, i.e., {x ∈ R | ∇Gi (x) = 0} is the same for all i ∈ [d]. Since
D = 1, without loss of generality, we can assume G0i (x) = ∇Gi (x) > 0 for all x ∈ M
and i ∈ [d] since G is regular. Then it holds that sign(G0i )(ln |G0i |)0 = sign(G0j )(ln |G0j |)0 ,
which implies that |G0i |/|G0j | is equal to some constant independent of x. This completes
the proof.
Remark 7.5.6. We note that the assumption that the parametrization is regular is
necessary for the open question to be non-trivial. Otherwise, consider the following
example with D = 1 and d = 2: Let f1 , f2 : R → R be any smooth function supported
Rx
on (0, 1) and (1, 2) respectively. Define Gi (x) = 0 fi (t)dt for all x ∈ R. Then
parametrization G is non-separable.
276
7.6 Related basics for convex analysis
We first introduce some additional notations. For any function f , we denote its range
(or image) by range f . For any set S, we use S to denote its closure. For any matrix
Λ ∈ Rd×D and set S ⊆ RD , we define ΛS = {Λx | x ∈ S} ⊆ Rd .
Below we collect some related basic definitions and results in convex analysis. We
refer the reader to Rockafellar [179] and Bauschke et al. [180] as main reference sources.
In particular, Sections 2, 3 and 4 in Bauschke et al. [180] provide a clear summary of
the related concepts.
Here we consider a convex function f : Rd → R ∪ {∞} whose domain is dom f =
{w ∈ Rd | f (w) < ∞}. From now on, we assume by default that f is continuous
on dom f , the interior of its domain int(dom f ) is non-empty, and f is differentiable
on int(dom f ).
The notions of essential smoothness and essential strict convexity defined below
describe certain nice properties of a convex function (see Section 26 in Rockafellar
[179]).
Definition 7.6.1 (Essential smoothness and essential strict convexity). If for any
sequence {wn }∞
n=1 ⊂ int(dom f ) going to the boundary of dom f as n → ∞, it holds
The following results characterize the relationship between a convex function and its
conjugate.
277
Theorem 7.6.2 (Theorem 26.3, Rockafellar [179]). A convex function f is essentially
strictly convex if and only if its convex conjugate f ∗ is essentially smooth.
Lemma 7.6.4 (Corollary 2.6, Bauschke et al. [180]). If f is essentially strictly convex,
then it holds for all w ∈ int(dom f ) that ∇f (w) ∈ int(dom f ∗ ) and ∇f ∗ (∇f (w)) = w.
The class of Legendre functions defined in Definition 7.3.8 contains convex functions
that are both essentially smooth and essentially strictly convex.
Next, we introduce the notion of Bregman function [181, 182]. It has been shown
in Bauschke et al. [180] that the properties of Bregman functions are crucial to prove
the trajectory convergence of Riemannian gradient flow where the metric tensor is
given by the Hessian of some Bregman function f .
Definition 7.6.6 (Bregman functions; Definition 4.1, Alvarez et al. [183]). A function
f is called a Bregman function if it satisfies the following properties:
The following theorem from Alvarez et al. [183] provides a convenient tool for
proving the convergence of a Riemannian gradient flow.
Here we first present the proof for the result on the domain of the flow induced by G.
Proof of Lemma 7.3.6. Fix any x ∈ M . For each i ∈ [d], let Ii (x) be the domain of
φtGj (x) in terms of t. If ∇Gi is a complete vector field on M as in Definition 7.3.2, then
Ii (x) = Rd , otherwise φtGj (x) is defined for t in an open interval containing 0 (see, e.g.,
Theorem 2.1 in Lang [184]). Then we claim that for any distinct j1 , j2 , . . . , jk ∈ [d]
µ µj
where k ∈ [d], the set of all (µj1 , . . . , µjk ) ∈ Rk such that φGjj1 ◦ · · · ◦ φGjk (x) is well-
1 k
defined is a hyperrectangle given by Ij1 (x) × Ij2 (x) × · · · × Ijk (x). Then the desired
result can be obtained by letting (j1 , j2 , . . . , jd ) = (1, 2, . . . , d). We prove the claim by
induction over k ∈ [d].
279
The base case for k = 1 has already been established above. Next, assume the claim
holds for 1, 2, . . . , k − 1 where k ≥ 3, and we proceed to show it for k. By the claim for
µ µj
k−2, φGjj3 ◦· · ·◦φGjk (x) is well-defined for (µj3 , . . . , µjk ) ∈ Ij3 (x)×· · ·×Ijk (x). For any
3 k
µj
such (µj3 , . . . , µjk ), φtGj ◦φµG3j ◦· · ·◦φGjk (x) is well-defined for t in and only in the open
1 3 k
µj
interval Ij1 (x) by applying the claim for k − 1, and similarly φtGj ◦ φµG3j ◦ · · · ◦ φGjk (x)
2 3 k
is also well-defined for t in and only in the open interval Ij2 (x). Note that for any
(s, t) ∈ Ij1 (x) × Ij2 (x),
µ µj
φsGj1 ◦ φ−t t j3
Gj ◦ φGj2 ◦ φGj ◦ · · · ◦ φGj (x)
k
2 3 k
−t µ µj
φG j
◦ φsGj1 ◦ φtGj2 ◦ φGjj3 ◦ · · · ◦ φGjk (x)
2 3 k
µ µj
is also well-defined, which further implies that φsGj ◦ φtGj ◦ φGjj3 ◦ · · · ◦ φGjk (x) is
1 2 3 k
µ µj
well-defined. Therefore, we conclude that φGjj1 ◦ ··· ◦ φGjk (x) is well-defined for and
1 k
only for (µj1 , . . . , µjk ) ∈ Ij1 (x) × · · · × Ijk (x). This completes the induction and hence
finishes the proof.
Next, we provide the proof for the implicit bias of mirror flow summarized in
Theorem 7.3.9. We need the following lemma that characterizes the KKT conditions
for minimizing a convex function R in a linear subspace.
Lemma 7.7.1. For any convex function R : Rd → R ∪ {∞} and Z ∈ Rn×d , suppose
∇R(w∗ ) = Z > λ for some λ ∈ Rn , then
280
Proof of Lemma 7.7.1. Consider another convex function defined as R(w)
e = R(w) −
w> Z > λ, then ∇R(w
e ∗ ) = ∇R(w∗ ) − Z > λ = 0, which implies that
Z t
>
∇R(w(t)) − ∇R(w0 ) = −Z ∇L(Zw(s)
e − Y )ds ∈ span(X > ),
0
281
Then applying Lemma 7.7.1 yields
Here we provide the omitted proofs in Section 7.4, including four main parts:
(3) Convergence for gradient flow with commuting parametrization (Section 7.8.3);
We first show the representation formula for gradient flow with commuting parametriza-
tion given in Lemma 7.4.7.
Proof of Lemma 7.4.7. Let µ(t) be given by the following differential equation:
282
For any µ ∈ U(x) and j ∈ [d], µ + δej ∈ U(x) for all sufficiently small δ, thus
where the second equality follows from the assumption that G is a commuting
∂ψ(xinit ;µ)
parametrization and Theorem 7.4.2. Then we have ∂µ
= ∂G(ψ(xinit ; µ))>
for all µ ∈ U(xinit ), and thus when µ(t) ∈ U(xinit ),
∂ψ(xinit ; µ(t))
dψ(xinit ; µ(t)) = dµ(t)
∂µ(t)
= −∂G(xinit ; µ(t))∇Lt (G(ψ(xinit ; µ(t))))dt
Then since ψ(xinit ; µ(0)) = xinit and ψ(xinit ; µ(t)) follows the same differential equation
and has the same initialization as x(t), we have x(t) ≡ ψ(xinit ; µ(t)) for all t ∈ [0, T ).
Therefore,
Z t Z t
µ(t) = µ(0) + −∇Lt (G(ψ(xinit ; µ(s))))ds = −∇Lt (G(x(s)))ds
0 0
Next, to prove Lemma 7.4.8, we need the following lemma which provides a
sufficient condition for a vector function to be gradient of some other function.
283
Proof of Lemma 7.8.1. This follows from a direct application of Corollary 16.27 in
Lee [173].
U(xinit ) such that limk→∞ µk = µ∞ , we have limk→∞ k∇Q(µk )k2 = ∞. Since each
∇Q(µk ) = G(ψ(xinit ; µk )), we only need to show that limk→∞ kG(ψ(xinit ; µk ))k2 = ∞.
Suppose otherwise, then {G(ψ(xinit ; µk )}∞
k=1 is bounded. Note that by Lemma 7.4.7,
Z 1
ψ(xinit ; µk ) = φ1−Hk (xinit ) = xinit + ∇Hk (φs−Hk (xinit ))ds.
0
284
Therefore,
s
Z 1 Z 1
2
kψ(xinit ; µk ) − xinit k2 ≤ ∇Hk (φs−Hk (xinit )) 2 ds ≤ ∇Hk (φs−Hk (xinit )) 2 ds.
0 0
(7.14)
where the second inequality follows from Cauchy-Schwarz inequality. Further note
that
Z 1
d
Hk (ψ(xinit ; µk )) − Hk (xinit ) = Hk (φs−Hk (xinit ))ds
ds
Z0 1
dφs−Hk (xinit )
s
= ∇Hk (φ−Hk (xinit )), ds
0 ds
Z 1
= k∇Hk (φs−Hk (xinit ))k22 ds. (7.15)
0
p
kψ(xinit ; µk ) − xinit k2 ≤ hµk , G(ψ(xinit ; µk )) − G(xinit )i
p
≤ kµk k2 · kG(ψ(xinit ; µk )) − G(xinit )k2 ,
285
Combining the above, it follows that Q is a Legendre function. Let R : Rd →
R ∪ {∞} be the convex conjugate of Q. Then by Theorem 7.6.5, R is also a Legendre
function. Note that for any µ ∈ U(xinit ), by the result in Crouzeix [185], we have
be commuting
Proof of Theorem 7.4.10. Fix any initialization xinit ∈ M , and let the Legendre func-
tion R be given such that for all time-dependent loss Lt , the gradient flow under Lt ◦ G
initialized at x can be written as the mirror flow under Lt with respect to the Legendre
function R. We first introduce a few notations that will be useful for the proof. For
any s ∈ R, we define a time-shifting operator Ts such that for any time-dependent loss
Lt (·), (Ts L)t (·) = Lt−s (·). We say a time-dependent loss Lt is supported on finite time
if Lt = ki=1 1t∈[ti ,ti+1 ) L(i) for some k ≥ 1 where t1 = 0, tk+1 = ∞ and L(k) ≡ 0, and
P
`j,δ
t (w) = 10≤t≤δ · hej , wi (7.16)
286
where ej is the j-th canonical base of Rd .
Now for any k ≥ 2, let {ji }ki=1 be any sequence where each ji ∈ [d]. Then
we recursively define a sequence of time-dependent losses as follows: First define
L1,δ = −`j1 ,δ , then sequentially for each i = 2, 3, . . . , k, we define
√ √ √ √
i,δ i−1, δ ji , δ i−1, δ
L =L k −` k −L k `ji , δ
(7.17)
√ √
i−1, δ
where we write L for convenience. Denote ιi (δ) = len(Li,δ ) for each
= Li−1, δ
√ √
i ∈ [k]. Then ι1 (δ) = δ and ιi (δ) = 2 δ + 2ιi−1 ( δ) for i = 2, 3, . . . , k, which further
implies
i−1
m i−1
X
ιi (δ) = 2m δ 1/2 + 2i−1 δ 1/2 for all i ∈ [k].
m=1
287
well-defined. Moreover, it follows from (7.18) that
√ √ √
Z ιk−1 (δ) Z ιk−1 ( δ) √
Z ιk−1 ( δ)+ δ
∇Lk,δ
t (w(t))dt = ∇Lk−1, δ
(w(t))dt + √
−ejk dt
0 0 ιk−1 ( δ)
√ √ √ √
Z 2ιk−1 ( δ)+ δ √ Z 2ιk−1 ( δ)2 δ
k−1, δ
+ √ √
−∇L (w(t))dt + √ √
ejk dt
ιk−1 ( δ)+ δ 2ιk−1 ( δ)+ δ
√
Z ιk−1 ( δ) √ √
k−1, δ
= ∇Lk−1,
t
δ
(w(t)) − ∇Lt (w(t)) dt = 0
0
√
where the last two equalities follow from the fact that ∇Ltk−1, δ (w) does not depend
on w and is only determined by t by our construction.
Hence, the mirror flow with respect to the Legendre function R for the time-
dependent loss Lk,δ will return to the initialization after ιk (δ) time since
Z ιk (δ)
∇R(w(ιk (δ))) − ∇R(w(0)) = −∇Lk,δ (w(t))dt = 0.
0
ι (δ)
G(xinit ) = G φLkk,δ ◦G (xinit )
for all sufficiently small δ. Then differentiating with δ on both sides yields
ι (δ)
dφ kk,δ (xinit )
∂G(x) · L ◦G = 0. (7.19)
dδ δ=0
ι (δ)
dφLkk,δ ◦G (xinit )
= [[[[∇Gj1 , ∇Gj2 ], . . .], ∇Gjk−1 ], ∇Gjk ](xinit ), (7.20)
dδ δ=0
then combining (7.19) and (7.20) completes the proof, so it remains to verify (7.20).
288
We will prove by induction over k, and now let {ji }∞
i=1 be an arbitrary sequence
ι (δ)
πk,δ (·) := φδ−`jk ,δ (·) and Πk,δ (·) := φLkk,δ (·).
−1 ι (δ)
Then their inverse maps are given by πk,δ (·) = φδ`jk ,δ (·) and Π−1
k,δ (·) = φ
k
k,δ (·) respec-
−L
k
tively. Since G is smooth, each Πk,√δ is a C ∞ function of δ 1/2 , and we can expand it
k
in δ 1/2 as
2 k k
X δ i/2
Πk,√δ (x) = x + ∆k,i (x) + rk,δ (x) (7.21)
i=1
i!
where the remainder term rk,δ (x) is continuous in x and for each x ∈ M , rk,δ (x) = o(δ)
rk,δ (x)
(i.e., limδ→0 δ
= 0), and each ∆k,i is defined as
di Πk,√δ (x)
∆k,i (x) = .
d(δ 1/2k )i δ=0
√ δ
Π1,√δ (x) = π1,√δ (x) = x + δ∇Gj1 (x) + ∂(∇Gj1 )(x)∇Gj1 (x) + r1,δ (x) (7.22)
2
where the second equality holds as well for any other Gj in place of Gj1 , with a different
but similar remainder term. For any fixed K ≥ 2, there is a small open neighborhood
of xinit on M , denoted by Nxinit ⊆ M , such that for all k ∈ [K], we have rk,δ (x) = o(δ)
uniformly over all x ∈ Nxinit , so we can replace all rk,δ (x) by o(δ) when x ∈ Nxinit .
Then we claim that for each k = 2, 3, . . . , K,
2 k−1k
1 X δ i/2
lim √ ∆k,i (x) = [[[∇Gj1 , ∇Gj2 ], . . .], ∇Gjk ](x), ∀x ∈ Nxinit , (7.23)
δ→∞ δ i=1 i!
289
which directly implies (7.20). With a slight abuse of notation, the claim is also true
for k = 1 since ∆1,1 (x) = ∇Gj1 (x) by (7.22), so we use this as the base case of the
induction. Then, assuming (7.23) holds for k − 1 < K, we proceed to prove it for k.
For convenience, further define LieG (j1:k ) = [[[∇Gj1 , ∇Gj2 ], . . .], ∇Gjk ].
Combining the Taylor expansion in (7.21) and (7.23) for k − 1, we obtain for all
x ∈ Nxinit that
k−1
2X
√ δ i/2
k−1
√
Πk−1, δ (x) = x + δ · LieG (j1:(k−1) )(x) + ∆k−1,i (x) + o(δ)
i!
i=2k−2 +1
for sufficiently small δ. Further apply (7.22) with Gjk in place of Gj1 for sufficiently
small δ, and then
Πk−1,√δ πk,√δ (x)
√
√
δ
= Πk−1, δ x + δ∇Gjk (x) + ∂(∇Gjk )(x)∇Gjk (x) + o(δ)
2
√ δ
= x + δ∇Gjk (x) + ∂(∇Gjk )(x)∇Gjk (x) + o(δ)
2
√ √
δ
+ δ · LieG (j1:(k−1) ) x + δ∇Gjk (x) + ∂(∇Gjk )(x)∇Gjk (x) + o(δ)
2
k−1
2X
δ i/2
k−1
√
δ
+ ∆k−1,i x + δ∇Gjk (x) + ∂(∇Gjk )(x)∇Gjk (x) + o(δ)
k−2
i! 2
i=2 +1
√
δ
+ rk−1,δ x + δ∇Gjk (x) + ∂(∇Gjk )(x)∇Gjk (x) + o(δ)
2
where the second equality follows from the Taylor expansion of Πk−1,√δ and that
πk,√δ (x) ∈ Nxinit for sufficiently small δ. Then by the Taylor expansion of LieG (j1:(k−1) )
290
and each ∆k−1,i , we have for all x ∈ Nxinit ,
√ √ δ
Πk−1,√δ πk,√δ (x) = x + δ∇Gjk (x) + δ · LieG (j1:(k−1) )(x) + ∂(∇Gjk )(x)∇Gjk (x)
2
k−1
2X k−1
δ i/2
+ δ · ∂LieG (j1:(k−1) )(x)∇Gjk (x) + ∆k−1,i (x) + o(δ)
k−2
i!
i=2 +1
(7.24)
for sufficiently small δ. For the other way around, we similarly have
k−1
2X
√ δ i/2
k−1
πk,√δ √ √
Πk−1, δ (x) = πk, δ x + δ · LieG (j1:(k−1) )(x) + ∆k−1,i (x) + o(δ)
i!
i=2k−2 +1
√ √ δ
=x+ δ · LieG (j1:(k−1) ) + ∂(∇Gjk )(x)∇Gjk (x)
δ∇Gjk (x) +
2
k−1
2X k
δ i/2
+ δ∂(∇Gjk )(x)LieG (j1:(k−1) )(x) + ∆k−1,i (x) + o(δ)
k−2
i!
i=2 +1
(7.25)
−1√ ◦ Π−1 √ ◦ Π √
for all x ∈ Nxinit , when δ is sufficiently small. Note that x = πk, δ k−1, δ k−1, δ ◦
−1√ ◦ Π−1 √ ◦ π √ ◦ Π √
Πk,δ (x) − x = πk, δ k−1, δ k, δ k−1, δ (x) − x
−1√ ◦ Π−1 √ ◦ π √ ◦ Π √ √ √
= πk, δ k−1, δ k, δ k−1, δ (x) − πk, δ ◦ Πk−1, δ (x)
291
−1√ ◦ Π−1 √ (·) in terms
where the last equality follows from the Taylor expansion of πk, δ k−1, δ
√
of δ. Now, combining (7.24), (7.25) and (7.26), we obtain
Πk,δ (x) − x = δ ∂(∇Gjk )(x)LieG (j1:(k−1) )(x) − ∂LieG (j1:(k−1) )(x)∇Gjk (x) + o(δ)
where the second equality follows from the definition of Lie bracket. Comparing (7.27)
with (7.21) yields (7.23). This completes the induction for k ∈ [K] and hence finishes
the proof as K is arbitrary.
tion
Based on Theorem 7.6.7, we can prove the trajectory convergence of w(t) for the
special case where Ωw (xinit ; G) = Rd as summarized in Corollary 7.4.14.
292
Next, we prove that for a class of commuting quadratic parametrizations, the
corresponding Legendre function is also a Bregman function, thus guaranteeing the
trajectory convergence.
Proof of Theorem 7.4.15. Since A1 , A2 , . . . , Ad commute with each other, these matri-
ces can be simultaneously diagonalized. Thus we can assume without loss of generality
that each Ai = diag(λi ) where λi ∈ RD , then Gi (x) = λ> 2
i x . For convenience, we
for some Legendre function R whose conjugate is denoted by Q. To apply the results
in Theorem 7.4.13, it suffices to show that this R is a Bregman function.
2
To do so, we further denote w
e=x and G(x)
e = x 2 , then w = Λw
e and in this
case G e defined on M = RD
e is a commuting parametrization for w + . Also, we have
∂G(x) = Λ∂ G(x).
e e : Rd → R be defined by L(
Let L e w)
e = L(Λw),
e which satisfies that
293
∇L( e = Λ> ∇L(Λw).
e w) e Then the gradient flow with parametrization G
e governed by
−∇(L
e ◦ G)(x)
e is given by
> e e
dx(t) = −∇(L
e ◦ G)(x)dt
e = −∂ G(x(t))
e ∇L(G(x(t))dt
> >
= −∂ G(x(t))
e Λ ∇L(ΛG(x(t))dt
e
= −∂G(x(t))> ∇L(G(x(t))dt,
which yields the same dynamics of the gradient flow with parametrization G governed
by −∇(L ◦ G)(x). Therefore, we have w(t) = G(x(t)) = ΛG(x(t))
e = Λw(t),
e where
again by Theorem 7.4.9, the dynamics of w(t)
e is
e = −∇2 R(
dw(t) e −1 ∇L(
e w(t)) e w(t))dt,
e w(0)
e = G(x
e init )
e ∈ RD , we define ψ(x;
µ e µ e) = φµG
e1 µ
e ◦ φG
e2 µ
e ◦ · · · ◦ φG
eD
e (x). We need the following lemma.
1 2 D
Lemma 7.8.2. In the setting of the proof of Theorem 7.4.15, for any µ ∈ Rd and
e Λ> µ).
x ∈ M , we have ψ(x; µ) = ψ(x;
Recall from Lemma 7.4.8 that ∇Q(µ) = G(ψ(xinit ; µ)) for any µ ∈ Rd and ∇Q(e
e µ) =
G(
e ψ(x
e init ; µ e ∈ RD . Note that
e)) for any µ
(7.28)
where the second equality follows from Lemma 7.8.2. This implies that Q(µ) =
e > µ) + C for some constant C. Recall the definition of convex conjugate, and we
Q(Λ
294
have
R( e = sup he
e w) e − Q(e
µ, wi e µ), R(w) = sup hµ, wi − Q(µ).
e∈RD
µ µ∈Rd
e ∈ RD , we have
Then for any w
= sup he e − Q(e
µ, wi e µ) − C ≤ sup he e − Q(e
µ, wi e µ) − C = R( e −C
e w) (7.29)
e∈Λ> Rd
µ e∈RD
µ
e ∈ dom R,
Therefore, for any w e ≤ R(
e it holds that R(Λw) e − C < ∞, so Λ dom R
e w) e⊆
Combining the above, we see that dom R = Λ dom R. e As discussed in Section 7.1,
e = RD
a Bregman function with domain dom R D
+ . Thus dom R = ΛR+ is also a closed
295
and
e ∈ RD
for all w e ∈ RD
+ . Then for any w y ∈ int(dom R), we have
+ and y = Λe
e − R(Λe
= R(Λw) y ) − h∇R(e e − yei
e y ), w
e − R(Λe
= R(Λw) y ) − R(
e w) e y ) + D e (w,
e + R(e R e y e) (7.31)
≥ R(Λw)
e − R( e + C + DRe (w,
e w) e ye)
where the inequality follows from (7.29). Therefore, we further have for any α ∈ R
{y ∈ int(dom R) | DR (Λw, y ∈ RD
e y) ≤ α} ⊆ Λ{e + | DR e ye) ≤ α − R(Λw)
e (w, e + R( e − C}
e w)
left-hand side.
Finally, we verify the third condition in Definition 7.6.6. Consider any w ∈ dom R
and sequence {wi }∞
i=1 ⊂ int(dom R) such that limi→∞ wi = w. Since dom R =
Λ dom R, e ∈ RD
e there is some w + such that w = Λw ei ∈ RD
e and some w + for each i ∈ N
+
296
such that wi = Λw
ei . We have that
Z 1
R(w) − R(wi ) = h∇R((1 − t)wi + tw), w − wi idt
0
Z 1
= hΛ> ∇R(Λ((1 − t)w
ei + tw)),
e w e−w
ei idt
Z0 1
= h∇R((1
e − t)w
ei + tw),
e w e−w
ei idt
0
= R( e − R(
e w) ew ei ).
we have
Proof of Lemma 7.8.2. For each i ∈ [D] and any t > 0, we have
Z t Z t D
X
φtGi (x) =x+ −∇Gi (φsfi (x))ds =x+ − ej (φs (x))ds = ψ(x;
λi,j ∇G fi
e tλi )
s=0 s=0 j=1
where the last equality follows from Lemma 7.4.7. Therefore, for any µ ∈ Rd , we
further have
297
where the third equality follows from the assumption that G
e is a commuting
Here we provide the proof for the implicit bias result for the quadratically over-
parametrized linear model.
parametrization G satisfies the conditions in Theorem 7.4.15, which then implies the
convergence of w(t).
Next, we identify the function R given by Theorem 7.4.9. we have ψ(xinit ; µ) =
u0 e−2µ
v0 e2µ
and thus
s 2
1 wi wi 1 v0,i
(∇R(w))i = (∇Q(µ))−1
i (w) = ln 1+ + + ln
4 2u0,i v0,i 2u0,i v0,i 4 u0,i
1 wi 1 v0,i
= arcsinh + ln
4 2u0,i v0,i 4 u0,i
298
which further implies that
d q
1X wi 2 2 2 u0,i
R(w) = wi arcsinh − wi + 4u0,i v0,i − wi ln + C.
4 i=1 2u0,i v0,i v0,i
We first prove the following intermediate result that will be useful in the proof of
Theorem 7.5.1.
Lemma 7.9.1. Under the setting of Theorem 7.5.1, let F be the smooth map that
isometrically embeds (int(dom R), g R ) into (RD , g). Let M = range(F ), and denote
e : M → Rd . Then for any w ∈ int(dom R), it holds that
the inverse of F by G
v = ∂F (G(x))∂
e G(x)v.
e (7.32)
Now, for any w ∈ int(dom R), let x = F (w), then for any v ∈ Tx (M ), it follows from
(7.32) that
299
Note that the span of the column space of ∂F (w) is exactly Tx (M ), so for any v in
the orthogonal complement of Tx (M ), it holds that
where the second equality follows from the fact that for any i ∈ [d], ∇G
ei (x) ∈ Tx (M ).
yields
∂ G(F
e (w))∂F (w) = Id .
300
e : M → Rd . Note (M, G)
inverse of F on M by G e is a global atlas for M , we have
ei (x)}di=1 ) for all x ∈ M . This G
that Tx (M ) = span({∇G e is almost the commuting
parametrization that we seek for, except now it is only defined on M but not on an
open neighborhood of M . Yet we can extend G
e to an open neighbourhood of M in the
following way: First by Foote [186], for each x ∈ M , there is an open neighbourhood
Ux of x such that projection function P defined by
P (y) = argmin ky − y 0 k2
y 0 ∈M
which implies that v = ∂P (x)v by letting t → 0. While for any v in the orthogonal
complement of Tx (M ), for sufficiently small δ > 0, we have P (x + δv) is smooth in δ.
Then since P (x + δv) ∈ M for all sufficiently small δ by its definition, we have
301
where O(δ 2 ) denotes a term whose norm is bounded by Cδ 2 for a constant C > 0 for
all sufficiently small δ, and the second equality follows from (7.33). Then dividing
both sides by δ and letting δ → 0, we have kvk2 ≥ kv − uk2 . Since u is orthogonal to
v, we must have u = 0. As v is arbitrary, we conclude that ∂P (x) is the orthogonal
projection matrix onto Tx (M ). Then differentiating both sides of G(x) = G(P
e (x))
with x yields
∂G(x) = ∂ G(P
e (x))∂P (x) = ∂ G(x)
e (7.34)
ei (x)}d ). This
where the second equality follows from the fact that Tx (M ) = span({∇G i=1
further implies that the solution of Equation (7.13) satisfies dx/dt = −∇(L ◦ G)(x)
e ∈
Tx (M ), and thus x(t) ∈ M for all t ≥ 0.
Now we consider the mirror flow
−1
dw(t) = − ∂F (w(t))> ∂F (w(t)) ∇Lt (w(t))dt.
where the third equality follows from Lemma 7.9.1 and (7.34).
302
Next, we verify that G restricted on M , G,
e is a commuting and regular parametriza-
e > = ∂F (G(x))(∂F
tion. First, for any x ∈ M , we have ∂ G(x) e (G(x))
e >
∂F (G(x)))
e −1
by Lemma 7.9.1 and (7.34). Since ∇2 R(w) = ∂F (w)> ∂F (w) is of rank d for all
w ∈ int(dom R), it follows that ∂F (w) is also of rank d for all w ∈ int(dom R), thus
∂ G(x)
e ei }d follows directly
is of rank d for all x ∈ M . The commutability of {∇G i=1
from Corollary 7.4.12. Here we just need to show rank(Ωx (x; G))
e = rank(M ). This is
ei (x)}di=1 )) = rank(M ), and on
e ≥ rank(span({∇G
because on one hand rank(Ωx (x; G))
e ≤ rank(M ) since Ωx (x; G)
the other hand, rank(Ωx (x; G)) e ⊂ M , for any x ∈ M .
field on M . For any xinit ∈ M , consider loss Lt (w) = hej , wi, and the corresponding
gradient flow is
>
dx(t) = −∇(Lt ◦ G)(x(t))dt
e = −∂ G(x(t))
e ∇Lt (G(x(t)))dt
e = −∇G
ej (x(t)),
so x(t) = φtGe (xinit ) for all t ≥ 0. On the other hand, w(t) = G(x(t))
e satisfies that
j
>
dw(t) = ∂ G(x(t))dx(t)
e = −∂ G(x(t))∂
e G(x(t))
e ∇Lt (w(t))dt
where the third equality follows from Lemma 7.9.1 and Equation (7.34). Therefore,
rewriting the above as a mirror Flow yields
field.
303
Part III
304
Chapter 8
Understanding the implicit bias of Stochastic Gradient Descent (SGD) is one of the key
challenges in deep learning, especially for overparametrized models, where the local
minimizers of the loss function L can form a manifold. Intuitively, with a sufficiently
small learning rate η, SGD tracks Gradient Descent (GD) until it gets close to such
manifold, where the gradient noise prevents further convergence. In such regime,
Blanc et al. [165] proved that SGD with label noise locally decreases a regularizer-like
term, the sharpness of loss, tr[∇2 L].
This chapter gives a general framework for such analysis by adapting ideas
from Katzenberger [187]. It allows in principle a complete characterization for the
regularization effect of SGD around such manifold—i.e., the ”implicit bias”—using a
stochastic differential equation (SDE) describing the limiting dynamics of the parame-
ters, which is determined jointly by the loss function and the noise covariance. This
yields some new results: (1) a global analysis of the implicit bias valid for η −2 steps,
in contrast to the local analysis of Blanc et al. [165] that is only valid for η −1.6 steps
and (2) allowing arbitrary noise covariance.
305
As an application, we show with arbitrary large initialization, label noise SGD can
always escape the kernel regime and only requires O(κ ln d) samples for learning a
κ-sparse overparametrized linear model in Rd [20], while GD initialized in the kernel
regime requires Ω(d) samples. This upper bound is minimax optimal and improves
e 2 ) upper bound [21].
the previous O(κ
8.1 Introduction
The implicit bias underlies the generalization ability of machine learning models trained
by stochastic gradient descent (SGD). But it still remains a mystery to mathematically
characterize such bias. We study SGD in the following formulation
√
xη (k + 1) = xη (k) − η(∇L(xη (k)) + Ξ · σξk (xη (k))) (8.1)
where η is the learning rate (LR), L : RD → R is the training loss and σ(x) =
[σ1 (x), σ2 (x), . . . , σΞ (x)] ∈ RD×Ξ is a deterministic noise function. Here ξk is sampled
uniformly from {1, 2, . . . , Ξ} and it satisfies Eξk [σξk (x)] = 0, ∀x ∈ Rd and k.
It is widely believed that large LR (or equivalently, small batch size) helps SGD
find better minima. For instance, some previous works argued that large noise enables
SGD to select a flatter attraction basin of the loss landscape which potentially benefits
generalization [164, 188]. However, there is also experimental evidence [50] that small
LR also has equally good implicit bias (albeit with higher training time), and that is
the case studied here. Presumably low LR precludes SGD jumping between different
basins since under general conditions this should require Ω(exp(1/η)) steps [189].
In other words, there should be a mechanism to reach better generalization while
staying within a single basin. For deterministic GD similar mechanisms have been
demonstrated in simple cases [14, 16, 100] and referred to as implicit bias of gradient
306
descent. This chapter presents a study of implicit bias of Stochastic GD, which turns
out to be quite different, mathematically.
Recent work [165] shed light on this direction by analyzing effects of stochasticity
in the gradient. For sufficiently small LR, SGD will reach and be trapped around some
manifold of local minimizers, denoted by Γ (see Figure 8.2). The effect is shown to be
an implicit deterministic drift in a direction corresponding to lowering a regularizer-like
term along the manifold. They showed SGD with label noise locally decreases the
sharpness of loss, tr[∇2 L], by Θ(η 0.4 ) in η −1.6 steps. However, such an analysis is
actually local, since the natural time scale of analysis should be η −2 , not η −1.6 .
The contribution of this chapter is a more general and global analysis of this type.
We introduce a more powerful framework inspired by the classic paper [187].
SGD
eη (t) = −η∇L(X
dX eη (t))dt + η · σ(X
eη (t))dW (t). (8.2)
307
(a) Taylor Expansion of ∇L (b) Normal Space Dynamics (c) Tangent Space Dynamics
drift. The key observation of Blanc et al. [165] is that the local dynamics of X
eη (t) is
completely different in tangent space and normal space — the fast random walk in
eη (t) to move slowly (with velocity Θ(η 2 )) but deterministically
normal space causes X
eη (t) − X∗ , Taylor expansion
in certain direction. To explain this, letting ∆(t) = X
of (8.2) gives d∆(t) ≈ −η∇2 L(X∗ )∆dt + ησ(X∗ )dW (t), meaning ∆ is behaving like
an Ornstein-Uhlenbeck (OU) process locally in the normal space. Its mixing time is
Θ(η −1 ) and the stationary distribution is the standard multivariate gaussian in the
√
normal space scaled by η (see Figure 8.1b), because noise covariance σσ > = ∇2 L.
Though this OU process itself doesn’t form any regularization, it activates the second
order Taylor expansion of ∇L(X∗ + ∆(t)), i.e., − 12 ∂ 2 (∇L)(X∗ )[∆(t), ∆(t)], creating a
Θ(η 2 ) velocity in the tangent space. Since there is no push back force in the tangent
space, the small velocity accumulates over time, and in a longer time scale of Ω(η −1 ),
the time average of the stochastic velocity is roughly the same as the expected velocity
when ∆ is sampled from its stationary distribution. This simplifies the expression of
η2
the velocity in tangent space to 2
∇T tr[∇2 L] (see Figure 8.1c), where ∇T means the
gradient is only taken in the tangent space.
However, the above approach only gives a local analysis for O(η −1.6 ) time, where
the total movement due to implicit regularization is O(η 2−1.6 ) = O(η 0.4 ) and thus is
negligible when η → 0. In order to get a non-trivial limiting dynamics when η → 0, a
308
global analysis for Ω(η −2 ) steps is necessary and it cannot be done by Taylor expansion
with a single reference point. Recent work by Damian et al. [166] glues analyses of
multiple local phases into a global guarantee that SGD finds a (, γ)-stationary point
for the regularized loss, but still doesn’t show convergence for trajectory when η → 0
and cannot deal with general noise types, e.g., noise lying in the tangent space of
the manifold. The main technical difficulty here is that it’s not clear how to separate
the slow and fast dynamics in different spaces and how to only take limit for the
slow dynamics, especially when shifting to a new reference point in the Taylor series
calculation.
In this work, we tackle this problem via a different angle. First, since the anticipated
limiting dynamics is of speed Θ(η 2 ), we change the time scaling to accelerate (8.2) by
η −2 times, which yields
The key idea here is that we only need to track the slow dynamic, or equivalently,
some projection of X onto the manifold Γ, Φ(X). Here Φ : RD → Γ is some function
to be specified and hopefully we can simplify the dynamics (8.3) via choosing suitable
Φ. To track the dynamics of Φ(Xη ), we apply Ito’s lemma (a.k.a. stochastic chain
rule, see Lemma 8.4.10) to Equation (8.3), which yields
309
Note the first term −η −1 ∂Φ(Xη )∇L(Xη ) is going to diverge to ∞ when η → 0, so a
natural choice for Φ is to kill the first term. Further note −∂Φ(X)∇L(X) is indeed
the directional derivative of Φ at X towards −∇L, killing the first term becomes
equivalent to making Φ invariant under Gradient Flow (GF) of −∇L(X)! Thus it
suffices to take Φ(X) to be the limit of GF starting at X. (Formally defined in
Section 8.3; see Lemma 8.9.2 for a proof of ∂Φ(X)∇L(X) ≡ 0.)
Also intuitively Xη will be infinitely close to Γ, i.e., d(Xη (t), Γ) → 0 for any t > 0
as η → 0, so we have Φ(Xη ) ≈ Xη . Thus we can rewrite the above equation as
1 XD
dXη (t) ≈ ∂Φ(Xη (t))σ(Xη (t))dW (t) + ∂ij Φ(Xη (t))(σ(Xη (t))σ(Xη (t))> )ij dt,
2 i,j=1
(8.4)
and the solution of (8.4) shall converge to that of the following (in an intuitive sense):
1 XD
dX(t) = ∂Φ(X(t))σ(X(t))dW (t) + ∂ij Φ(X(t))(σ(X(t))σ(X(t))> )ij dt,
2 i,j=1
(8.5)
The above argument for SDE was first formalized and rigorously proved by Katzen-
berger [187]. It included an extension of the analysis to the case of asymptotic
continuous dynamics (Theorem 8.5.2) including SGD with infinitesimal LR, but the
result is weaker in this case and no convergence is shown. Another obstacle for
applying this analysis is that 2nd order partial derivatives of Φ are unknown. We
solve these issues in Section 8.5 and our main result Theorem 8.5.7 gives a clean and
complete characterization for the implicit bias of SGD with infinitesimal LR in Θ(η −2 )
steps. Finally, our Corollary 8.6.2 shows (8.5) gives exactly the same regularization as
tr[∇2 L] for label noise SGD.
The main contributions of this chapter are summarized as follows.
310
1. In Section 8.5, we propose a mathematical framework to study the implicit bias
of SGD with infinitesimal LR. Our main theorem (Theorem 8.5.7) gives the
limiting diffusion of SGD with LR η for Θ(η −2 ) steps as η → 0 and allows any
covariance structure.
2. In Section 8.6, we give limiting dynamics of SGD with isotropic noise and label
noise.
3. In Section 8.7, we show for any initialization, SGD with label noise achieves
O(κ ln d) sample complexity for learning a κ-sparse overparametrized linear
model [20]. In this case, the implicit regularizer is a data-dependent weighted
`1 regularizer, meaning noise can help reduce the norm and even escape the
kernel regime. The O(κ ln d) rate is minimax optimal [192] and improves over
e 2 ) upper bound by HaoChen et al. [21]. In contrast, vanilla GD requires
O(κ
Ω(d) samples to generalize in the kernel regime.
For technical contributions, we rigorously prove the convergence of GF for OLM
(Lemma 8.7.3), unlike many existing implicit bias analyses which have to assume
the convergence. We also prove the convergence of limiting flow to the global
minimizer of the regularizer (Lemma 8.7.5) by a trajectory analysis via our
framework. It cannot be proved by previous results [165, 166], as they only
assert convergence to stationary point in the best case.
Modelling Stochastic First-Order Methods with Itô SDE Apart from the
discrete-time analysis, another popular approach to study SGD is through the
continuous-time lens using SDE [190, 191, 214]. Such an approach is often more
elegant and can provide fruitful insights like the linear scaling rule [215, 216] and
the intrinsic learning rate [50]. A recent work by Li et al. [217] justifies such SDE
approximation. Xie et al. [218] gave a heuristic derivation explaining why SGD favors
flat minima with SDE approximation. Wojtowytsch [219] showed that the invariant
distribution of the canonical SDE approximation of SGD will collapse to some manifold
312
of minimizers and in particular, favors flat minima. By approximating SGD using a
SDE with slightly modified covariance for the overparametrized linear model, Pesme
et al. [176] relates the strength of implicit regularization to training speed.
8.3 Notations
In this section, we review a few basics of stochastic processes that will be useful
for proving our results. We refer the reader to classics like Karatzas and Shreve
[220], Billingsley [221], Pollard [222] for more systematic derivations.
Throughout the rest of this section, let E be a Banach space equipped with norm
k · k, e.g., (R, | · |) and (RD , k · k2 ).
313
8.4.1 Càdlàg Function and Metric
Definition 8.4.2 (Continuity modulus). For any function f : [0, ∞) → E and any
interval I ⊆ [0, ∞), we define
For any N ∈ N and θ > 0, we further define the continuity modulus of continuous f as
0
ωN (f, θ) = inf max ω(f ; [ti−1 , ti ) : 0 ≤ t0 < · · · < tr = N, inf (ti − ti−1 ) ≥ θ .
i≤r i<r
314
We then further define Jδ : DRD [0, ∞) → DRD [0, ∞) [187] as
X
Jδ (g)(t) = hδ (k∆g(s)k)∆g(s). (8.6)
0<s≤t
Definition 8.4.4 (Skorokhod metric on DE [0, ∞)). For each finite T > 0 and each
pair of functions f, g ∈ DE [0, ∞), define dT (f, g) as the infimum of all those values of
δ for which there exist grids 0 ≤ t0 < t1 < · · · < tm and 0 < s0 < s1 < · · · < · · · < sm ,
with tk , sk ≥ T , such that |ti − si | ≤ δ for i = 0, . . . , k, and
∞
X
d(f, g) = 2−T min{1, dT (f, g)}.
T =1
Definition 8.4.5 (Cross variation). Let X and Y be two {Ft }t≥0 -adapted stochastic
processes such that X has sample paths in DRD×e [0, ∞) and Y has samples paths
in DRe [0, ∞), then the cross variation of X and Y on (0, t], denoted by [X, Y ](t), is
defined to be the limit of
m−1
X
(X(ti+1 ) − X(ti ))(Y (ti+1 ) − Y (ti ))
i=0
315
in probability as the mesh size of 0 = t0 < t1 < · · · < tm = t goes to 0, if it exists.
Moreover, for Y itself, we write
e
X
[Y ] = [Yi , Yi ]
i=1
Definition 8.4.6 (Martingale). Let {X(t)}t≥0 be a {Ft }t≥0 -adapted stochastic process.
If for all 0 ≤ s ≤ t, it holds that
E[X(t) | Fs ] = X(s),
Definition 8.4.7 (Local martingale). Let {X(t)}t≥0 be a {Ft }t≥0 -adapted stochastic
process. If there exists a sequence of {Ft }t≥0 -stopping time, {τk }k≥0 , such that
Definition 8.4.9 (Itô’s Stochastic Integral). If {X(t)}t≥0 and {Y (t)}t≥0 are adapted
stochastic processes, X has sample paths in DRd×e [0, ∞), sample paths in DRe [0, ∞)
Rt
and Y is a semimartingale, then the integral s XdY is defined, as the limit of
Pn−1
i=0 X(ri )(Y (ri+1 ) − Y (ri )) where s = r0 < r1 < . . . < rn = t, the limit being in
probability as the mesh size goes to 0. Standard results in stochastic calculus imply
that this limit exists. We call X the integrand and Y the integrator.
316
Since all deterministic process are adapted, the above definition of integral also
makes sense for deterministic functions and is a generalization of standard Riemman-
Stieltjes Integral. The difference is that in the above Itô’s Stochastic Integral we use
the left-end value of the integrand but the existence of Riemman-Stieltjes Integral
requires the limit exists for any point within the interval. When X and Y don’t jump
together, Riemman-Stieltjes Integral exists and coincides with the Itô’s Integral.
Lemma 8.4.10 (Itô’s Lemma). Let {X(t)}t≥0 be defined through the following Itô
drift-diffusion process:
where {W (t)}t≥0 is the standard Brownian motion. Then for any twice differentiable
function f , it holds that
∂f 1
df (t, X(t)) = + (∇x f ) µt + tr[σ ∇x f σ] dt + (∇x f )> σ(t)dW (t).
> > 2
∂t 2
Let (DE [0, ∞), A, d) be a metric space equipped with a σ-algebra A and the Skorokhod
metric defined in the previous subsection.
Let {Xn }n≥0 be a sequence of stochastic processes on a sequence of probability
spaces {(Ωn , Fn , Pn )}n≥0 such that each Xn has sample paths in DE [0, ∞). Also, let
X be a stochastic process on (Ω, F, P) with sample paths on DE [0, ∞).
317
holds that
Definition 8.4.12 (δ-Prohorov distance). Let δ > 0. For any two probability measures
P and Q on a metric space with metric d, let (X, Y ) be a coupling such that P is the
marginalized law of X and Q that of Y . We define
Note this distance is not a metric because it does not satisfy triangle inequality.
Definition 8.4.13 (Prohorov metric). For any two probability measures P and Q on
a metric space with metric d, let (X, Y ) be a coupling such that P is the marginalized
law of X and Q that of Y . Denote the marginal laws of X and Y by L(X) and L(Y )
respectively. We define the Prohorov metric as
Definition 8.4.15 (Uniform metric on DE [0, ∞)). For each finite T > 0 and each
pair of functions f, g ∈ DE [0, T ), the uniform metric is defined to be
∞
X
dU (f, g) = 2−T min{1, dU (f, g; T )}.
T =1
Remark 8.4.17. We shall note the uniform metric defined above is weaker than
supt∈[0,∞) kf (t) − g(t)k. Convergence in the uniform metric on [0, ∞] defined in
Definition 8.4.15 is equivalent to convergence in the uniform metric on each compact
set [0, T ] for T ∈ N+ . The same holds for the Skorokhod topology.
Manifold of Minimizers
In this section, we first state our assumptions about the loss function in Section 8.5.1.
In Section 8.5.1 In Section 8.5.2 we recap the main result of Katzenberger [187]. In
Section 8.5.3 we derive the closed-form expressions of ∂Φ and ∂ 2 Φ. We present our
319
main result in Section 8.5.4. We remark that sometimes we omit the dependency on t
to make things clearer.
Following Fehrman et al. [203], we make the following important assumption about
the loss function.
Let U be the sets of points starting from which, gradient flow w.r.t. loss L
converges to some point in Γ, that is, U := {x ∈ RD | Φ(x) exists and Φ(x) ∈ Γ}.
Assumption 9.5.1 implies that U is open and Φ is C 3 on U . (By Lemma 8.8.2)
When does such a manifold exist? The vast overparametrization in modern
deep learning is a major reason for the set of global minimizers to appear as a
Riemannian manifold (possibly with multiple connected components), instead of
isolated ones. Suppose all global minimizers interpolate the training dataset, i.e.,
∀x ∈ RD , L(x) = minx0 ∈RD L(x0 ) implies fi (x) = yi for all i ∈ [n], then by preimage
theorem [223], the manifold Γ := {x ∈ RD | fi (x) = yi , ∀i ∈ [n]} is of dimension
D − n if the Jacobian matrix [∇f1 (x), . . . , ∇fn (x)] has rank n for all x ∈ Γ. Note
this condition is equivalent to that NTK at x has full rank, which is very common in
literature.
The smoothness assumption is satisfied for networks with smooth activation
functions like tanh and GeLU [76].The assumption rank (∇2 L(x)) = M basically saies
∇2 L(x) always attains the maximal rank in the normal space of the manifold, which
ensures the differentiability of Φ and is crucial to our current analysis, though it’s not
clear if it is necessary.
320
8.5.2 Recap of Katzenberger’s Theorem
Z t Z t
Xn (t) = X(0) + σ(Xn (s)dZn (s) + −∇L(Xn (s))dAn (s) (8.8)
0 0
In particular, when the integrator sequence {An }n≥1 increases infinitely fast,
meaning that ∀ > 0, inf t≥0 (An (t + ) − An (t)) → ∞ as n → ∞, we call (8.8) a
Katzenberger process.
One difficulty for directly studying the limiting dynamics of Xn (t) is that the
point-wise limit as n → ∞ become discontinuous at t = 0 if X(0) ∈
/ Γ. The
reason is that clearly limn→∞ Xn (0) = X(0), but for any t > 0, since {An }n≥1
increases infinitely fast, one can prove limn→∞ Xn (t) ∈ Γ! To circumvent this issue,
we consider Yn (t) = Xn (t) − φ(X(0), An (t)) + Φ(X(0)). Then for each n ≥ 1, we have
Yn (0) = Φ(X(0)) and limn→∞ Yn (t) = limn→∞ Xn (t). Thus Yn (t) has the same limit
on (0, ∞) as Xn (t), but the limit of the former is further continuous at t = 0.
Z t Z t
1 XD
Y (t) = Y (0) + ∂Φ(Y )σ(Y )dW (s) + ∂ij Φ(Y )(σ(Y )σ(Y )> )ij ds.
0 2 i,j=1 0
(8.9)
321
Indeed, SGD (8.1) can be rewritten into a Katzenberger process as in the following
lemma.
Lemma 8.5.3. Let {ηn }∞ n=1 be any positive sequence with limn→∞ ηn = 0, An (t) =
Pbt/η2 c √ i.i.d.
ηn bt/ηn2 c, and Zn (t) = ηn k=1n Ξ(1ξk − Ξ1 1), where ξ1 , ξ2 , . . . ∼ Unif([Ξ]). Then
with the same initialization Xn (0) = xηn (0) ≡ X(0), Xn (kηn2 ) defined by (8.8) is a
Katzenberger process and is equal to xηn (k) defined in (8.1) with LR equal to ηn for
all k ≥ 1. Moreover, the counterpart of (8.9) is
Z t Z t
1
Y (t) = Φ(X(0)) + ∂Φ(Y )σ(Y )dW (s) + ∂ 2 Φ(Y )[Σ(Y )]ds, (8.10)
0 2 0
However, there are two obstacles preventing us from directly applying Theorem 8.5.2
to SGD. First, the stochastic integral in (8.10) depends on the derivatives of Φ, ∂Φ
and ∂ij Φ, but Katzenberger [187] did not give their dependency on loss L. To resolve
this, we explicitly calculate the derivatives of Φ on Γ in terms of the derivatives of L
in Section 8.5.3.
The second difficulty comes from the convergence of (Yn , Zn ) which we assume
as granted for brevity in Theorem 8.5.2. In fact, the full version of Theorem 8.5.2
(see Theorem 8.8.8) concerns the stopped version of Yn with respect to some compact
µ (K)
K ⊂ U , i.e., Yn n (t) = Yn (t ∧ µn (K)) where µn (K) is the stopping time of Yn
leaving K. As noted in Katzenberger [187], we need the convergence of µn (K) for
µ (K)
Yn n to converge, which is a strong condition and difficult to prove in our cases.
We circumvent this issue by proving Theorem 8.8.10, a user-friendly interface for the
original theorem in Katzenberger [187], and it only requires the information about the
limiting diffusion. Building upon these, we present our final result as Theorem 8.5.7.
322
8.5.3 Closed-Form expression of the limiting diffusion
We can calculate the derivatives of Φ by relating to those of L. Here the key observation
is the invariance of Φ along the trajectory of GF. The proofs of this section are deferred
into Section 8.9.
Lemma 8.5.4. For any x ∈ Γ, ∂Φ(x) is the orthogonal projection matrix onto tangent
space Tx (Γ).
Lemma 8.5.6. Let x be any point in Γ and Σ = Σ(x) = σσ > (x) ∈ RD×D be the noise
covariance at x1 . Then Σ can be decomposed as Σ = Σk + Σ⊥ + Σk,⊥ + Σ⊥,k , where
Σk := ∂ΦΣ∂Φ, Σ⊥ := (ID − ∂Φ)Σ(ID − ∂Φ) and Σk,⊥ = Σ>
⊥,k = ∂ΦΣ(ID − ∂Φ) are the
noise covariance in tangent space, normal space and across both spaces, respectively.
Then it holds that
(8.11)
Now we are ready to present our main result. It’s a direct combination of Theo-
rem 8.8.10 and Lemma 8.5.6.
Theorem 8.5.7. Suppose the loss function L, the manifold of local minimizer Γ and
the open neighborhood U satisfy Assumption 8.5.1 and ??, and xη (0) = x(0) ∈ U
1
For notational convenience, we drop dependency on x.
323
for all η > 0. If SDE (8.12) has a global solution Y with Y (0) = x(0) and Y never
leaves U , i.e., P[Y (t) ∈ U, ∀t ≥ 0] = 1, then for any T > 0, xη (bT /η 2 c) converges in
distribution to Y (T ) as η → 0.
1 1
dY (t) = Σk2 (Y )dW (t) − ∇2 L(Y )† ∂ 2 (∇L)(Y ) Σk (Y ) dt
(8.12)
| {z } |2 {z }
Tangent Noise Tangent Noise Compensation
1 2
2 †
2
−1
− ∂Φ(Y ) 2 ∂ (∇L)(Y ) ∇ L(Y ) Σ⊥,k (Y ) + ∂ (∇L)(Y ) L∇2 L (Σ⊥ (Y )) dt,
2 | {z } | {z }
Mixed Regularization Normal Regularization
Based on the above theorem, the limiting dynamics of SGD can be understood as
1/2
follows: (a) the Tangent Noise, Σk (Y )dW (t), is preserved, and the second term
of (8.12) can be viewed as the necessary Tangent Noise Compensation for the
limiting dynamics to stay on Γ. Indeed, Lemma 8.9.7 shows that the value of the second
term only depends on Γ itself, i.e., it’s same for all loss L which locally defines the
same Γ. (b) The noise in the normal space is killed since the limiting dynamics always
stay on Γ. However, its second order effect (Itô correction term) takes place as a vector
field on Γ, which induces the Noise Regularization and Mixed Regularization
term, corresponding to the mixed and normal noise covariance respectively.
Remark 8.5.8. In Section 8.8.4 we indeed prove a stronger version of Theorem 8.5.7
eη (t) = xη (bt/η 2 c),
that the sample paths of SGD converge in distribution, i.e., let x
then x
eη weakly converges to Y on [0, T ]. Moreover, we only assume the existence of
a global solution for ease of presentation. As long as there exists a compact K ⊆ Γ
such that Y stays in K on [0, T ] with high probability, Theorem 8.8.10 still provides
the convergence of SGD iterates (stopped at the boundary of K) before time T with
high probability.
324
8.6 Implications and Examples
In this section, we derive the limiting dynamics for two notable noise types, where we
fix the expected loss L and the noise distribution, and only drive η to 0. The proofs
are deferred into Section 8.10.
Type I: Isotropic Noise. Isotropic noise means Σ(x) ≡ ID for any x ∈ Γ [189].
The following theorem shows that the limiting diffusion with isotropic noise can be
viewed as a Brownian Motion plus Riemannian Gradient Flow with respect to the
pseudo-determinant of ∇2 L.
1 1
dY (t) = ∂Φ(Y )dW − ∇2 L(Y )† ∂ 2 (∇L)(Y ) [∂Φ(Y )] dt − ∂Φ(Y )∇(ln |∇2 L(Y )|+ )dt
| 2 {z } |2 {z }
Brownian Motion on Manifold Normal Regularization
(8.13)
Type II: Label Noise. When doing SGD for `2 -regression on dataset {(zi , yi )}ni=1 ,
adding label noise [165, 166] means replacing the true label at iteration k, yik , by a
i.i.d.
fresh noisy label yeik := yik + δk , where δk ∼ Unif{−δ, δ} for some constant δ > 0.
Then the corresponding loss becomes 12 (fik (x) − yeik )2 , where fik (x) is the output of
the model with parameter x on data zik . So the label noise SGD update is
xk+1 = xk − η/2 · ∇x (fik (xk ) − yik + δk )2 = xk − η(fik (xk ) − yik + δk )∇x fik (xk ).
(8.14)
325
Suppose the model can achieve the global minimum of the loss L(x) := 12 E[(fi (x)− yei )2 ]
at x∗ , then the model must interpolate the whole dataset, i.e., fi (x∗ ) = yi for all
i ∈ [n], and thus here the manifold Γ is a subset of {x ∈ RD | fi (x) = yi , ∀i ∈
[n]}. Here the key property of the label noise used in previous works is Σ(x) =
δ2
Pn > 2 2
n i=1 ∇x fi (x)∇x fi (x) = δ ∇ L(x). Lately, Damian et al. [166] further generalizes
the analysis to other losses, e.g., logistic loss and exponential loss, as long as they
satisfy Σ(x) = c∇2 L(x) for some constant c > 0.
In sharp contrast to the delicate discrete-time analysis in Blanc et al. [165] and
Damian et al. [166], the following corollary recovers the same result but with much
simpler analysis – taking derivatives is all you need. Under our framework, we no
longer need to do Taylor expansion manually nor carefully control the infinitesimal
variables of different orders together. It is also worth mentioning that our framework
immediately gives a global analysis of Θ(η −2 ) steps for SGD, far beyond the local
coupling analysis in previous works. In Section 8.7, we will see how such global analysis
allows us to prove a concrete generalization upper bound in a non-convex problem,
the overparametrized linear model [20, 21].
Corollary 8.6.2 (Limiting Flow for Label Noise). If Σ ≡ c∇2 L on Γ for some
constant c > 0, SDE (8.12) can be simplified into (8.15) where the regularization is
from the noise in the normal space.
Example: k-Phase Motor We also give an example with rigorous proof where the
implicit bias induced by noise in the normal space cannot be characterized by a fixed
regularizer, which was first discovered by Damian et al. [166] but was only verified via
experiments.
326
Note the normal regularization in both cases of label noise and isotropic noise
induces Riemmanian gradient flow against some regularizer, it’s natural to wonder if
the limiting flow induced by the normal noise can always be characterized by certain
regularizer. Interestingly, Damian et al. [166] answers this question negatively via
experiments in their Section E.2. We adapt their example into the following one, and
rigorously prove the limiting flow moves around a cycle at a constant speed and never
stops using our framework.
x1:2
Suppose dimension D = k + 2 ≥ 5. For each x ∈ RD , we decompose x = x3:D
where x1:2 ∈ R2 and x3:D ∈ RD−2 . Let Qθ ∈ R2×2 be the rotation matrix of angle θ, i.e.,
θ − sin θ
and the loss L(x) := 18 (kx1:2 k22 − 1)2 + 12 D
j−3 2
Qθ = cos
P
sin θ cos θ j=3 (2 + hQα v, x1:2 i xj ,
2π
where α = D−2
and v is any vector in R2 with unit norm. Here the manifold is given
by Γ := {x | L(x) = 0} = {x ∈ RD | x21 + x22 = 1, xj = 0, ∀j = 3, . . . , D}.
The basic idea is that we can add noise in the ‘auxiliary dimensions’ for j = 3, . . . , D
to get the regularization force on the circle {x21 + x22 = 1}, and the goal is to make the
vector field induced by the normal regularization always point to the same direction,
say anti-clockwise. However, this cannot be done with a single auxiliary dimension
because from the analysis for label noise, we know when L−1
∇2 L (Σ⊥ ) is identity, the
normal regularization term in Equation (8.12) has 0 path integral along the unit
circle and thus it must have both directions. The key observation here is that we
can align the magnitude of noise with the strength of the regularization to make the
path integral positive. By using k ≥ 3 auxiliary dimensions, we can further ensure
the normal regularization force is anti-clockwise and of constant magnitude, which is
reminiscent of how a three-phase induction motor works.
hQj−3
α v, x1:2 i), if i = j ≥ 3 or 0 otherwise, then the solution of SDE (8.12) is the
following (8.16) , which implies that Y (t) moves anti-clockwise with a constant angular
327
speed of (D − 2)/2.
Noise
In this section, we show provable benefit of label noise in generalization using our
framework (Theorem 8.8.8) in a concrete setting, the overparametrized linear models
(OLM) [20]. While the existing implicit regularization results for Gradient Flow often
relates the generalization quality to initialization, e.g., Woodworth et al. [20] shows
that for OLM, small initialization corresponds to the rich regime and prefers solutions
with small `1 norm while large initialization corresponds to the kernel regime and
prefers solutions with small `2 norm, our result Theorem 8.7.1 surprisingly proves
that even if an OLM is initialized in the kernel regime, label noise SGD can still
help it escape and then enter the rich regime by minimizing its weighted `1 norm.
When the groundtruth is κ-sparse, this provides a O(κ
e ln d) vs Ω(d) sample complexity
separation between SGD with label noise and GD when both initialized in the kernel
regime. Here d is the dimension of the groundtruth. The lower bound for GD in the
kernel regime is folklore, but for completeness, we state the result as Theorem 8.7.7 in
Section 8.7.3 and append its proof in Section 8.11.6.
Theorem 8.7.1. In the setting of OLM, suppose the groundtruth is κ-sparse and
n ≥ Ω(κ ln d) training data are sampled from either i.i.d. Gaussian or Boolean
distribution. Then for any initialization xinit (except a zero-measure set) and any
> 0, there exist η0 , T > 0 such that for any η < η0 , OLM trained with label noise
SGD (8.14) with LR equal to η for bT /η 2 c steps returns an -optimal solution, with
probability of 1 − e−Ω(n) over the randomness of the training dataset.
328
The proof roadmap of Theorem 8.7.1 is the following:
1. Show Assumption 8.5.1 is satisfied, i.e., the set of local minimizers, Γ, is indeed
a manifold and the hessian ∇2 L(x) is non-degenerate on Γ (by Lemma 8.7.2);
3. Show the limiting flow (8.15) converges to the minimizer of the regularizer (by
Lemma 8.7.5);
4. Show the minimizer of the regularizer recovers the groundtruth (by Lemma 8.7.6).
Our setting is more general than HaoChen et al. [21], which assumes w∗ ∈ {0, 1}d
and their reparametrization can only express positive linear functions, i.e., w = u 2 .
e 2 ) rate is achieved with a delicate three phase LR schedule, while our
Their O(κ
O(κ ln d) rate only uses a constant LR.
i.i.d.
Setting: Let {(zi , yi )}i∈[n] be the training dataset where z1 , . . . , zn ∼ Unif({±1}d )
or N (0, Id ) and each yi = hzi , w∗ i for some unknown w∗ ∈ Rd . We assume that w∗ is
κ-sparse for some κ < d. Denote x = uv ∈ RD = R2d , and we will use x and (u, v)
Xn
L(x) = L(u, v) = 1
n
`i (u, v), where `i (u, v) = 21 (fi (u, v) − yi )2 . (8.17)
i=1
4
Pn zi u zi u >
It is straightforward to verify that ∇2 L(x) = n i=1 −zi v −zi v
, ∀x ∈ Γ. For
simplicity, we define Z = (z1 , . . . , zn )> ∈ Rn×d and Y = (y1 , . . . , yn )> ∈ Rn . Consider
the following manifold:
329
We verify that the above loss function L and manifold Γ satisfy Assumption 8.5.1 by
Lemma 8.7.2, and that the neighborhood U and Γ satisfy ?? by Lemma 8.7.3.
Lemma 8.7.2. Consider the loss L defined in (8.17) and manifold Γ defined in (8.18).
If data is full rank, i.e., rank(Z) = n, then it holds that (a). Γ is a smooth manifold
of dimension D − n; (b). rank(∇2 L(x)) = n for all x ∈ Γ. In particular, rank(Z) = n
holds with probability 1 for Gaussian distribution and with probability 1−cd for Boolean
distribution for some constant c ∈ (0, 1).
Lemma 8.7.3. Consider the loss function L defined in (8.17), manifold Γ and its
dxt
open neighborhood defined in (8.18). For gradient flow dt
= −∇L(xt ) starting at any
x0 ∈ U , it holds that Φ(x0 ) ∈ Γ.
Remark 8.7.4. In previous works [20, 138], the convergence of gradient flow is
only assumed. Recently Pesme et al. [176] proved it for a specific initialization, i.e.,
uj = vj = α, ∀j ∈ [n] for some α > 0. Lemma 8.7.3 completely removes the technical
assumption.
Therefore, by the result in the previous section, the implicit regularizer on the
manifold is R(x) = tr(Σ(x)) = tr(δ 2 ∇2 L(x)). Without loss of generality, we take
δ = 1. Hence, it follows that
4 X D X n 2 2
R(x) = zi,j (uj + vj2 ). (8.19)
n j=1 i=1
The limiting behavior of label noise SGD is described by a Riemannian gradient flow
on Γ as follows:
The goal is to show that the above limiting flow will converge to the underlying
∗ 1/2 1/2
groundtruth x∗ = uv∗ where (u∗ , v ∗ ) = ([w∗ ]+ , [−w∗ ]+ ).
330
8.7.1 Limiting Flow Converges to Minimizers of Regularizer
In this subsection we show limiting flow (8.15) starting from anywhere on Γ converges
to the minimizer of regularizer R (by Lemma 8.7.5). The proof contains two parts:
(a) the limiting flow converges; (b) the limit point of the flow cannot be sub-optimal
stationary points. These are indeed the most technical and difficult parts of proving
the O(κ ln d) upper bound, where the difficulty comes from the fact that the manifold
Γ is not compact, and the stationary points of the limiting flow are in fact all located
on the boundary of Γ. However, the limiting flow itself is not even defined on the
boundary of the manifold Γ. Even if we can extend ∂Φ(·)∇R(·) continuously to entire
RD , the continuous extension is not everywhere differentiable.
Thus the non-compactness of Γ brings challenges for both (a) and (b). For (a), the
convergence for standard gradient flow is often for free, as long as the trajectory is
bounded and the objective is analytic or smooth and semialgebraic. The latter ensures
the so-called Kurdyka-Lojasiewicz (KL) inequality [224], which implies finite trajectory
length and thus the convergence. However, since our flow does not satisfy those nice
properties, we have to show that the limiting flow satisfies Polyak-Lojasiewicz condition
(a special case of KL condition) [225] via careful calculation (by Lemma 8.11.16).
For (b), the standard analysis based on center stable manifold theorem shows that
gradient descent/flow converges to strict saddle (stationary point with at least one
negative eigenvalue in hessian) only for a zero-measure set of initialization [119, 121].
However, such analyses cannot deal with the case where the flow is not differentiable
at the sub-optimal stationary point. To circumvent this issue, we prove the non-
convergence to sub-optimal stationary points with a novel approach: we show that for
any stationary point x, whenever there exists a descent direction of the regularizer
R at x, we can construct a potential function which increases monotonically along
the flow around x, while the potential function is equal to −∞ at x, leading to a
contradiction. (See proof of Lemma 8.7.5.)
331
Lemma 8.7.5. Let {xt }t≥0 ⊆ RD be generated by the flow defined in (8.20) with any
initialization x0 ∈ Γ. Then x∞ = limt→∞ xt exists. Moreover, x∞ = x∗ is the optimal
solution of (8.21).
Groundtruth
1
Pn 2 iid
Note n i=1 zi,j = 1 when zi,j ∼ Unif{−1, 1}, and we can show minimizing R(x) on
Γ, (8.21), is equivalent to finding the minimum `1 norm solution of Equation (8.17).
Standard results in sparse recovery imply that minimum `1 norm solution recovers
with the sparse groundtruth. The gaussian case is more complicated but still can be
proved with techniques from Tropp [226].
4 X d X n 2 2
minimize R(x) = zi,j (uj + vj2 ),
n j=1 i=1
(8.21)
2 2 ∗
subject to Z(u − v ) = Zw .
i.i.d.
Lemma 8.7.6. Let z1 , . . . , zn ∼ Unif({±1}d ) or N (0, Id ). Then there exist some
constants C, c > 0 such that if n ≥ Cκ ln d, then with probability at least 1 − e−cn , the
optimal solution of (8.21), (b
u, vb), is unique up to sign flips of each coordinate and
recovers the groundtruth, i.e., u
b 2
− vb 2
= w∗ .
Regime
In this subsection we show GD needs at least Ω(d) samples to learn OLM, when
initialized in the kernel regime. This lower bound holds for all learning rate schedules
and numbers of steps. This is in sharp contrast to the O(κ
e ln d) sample complexity
upper bound of SGD with label noise. Following the setting of kernel regime in [20],
we consider the limit of u0 = v0 = α1, with α → ∞. It holds that fi (u0 , v0 ) = 0 and
332
∇fi (u0 , v0 ) = [αzi , −αzi ] for each i ∈ [n]. Standard convergence analysis for NTK
(Neural Tangent Kernel, Jacot et al. [15]) shows that upon convergence, the distance
traveled by parameter converges to 0, and thus the learned model shall converge
in function space, so is the generalization performance. For ease of illustration, we
directly consider the lower bound for test loss when the NTK is fixed throughout the
training.
i.i.d.
Theorem 8.7.7. Assume z1 , . . . , zn ∼ N (0, Id ) and yi = zi> w∗ , for all i ∈ [n].
Define the loss with linearized model as L(x) = ni=1 (fi (x0 ) + h∇fi (x0 ), x − x0 i − yi )2 ,
P
rate schedule {ηt }t≥1 , and any fixed number of steps T , the expected `2 loss of x(T )
is at least (1 − nd ) kw∗ k22 , where x(T ) is the T -th iterate of GD on L, i.e., x(t + 1) =
x(t) − ηt ∇L(x(t)), for all t ≥ 0.
In this section, we give a complete derivation of the limiting diffusion of SGD. Here
we use ⇒ to denote the convergence in distribution. For any U ⊆ RD , we denote by
Ů its interior. For linear space S, we use S ⊥ to denote its orthogonal complement.
First, as mentioned in ??, we verify that the mapping Φ is C 2 in ??. In Section 8.8.1
we discuss how different time scalings could affect the coefficients in SDE (8.2) and
(8.3). Then we check the necessary conditions for applying the results in Katzenberger
[187] in Section 8.8.2 and recap the corresponding theorem for the asymptotically
continuous case in Section 8.8.3. Finally, we provide a user-friendly interface for
Katzenberger’s theorem in Section 8.8.4.
Lemma 8.8.1. If limn→∞,n∈N φ(x, n) exists, then Φ(x) also exists and Φ(x) =
limn→∞,n∈N φ(x, n).
333
Proof of Lemma 8.8.1. Suppose K ⊆ RD is a compact set and φ(x, t) ∈ K for all
t ∈ [0, T ], we have that
d k∇L(φ(x, t))k2
= −2∇L(φ(x, t))∇2 L(φ(x, t))∇L(φ(x, t)) ≤ 2ρK k∇L(φ(x, t))k2 ,
dt
where ρK denotes supx∈K k∇2 L(φ(x, t))k. This implies that k∇L(φ(x, t))k ≤
eρK T k∇L(x)k and kφ(x, t) − xk ≤ eρK T k∇L(x)k for all t ∈ [0, T ].
Now suppose limn→∞,n∈N φ(x, n) = x∗ , we know ∇L(x∗ ) and k∇L(φ(x, n))k must
converges to 0 due to the continuity of ∇L. Take any compact neighborhood of x∗ as the
above defined K, we know there exists N > 0, such that for all n > N , k∇L(Φ(x, n))k
and kφ(x, n) − x∗ k are small enough such that φ(x, n + δ) ∈ K for all δ ∈ [0, 1].
Therefore, we know that when t → ∞ as a real number, kφ(x, t) − φ(x, btc)k ≤
eρK k∇L(φ(x, btc))k → 0. This completes the proof.
Proof of Lemma 8.8.2. Since Γ is a (D−M ) dimensional manifold and rank(∇2 L(x)) =
M , it holds that for any x ∈ Γ, there exists a small open set containing x, Vx , such that
Γ ∩ Vx is the set of the stationary points of L in Vx , i.e., Γ ∩ Vx = {y ∈ Vx | ∇L(y) = 0}.
Thus Γ is the set of the stationary points of L in open set V := ∪x∈Γ Vx . We further
define f : RD → RD , f (x) = φ(x, 1), and we have Γ is the set of the fixed points of
mapping f . Since L is C k , ∇L is C k−1 and thus f is C k−1 . By Theorem 5.1 in [227],
we know that there is an open set N containing Γ and f ∞ (x) is well-defined and C 3 on
N with f ∞ (x) ∈ Γ for any x ∈ Γ, where f ∞ (x) := limn→∞ f n (x), f n (x) = f (f n−1 (x))
and f 1 (x) = f (x). By Lemma 8.8.1, we know that Φ(x) = f ∞ (x).
Since V ⊇ Γ is open and Φ(x) ∈ Γ for all x ∈ U , we know that there is a t > 0,
such that φ(x, t) ∈ V . Thus U = ∪t≥0 φ(V, −t) is a union of open sets, which is still
334
open. Moreover, Φ(x) = Φ(φ(x, t)) for each x ∈ U and some t > 0 with φ(x, t) ∈ V .
Since Φ is C 3 in V and φ(·, t) is C k−1 for any t, we conclude that Φ is C 3 in U .
Let’s first clarify how we derive the SDEs, (8.2) and (8.3), that approximate SGD
(8.1) under different time scalings. Recall W (t) is Ξ-dimensional Brownian motion
and that σ(X) : RD → RD×Ξ is a deterministic noise function. As proposed by [190],
one approach to approximate SGD (8.1) by SDE is to consider the following SDE:
√
dX(t) = −∇L(X(t))dt + ησ(X(t))dW (t),
e = dX(tη) = −∇L(X(tη))d(tη) + √
dX(t) ησ(X(tη))dW (tη)
√
= −η∇L(X(tη))dt + ησ(X(tη))dW (tη).
335
Then, to accelerate the above SDE by η −2 times, let’s define X̄(t) = X(t/η
e 2
).
Then it follows that
2 2
dX̄(t) = dX(t/η
e ) = −η∇L(X(t/η
e ))dt/η 2 + ησ(X(t/η
e 2
))dW (t/η 2 )
1
= − ∇L(X̄(t))dt + σ(X̄(t))d ηW (t/η 2 )
η
d
Again note that ηW (t/η 2 ) = W (t) in sample paths and thus is also a Ξ-Brownian
motion. Here the time correspondence is t = kη 2 , i.e., evolving for constant time with
the above SDE approximates Ω(1/η 2 ) steps of SGD. In this way, we derive SDE (8.3)
in the main context.
Below we collect the necessary conditions imposed on {Zn }n≥1 and {An }n≥1 in Katzen-
berger [187]. Recall that we consider the following stochastic process
Z t Z t
Xn (t) = X(0) + σ(Xn (s))dZn (s) − ∇L(Xn (s))dAn (s).
0 0
For any stopping time τ , the stopped process is defined as Xnτ (t) = Xn (t ∧ τ ). For any
compact K ⊂ U , we define the stopping time of Xn leaving K as λn (K) = inf{t ≥ 0 |
Xn (t−) ∈
/ K̊ or Xn (t) ∈
/ K̊}.
Condition 8.8.4. The integrator sequence {An }n≥1 increases infinitely fast: ∀ > 0,
inf (An (t + ) − An (t)) ⇒ ∞.
t≥0
336
Condition 8.8.5 (Eq.(5.1), Katzenberger 187). For every T > 0, as n → ∞, it holds
that
lim lim sup P sup (Tt+γ (Fn ) − Tt (Fn )) > = 0,
γ→0 n→∞ 0≤t≤T
for every > 0 and T > 0, where Tt (·) denotes total variation on the interval [0, t].
Lemma 8.8.7. For SGD iterates defined using the notation in Lemma 8.5.3, the
sequences {An }n≥1 and {Zn }n≥1 satisfy Condition 8.8.3, 8.8.4, 8.8.5 and 8.8.6.
Proof of Lemma 8.8.7. Condition 8.8.3 is obvious from the definition of {An }n≥1 .
Next, for any > 0 and t ∈ [0, T ], we have
t + − ηn2 − ηn2
t+ t t
An (t + ) − An (t) = ηn · − η n · ≥ − = ,
ηn2 ηn2 ηn ηn ηn
which implies that inf 0≤t≤T (An (t + ) − An (t)) > /(2ηn ) for small enough ηn . Then
taking n → ∞ yields the Condition 8.8.4.
337
For Condition 8.8.5, note that
√
ηn Ξ(1ξk − 1 1) if t = k · ηn2 ,
Ξ
∆Zn (t) =
0
otherwise.
√
Therefore, we have k∆Zn (t)k2 ≤ 2ηn Ξ for all t > 0. This implies that k∆Zn (t)k2 → 0
uniformly over t > 0 as n → ∞, which verifies Condition 8.8.5.
We proceed to verify Condition 8.8.6. By the definition of Zn , we know that
{Zn (t)}t≥0 is a jump process with independent increments and thus is a martingale.
Therefore, by decomposing Zn = Mn + Fn with Mn being a local martingale and Fn a
finite variation process, we must have Fn = 0 and Mn is Zn itself. It then suffices to
show that [Mn ](t ∧ τnm ) is uniformly integrable for every t ≥ 0 and m ≥ 1. Since Mn
is a pure jump process, we have
X X
[Mn ](t ∧ τnm ) = k∆Mn (s)k22 ≤ k∆Mn (s)k22
0<s≤t∧τnm 0<s≤t
2c
bt/ηn 2c
bt/ηn
√
2
X 1 X
= ηn Ξ 1ξk − 1 ≤ 4Ξ ηn2 ≤ 4Ξt.
k=1
Ξ 2 k=1
This implies that [Mη ](t ∧ τηm ) is universally bounded by 4t, and thus [Mη ](t ∧ τηm ) is
uniformly integrable. This completes the proof.
Lemma 8.5.3. Let {ηn }∞ n=1 be any positive sequence with limn→∞ ηn = 0, An (t) =
Pbt/η2 c √ i.i.d.
ηn bt/ηn2 c, and Zn (t) = ηn k=1n Ξ(1ξk − Ξ1 1), where ξ1 , ξ2 , . . . ∼ Unif([Ξ]). Then
with the same initialization Xn (0) = xηn (0) ≡ X(0), Xn (kηn2 ) defined by (8.8) is a
Katzenberger process and is equal to xηn (k) defined in (8.1) with LR equal to ηn for
all k ≥ 1. Moreover, the counterpart of (8.9) is
Z t Z t
1
Y (t) = Φ(X(0)) + ∂Φ(Y )σ(Y )dW (s) + ∂ 2 Φ(Y )[Σ(Y )]ds, (8.10)
0 2 0
338
where Σ ≡ σσ > and {W (t)}t≥0 is a Ξ-dimensional standard Brownian motion.
Proof of Lemma 8.5.3. For any n ≥ 1, it suffices to show that given Xn (kηn2 ) = xηn (k),
we further have Xn ((k + 1)ηn2 ) = xηn (k + 1). By the definition of Xn (t) and note that
An (t), Zn (t) are constants on [kηn2 , (k + 1)ηn2 ), we have that Xn (t) = Xn (kηn2 ) for all
t ∈ [kηn2 , (k + 1)ηn2 ), and therefore
= − ∇L(Xn (kηn2 ))(An ((k + 1)ηn2 ) − An (kηn2 )) + σ(Xn (kηn2 ))(Zn ((k + 1)ηn2 ) − Zn (kηn2 ))
√
= − ηn ∇L(Xn (kηn2 )) + ηn Ξσξk (Xn (kηn2 ))
√
= − ηn ∇L(xηn (k)) + ηn Ξσξk (xηn (k)) = xηn (k + 1) − xηn (k)
where the second equality is because An (t) and Zn (t) are constant on interval [kηn2 , (k +
1)ηn2 ). This confirms the alignment between {Xn (kηn2 )}k≥1 and {xηn (k)}k≥1 .
For the second claim, note that σ(x)EZn (t) ≡ 0 for all x ∈ RD , t ≥ 0 (since the
noise has zero-expectation) and that {Zn (t) − EZn (t)}t≥0 will converge in distribution
to a Brownian motion by the classic functional central limit theorem (see, for example,
Theorem 4.3.5 in Whitt [228]). Thus, the limiting diffusion of Xn as n → ∞ can be
obtained by substituting Z with the standard Brownian motion W in (8.23). This
completes the proof.
ous Case
The full Katzenberger’s theorem deals with a more general case, which only requires
the sequence of intergrators to be asymptotically continuous, thus including SDE (8.3)
and SGD (8.1) with η goes to 0.
339
To describe the results in Katzenberger [187], we first introduce some definitions.
For each n ≥ 1, let (Ωn , F n , {Ftn }t≥0 , P) be a filtered probability space, Zn an Re -
valued cadlag {Ftn }-semimartingale with Zn (0) = 0 and An a real-valued cadlag
{Ftn }-adapted nondecreasing process with An (0) = 0. Let σn : U → M(D, e) be
continuous with σn → σ uniformly on compact subsets of U . Let Xn be an RD -valued
càdlàg {Ftn }-semimartingale satisfying, for all compact K ⊂ U ,
Z t Z t
Xn (t) = X(0) + σ(Xn )dZn + −∇L(Xn )dAn (8.22)
0 0
Z t∧µ
Y (t) = Y (0) + ∂Φ(Y (s))σ(Y (s))dZ(s)
0
D e Z t∧µ
1 X X
+ ∂ij Φ(Y (s))σ(Y (s))ik σ(Y (s))jl d[Zk , Zl ](s). (8.23)
2 i,j=1 k,l=1 0
340
8.8.4 A User-friendly Interface for Katzenberger’s Theorem
Based on the Lemma 8.8.7, we can immediately apply Theorem 8.8.8 to obtain the
following limiting diffusion of SGD.
Theorem 8.8.9. Let the manifold Γ and its open neighborhood U satisfy Assump-
tion 8.5.1 and ??. Let K ⊂ U be any compact set and fix some x0 ∈ K. Consider the
SGD formulated in Lemma 8.5.3 where Xηn (0) ≡ x0 . Define
Yηn (t) = Xηn (t) − φ(Xηn (0), Aηn (t)) + Φ(Xηn (0))
µ (K)
/ K̊}. Then the sequence {(Yηnηn
and µηn (K) = min{t ∈ N | Yηn (t) ∈ , Zηn , µηn (K))}n≥1
is relatively compact in DRD ×Rn [0, ∞) × [0, ∞]. Moreover, if (Y, Z, µ) is a limit point
of this sequence, it holds that Y (t) ∈ Γ a.s for all t ≥ 0, µ ≥ inf{t ≥ 0 | Y (t) ∈
/ K̊}
and Y (t) admits
Z t∧µ Z t∧µ D
1X
Y (t) = ∂Φ(Y (s))σ(Y (s))dW (s) + ∂ij Φ(Y (s))(σ(Y (s))σ(Y (s))> )ij ds
s=0 s=0 2 i,j=1
(8.24)
However, the above theorem is hard to parse and cannot be directly applied if we
want to further study the implicit bias of SGD through this limiting diffusion. There-
fore, we develop a user-friendly interface to it in below. In particular, Theorem 8.5.7
is the a special case of Theorem 8.8.10. In Theorem 8.5.7, we replace ∂Φ(Y (t))σ(Y (t))
1
with Σk2 (Y (t)) to simplify the equation, since ∂Φ(Y (t))σ(Y (t)) (∂Φ(Y (t))σ(Y (t)))> =
Σk (Y (t)) and thus this change doesn’t affect the distribution of the sample paths of
the solution.
341
Theorem 8.8.10. Under the same setting as Theorem 8.8.9, we change the integer
index back to η > 0 with a slight abuse of notation. For any stopping time µ and
stochastic process {Y (t)}t≥0 such that µ ≥ inf{t ≥ 0 | Y (t) ∈
/ K̊}, Y (0) = Φ(x0 )
and that (Y, µ) satisfy Equation (8.24) for some standard Brownian motion W . For
any compact set K ⊆ U and T > 0, define µ(K) = inf{t ≥ 0 | Y (t) ∈
/ K̊} and
δ = P(µ(K) ≤ T ). Then for any > 0, it holds for all sufficiently small LR η that:
which means there is a coupling between the distribution of the stopped processes
µ (K)∧T
Yη η and Y µ(K)∧T , such that the uniform metric between them is smaller than
µ (K)∧T
with probability at least 1 − 2δ. In other words, limη→0 ρ2δ (Yη η , Y µ(K)∧T ) = 0.
Moreover, when {Y (t)}t≥0 is a global solution to the following limiting diffusion
Z t Z t D
1X
Y (t) = ∂Φ(Y (s))σ(Y (s))dW (s) + ∂ij Φ(Y (s))(σ(Y (s))σ(Y (s))> )ij ds
s=0 s=0 2 i,j=1
and Y never leaves U , i.e. P[∀t ≥ 0, Y (t) ∈ U ] = 1, it holds that YηT converges in
distribution to Y T as η → 0 for any fixed T > 0.
For clarity, we break the proof of Theorem 8.8.10 into two parts, devoted to the
two claims respectively.
Proof of the first claim of Theorem 8.8.10. First, Theorem 8.8.9 guarantees there ex-
e and a stochastic process {Ye (t)}t≥0 such that
ists a stopping time µ
1. (Ye , µ
e) satisfies Equation (8.24);
2. Ye ∈ Γ a.s.;
e≥µ
3. µ e(K) := inf{t ≥ 0 | Ye (t) ∈
/ K̊}.
342
The above conditions imply that Ye µe(K) ∈ Γ a.s.. Since the coefficients in Equa-
d
e(K)) = (Y µ(K) , µ(K)). To
tion (8.24) are locally Lipschitz, we claim that (Ye µe(K) , µ
see this, note that for any compact K ⊆ U , the noise function σ, ∂Φ and ∂ 2 Φ are all
Lipschitz on K, thus we can extend their definitions to RD such that the resulting
functions are still locally Lipschitz. Based on this extension, applying classic theorem
on weak uniqueness (e.g., Theorem 1.1.10, Hsu 229) to the extended version of Equa-
tion (8.24) yields the equivalence in law. Thus we only need to prove the first claim
for Ye .
Let ET be the event such that µ
e(K) > T on ET . Then restricted on ET , we have
Ye (T ∧ µ
e) = Ye (T ∧ µ e≥µ
e(K)) as µ e(K) holds a.s. We first prove the claim for any
convergent subsequence of {Yη }η>0 .
µ (K)
Now, let {ηm }m≥1 be a sequence of LRs such that ηm → 0 and Yηmηm ⇒ Ye µe as
m → ∞. By applying the Skorohod representation theorem, we can put {Yηm }m≥1
µ (K)
and Ye under the same probability space such that Yηmηm → Ye µe a.s. in the Skorohod
metric, or equivalently the uniform metric (since Ye µe is continuous) i.e.,
which further implies that for any > 0, there exists some N > 0 such that for all
m > N,
h i
P dU (Yηµmηm (K)∧T , Ye µe∧T ) ≥ ≤ δ.
343
µ (K)∧T µ (K)∧T
Restricted on ET , we have dU (Yηmηm , Ye µe∧T ) = dU (Yηmηm , Ye µe(K)∧T ), and it
follows that for all m > N ,
h i h i
P dU (Yηµmηm (K)∧T , Ye µe(K)∧T ) ≥ ≤ P {dU (Yηm µηm (K)∧T e µ
,Y e(K)∧T
) ≥ } ∩ ET + P [ETc ]
h i
µηm (K)∧T e µ e∧T
= P {dU (Yηm ,Y ) ≥ } ∩ ET + P[ETc ]
h i
≤ P dU (Yηµmηm (K)∧T , Ye µe∧T ) ≥ + P[ETc ]
≤ 2δ,
µ (K)∧T
Now we claim that it indeed holds that limη→0 ρ2δ (Yη η , Ye µe(K)∧T ) = 0. We
prove this by contradiction. Suppose otherwise, then there exists some > 0 such that
µ (K)∧T
for all η0 > 0, there exists some η < η0 with ρ2δ (Yη η , Ye µe(K)∧T ) > . Consequently,
µ (K)
there is a sequence {ηm }m≥1 satisfying limm→∞ ηm = 0 and ρ2δ (Yηmηm , Ye µe(K)∧T ) >
µ (K)∧T
for all m. Since {(Yηmηm , Zηm , µηm (K))}m≥1 is relatively compact, there ex-
ists a subsequence (WLOG, assume it is the original sequence itself) converging
to (Ye µe∧T , W, µ
e) in distribution. However, repeating the exactly same argument as
µ (K)∧T
above, we would have ρ2δ (Yηmηm , Ye µe(K)∧T ) ≤ for all sufficiently large m, which
is a contradiction. This completes the proof.
Proof of the second claim of Theorem 8.8.10. We will first show there exists a se-
quence of compact set {Km }m≥1 such that ∪∞
m=1 Km = U and Km ⊆ Km+1 . For
344
is bounded and closed, Km is compact for every m. Now we claim ∪∞
m=1 Km = U .
Note that ∪∞ ∞ ∞
m=1 Km = ∪m=1 Hm ∩ Bm (0) = ∪m=1 Hm . ∀x ∈ U , since U is open,
0
ρ3δ (Y µ(K)∧T , Yηµη (K )∧T ) ≤ 2−dT e 0 .
345
Y µ(K)∧T (t) ∈ K, thus we know if µη (K 0 ) ≤ T , then
0 0
dU (Y µ(K)∧T , Yηµη (K )∧T ) ≥ 2−dT e Y µ(K)∧T (µη (K 0 )) − Yηµη (K )∧T (µη (K 0 )
2
≥ 2−dT e dU (K, Rd /K 0 )
≥ 2−dT e 0 .
µ (K 0 )∧T
On the other hand, if µη (K 0 ) > T , then YηT = Yη η . Thus we can conclude
µ (K 0 )∧T
that dU (Y µ(K)∧T , YηT ) ≥ 2−dT e 0 implies dU (Y µ(K)∧T , Yη η ) ≥ 2−dT e 0 . Therefore,
we further have
h 0
i
P dU (Y µ(K)∧T , YηT ) ≥ 2−dT e 0 ≤ P dU (Y µ(K)∧T , Yηµη (K )∧T ) ≥ 2−dT e 0 ≤ 3δ,
that is,
Proof of Theorem 8.5.7. We first prove that Y never leaves Γ, i.e., P[Y (t) ∈ Γ, ∀t ≥
0] = 1. By the result of Theorem 8.8.9, we know that for each compact set K ⊂ Γ,
Y µ(K) stays on Γ almost surely, where µ(K) := inf{t ≥ 0 | Ye (t) ∈
/ K̊} is the
earliest time that Y leaves K. In other words, for all compact set K ⊂ Γ, P[∃t ≥
0, Y (t) ∈
/ Γ, Y (t) ∈ K] = 0. Let {Km }m≥1 be any sequence of compact sets such that
346
∪m≥1 Km = U and Km ⊂ U , e.g., the ones constructed in the proof of the second claim
of Theorem 8.8.10. Therefore, we have
∞
X
P[∃t ≥ 0, Y (t) ∈
/ Γ] = P[∃t ≥ 0, Y (t) ∈
/ Γ, Y (t) ∈ U ] ≤ P[∃t ≥ 0, Y (t) ∈
/ Γ, Y (t) ∈ Km ] = 0,
m=1
1/2
dY (t) = Σk dW (t) + ∂ 2 Φ(Y (t))[Σ(Y (t))]dt
D
1X
= ∂Φ(Y (t))σ(Y (t))dW (t) + ∂ij Φ(Y (t))(σ(Y (t))σ(Y (t))> )ij dt
2 i,j=1
where the second equality follows from the definition that Σk = ∂ΦΣ∂Φ = ∂Φσσ > ∂Φ.
This coincides with the formulation of the limiting diffusion in Theorem 8.8.10.
Therefore, further combining Lemma 8.5.3 and the second part of Theorem 8.8.10, we
obtain the desired result.
Remark 8.8.11. Our result suggests that for tiny LR η, SGD dynamics have two
phases. In Phase I of Θ(1/η) steps, the SGD iterates move towards the manifold Γ
of local minimizers along GF. Then in Phase II which is of Θ(1/η 2 ) steps, the SGD
iterates stay close to Γ and diffuse approximately according to (8.12). See Figure 8.2
for an illustration of this two-phase dynamics. However, since the length of Phase I
gets negligible compared to that of Phase II when η → 0, Theorem 8.5.7 only reflects
the time scaling of Phase II.
347
Figure 8.2: Illustration for two-phase dynamics of SGD with the same example as in
Figure 8.1 . Γ is an 1D manifold of minimizers of loss L.
Lemma 8.9.1. For any x ∈ Γ and any v ∈ Tx (Γ), it holds that ∇2 L(x)v = 0.
Proof. For any x ∈ Tx (Γ), let {x(t)}t≥0 be a parametrized smooth curve on Γ such
dx(t) d∇L(xt )
that x(0) = x and dt t=0
= v. Then ∇L(xt ) = 0 for all t. Thus 0 = dt t=0
=
∇2 L(x)v.
348
dx(t)
Proof. Fixing any x ∈ RD , let dt
= −∇L(x(t)) be initialized at x(0) = x. Since
Φ(x(t)) = Φ(x) for all t ≥ 0, we have
d
Φ(x(t)) = −∂Φ(x(t))∇L(x(t)) = 0.
dt
d2
2 dx(t) 2 dx(t)
Φ(x t ) = −∂ Φ(x(t)) , ∇L(x(t)) − ∂Φ(x(t))∇ L(x(t)) = 0.
dt2 dt dt
Lemma 8.5.4. For any x ∈ Γ, ∂Φ(x) is the orthogonal projection matrix onto tangent
space Tx (Γ).
Proof of Lemma 8.5.4. For any v ∈ Tx (Γ), let {v(t), t ≥ 0} be a parametrized smooth
dv(t)
curve on Γ such that v(0) = x and dt t=0
= v. Since v(t) ∈ Γ for all t ≥ 0, we have
Φ(v(t)) = v(t), and thus
dv(t) d dv(t)
= Φ(v(t)) = ∂Φ(x) .
dt t=0 dt t=0 dt t=0
= tu + o(t)
349
where the second equality follows from the assumption that ∇2 L(x) is full-rank when
restricted on Tx⊥ (Γ). Then since ∂Φ is continuous, it follows that
= ∂Φ(x)u.
By Lemma 8.9.2, we have ∂Φ(x + t(∇2 L(x))† u))∇L(x + t(∇2 L(x))† u) = 0 for all t > 0,
which then implies that ∂Φ(x)u = 0 for all u ∈ Tx⊥ (Γ).
Therefore, under the basis {v1 , . . . , vD }, ∂Φ(x) is given by
ID−M 0 D×D
∂Φ(x) = ∈R ,
0 0
350
Denote the derivative of P (t), P ⊥ (t) and H(t) with respect to t as P 0 (t), (P ⊥ )0 (t) and
H 0 (t). Then differentiating with respect to t, we have
0 0
P11 (0) P12 (0) 0 0
P 0 (0) = , H(0) = , (8.28)
0 0
P21 (0) P22 (0) 0 H22 (0)
0
where P11 (0) ∈ R(D−M )×(D−M ) and H22 is the hessian of L restricted on Tx⊥ (Γ). Also
note that
0 0
ID−M 0 H11 (0) H12 (0) 0 0
P (0)H 0 (0)P ⊥ (0) =
0 0
0 0 H21 (0) H22 (0) 0 IM
0
0 H12 (0)
= ,
0 0
0 0
0 P12 (0)H22 (0) 0 −H12 (0)
P 0 (0)H(0) = =
.
0
0 P22 (0)H22 (0) 0 0
0 0 0
This implies that we must have P22 (0) = 0 and P12 (0)H22 (0) = H12 (0). Similarly, by
0 0
taking transpose in (8.28), we also have H22 (0)P21 (0) = −H21 (0).
351
0
It then remains to determine the value of P11 (0). Note that since P (t)P (t) = P (t),
we have P 0 (t)P (t) + P (t)P 0 (t) = P 0 (t), evaluating at t = 0 yields
0 0
2P11 (0) = P11 (0).
0
Therefore, we must have P11 (0) = 0. Combining the above results, we obtain
d
P 0 (0) = ∂Φ(v(t)) = ∂(∂Φ(x))[v].
dt t=0
352
Proof of Lemma 8.9.5. For any u ∈ Tx⊥ (Γ), we define u(t) = x + t∇2 L(x)† u for t ≥ 0.
By Taylor approximation, we have
and
Combine (8.29) and (8.30) and apply Lemma 8.9.2, and it follows that
where the last equality follows from Lemma 8.9.3. Dividing both sides by t2 and
letting t → 0, we get
With the notion of Lyapunov Operator in Definition 8.5.5, Lemma 8.9.5 can be
further simplified into Lemma 8.9.6.
353
Lemma 8.9.6. For any x ∈ Γ and Σ ∈ span{uu> | u ∈ Tx⊥ (Γ)},
Proof of Lemma 8.9.6. Let A = uu> + ∇2 L(x)† uu> ∇2 L(x) and B = ∇2 L(x)† uu> .
The key observation is that A + A> = L∇2 L(x) (B + B > ). Therefore, by Lemma 8.9.5,
it holds that
∂ 2 Φ(x)[L∇2 L(x) (B+B > )] = ∂ 2 Φ(x)[A+A> ] = 2∂Φ(x)∂ 2 (∇L)(x)[B] = ∂Φ(x)∂ 2 (∇L)(x)[B+B > ].
Since ∇2 L(x)† is full-rank when restricted to Tx⊥ (Γ), we have span{∇2 L(x)† uu> +
uu> ∇2 L(x)† | u ∈ Tx⊥ (Γ)} = span{uu> | u ∈ Tx⊥ (Γ)}. Thus by the linearity of above
equation, we can replace B + B > by any Σ ∈ span{uu> | u ∈ Tx⊥ (Γ)}, resulting in the
desired equation.
Then Lemma 8.5.6 directly follows from Lemma 8.9.4 and 8.9.5.
Manifold Itself
Here we show that the second term of (8.12), i.e., the tangent noise compensation
for the limiting dynamics to stay on Γ, only depends on Γ itself.
Proof of Lemma 8.9.7. Let {v(t)}t≥0 be a smooth curve on Γ with v(0) = x and
dv(t)
dt t=0
= v. Since v(t) stays on Γ, we have ∇L(v(t)) = 0 for all t ≥ 0. Taking deriva-
2
tive for two times yields ∂ 2 (∇L)(v(t))[ dv(t)
dt
, dv(t)
dt
] + ∇2 L(v(t)) d dtv(t)
2 = 0. Evaluating it
354
at t = 0 and multiplying both sides by ∇2 L(x)† , we get
d2 v(t) d2 v(t)
∇2 L(x)† ∂ 2 (∇L)(x) [v, v] = −∇2 L(x)† ∇2 L(x) = −∂Φ(x) .
dt2 t=0 dt2 t=0
Since ∂Φ(x) is the projection matrix onto Tx (Γ) by Lemma 8.5.4, it does not depend
2
on L, so analogously we also have ∇2 L0 (x)† ∂ 2 (∇L0 )(x) [v, v] = −∂Φ(x) d dtv(t)
2 t=0
as
d2 v(t)
well. The proof is thus completed. Note that ∂Φ(x) dt2 t=0
is indeed the second
fundamental form for v at x, and the value won’t change if we choose another
parametric smooth curve with a different second-order time derivative. (See Chapter
6 in Do Carmo [230] for a reference.)
Now we are ready to give the missing proofs in Section 8.6 which yield explicit formula
of the limiting diffusion for label noise and isotropic noise.
1 1
dY (t) = ∂Φ(Y )dW − ∇2 L(Y )† ∂ 2 (∇L)(Y ) [∂Φ(Y )] dt − ∂Φ(Y )∇(ln |∇2 L(Y )|+ )dt
| 2 {z } |2 {z }
Brownian Motion on Manifold Normal Regularization
(8.13)
Proof of Corollary 8.6.1. Set Σk = ∂Φ, Σ⊥ = ID − ∂Φ and Σ⊥,k = Σk,⊥ = 0 in the de-
composition of Σ by Lemma 8.5.6, and we need to show ∇(ln |Σ|+ ) = ∂ 2 (∇L)[(∇2 L)† ].
Holbrook [231] shows that the gradient of pseudo-inverse determinant satis-
fies ∇|A|+ = |A|+ A† . Thus we have for any vector v ∈ RD , hv, ∇ ln |∇2 L|+ i =
355
D E
|∇2 L|+ ∇2 L
|∇2 L|+
, ∂ 2 (∇L)[v] = h∇2 L, ∂ 2 (∇L)[v]i = ∂ 2 (∇L)[v, ∇2 L] = v, ∂ 2 (∇L)[(∇2 L)† ] ,
which completes the proof.
Corollary 8.6.2 (Limiting Flow for Label Noise). If Σ ≡ c∇2 L on Γ for some
constant c > 0, SDE (8.12) can be simplified into (8.15) where the regularization is
from the noise in the normal space.
Proof of Corollary 8.6.2. Since Σ = c∇2 L, here we have Σ⊥ = Σ and Σk , Σ⊥,k , Σk,⊥ =
0. Thus it suffices to show that 2∂ 2 (∇L) L−1
2
∇2 L (Σ⊥ ) = ∇ tr[∇ L]. Note that for any
v ∈ RD ,
where the second equality is because the the tangent space of symmetric rank-(D − M )
matrices at ∇2 L is {A∇2 L + ∇2 LA> | A ∈ RD×D }, and every element in this
tangent space has zero inner-product with ∂Φ by Lemma 8.5.4. Also note that
L−1 2 1 2 −1 2 2
∇2 L (∇ L) = 2 (ID − ∂Φ), thus hID − ∂Φ, ∂ (∇L)[v]i = 2 L∇2 L (∇ L), ∂ (∇L)[v] =
2v > ∂ 2 (∇L)[L−1 2
∇2 L (∇ L)].
hQj−3
α v, x1:2 i), if i = j ≥ 3 or 0 otherwise, then the solution of SDE (8.12) is the
following (8.16) , which implies that Y (t) moves anti-clockwise with a constant angular
speed of (D − 2)/2.
356
Proof of Lemma 8.6.3. Note that for any x ∈ Γ, it holds that
2 + hQj−3
α v, x1:2 i if i = j ≥ 3,
∇2 L(x)
ij
= xi xj if i, j ∈ {1, 2}, (8.33)
0
otherwise.
Then clearly Σ only brings about noise in the normal space, and specifically, it holds
that L−1 0 D−3
∇2 L(x) (Σ(x)) = diag(0, 0, 1+ Qα v, Q−π/2 x1:2 , . . . , 1+ Qα v, Q−π/2 x1:2 ). Fur-
ther note that, by the special structure of the hessian in (8.33) and Lemma 8.9.3, for any
x1:2 Q−π/2 x1:2 >
x ∈ Γ, we have ∂Φ(x) = (x2 , −x1 , 0, . . . , 0)> (x2 , −x1 , 0, . . . , 0) = Q−π/2
0 0
.
Combining these facts, the dynamics of the first two coordinates in SDE (8.12) can
be simplified into
dx1:2 (t) 1 2 −1
=− ∂Φ(x(t))∂ (∇L)(x(t))[L∇2 L (Σ(x(t))]
dt 2 1:2
D
1 X
= − Q−π/2 x1:2 x> >
1 + Qj−3
1:2 Q−π/2 α v, Q−π/2 x1:2 ∇1:2 (∂jj L)(x)
2 j=3
* D
+
1 X
1 + Qj−3 Qj−3
= − Q−π/2 x1:2 Q−π/2 x1:2 , α v, Q−π/2 x1:2 α v
2 j=3
* D
+ D
!
1 X X 2
= − Q−π/2 x1:2 Q−π/2 x1:2 , Qj−3
α v + Qj−3
α v, Q−π/2 x1:2
2 j=3 j=3
1 D−2 2 D−2
= − Q−π/2 x1:2 0 + Q−π/2 x1:2 2 = Qπ/2 x1:2 ,
2 2 2
where the second to the last equality follows from the property of Qα and the last
equality follows from the fact that kx1:2 k22 = 1 for all x ∈ Γ. Note we require k ≥ 3 (or
2 2
D ≥ 5) to allow D j−3
= D−2
P
j=3 Qα v, Q−π/2 x1:2 2
Q−π/2 x1:2 2 . On the other hand,
dx3:D (t)
we have dt
= 0 as ∂Φ kills the movement on that component.
357
The proof is completed by noting that the solution of x1:2 is
D−2
x1:2 (t) = exp t · Qπ/2 x1:2 (0),
2
D−2 t(D−2) t(D−2)
exp t · Qπ/2 = (exp(Qπ/2 )) 2 = Q1 2 = Q t(D−2) .
2 2
P∞ At
Proof. By definition, for matrix A = ( 01 −1
0 ), exp(A) = t=0 t! . Note that A2 = −I,
A3 = −A and A4 = I. Using this pattern, we can easily check that
∞
P∞ i 1
P∞ i 1
X At − i=0 (−1)
i=0 (−1) (2i)! cos 1 − sin 1
(2i+1)!
= P = .
t! ∞ i 1
P∞ i 1
t=0 i=0 (−1) (2i+1)! i=0 (−1) (2i)! sin 1 cos 1
In this section, we present the missing proofs in Section 8.7 regarding the over-
parametrized linear model.
For convenience, for any p, r ≥ 0 and u ∈ RD , we denote by Brp (u) the `p norm
ball of radius r centered at u. We also denote vi:j = (vi , vi+1 , . . . , vj )> for i, j ∈ [D].
358
Theorem 8.7.1. In the setting of OLM, suppose the groundtruth is κ-sparse and
n ≥ Ω(κ ln d) training data are sampled from either i.i.d. Gaussian or Boolean
distribution. Then for any initialization xinit (except a zero-measure set) and any
> 0, there exist η0 , T > 0 such that for any η < η0 , OLM trained with label noise
SGD (8.14) with LR equal to η for bT /η 2 c steps returns an -optimal solution, with
probability of 1 − e−Ω(n) over the randomness of the training dataset.
Proof of Theorem 8.7.1. First, by Lemma 8.7.6, it holds with probability at least
1 − e−Ω(n) that the solution to (8.21), x∗ , is unique up to and satisfies |x∗ | = ψ(w∗ ).
Then on this event, for any > 0, by Lemma 8.7.5, there exists some T > 0 such that
xT given by the Riemannian gradient flow (8.20) satisfies that xT is an /2-optimal
solution of the OLM. For this T , by Theorem 8.5.7, we know that the bT /η 2 c-th
SGD iterate, xη (bT /η 2 c), satisfies kxη (bT /η 2 c) − xT k2 ≤ /2 with probability at
least 1 − e−Ω(n) for all sufficiently small η > 0, and thus xη (bT /η 2 c) is an -optimal
solution of the OLM. Finally, the validity of applying Theorem 8.5.7 is guaranteed by
Lemma 8.7.2 and 8.7.3. This completes the proof.
In the following subsections, we provide the proofs of all the components used in
the above proof.
zi u
2
∇ `i (x) = 2 (zi u) >
−(zi v) > + (fi (u, v) − yi ) · diag(zi , zi ).
−zi v
359
So for any x ∈ Γ, it holds that
n
zi u
2 2 X
∇ L(x) = (zi u)> −(zi v)> . (8.34)
n
i=1 −zi v
Lemma 8.11.1. For any fixed x ∈ RD , suppose {∇fi (x)}i∈[n] is linearly independent,
then K(x) is full-rank.
Proof of Lemma 8.11.1. Suppose otherwise, then there exists some λ ∈ Rn such that
λ 6= 0 and λ> K(x)λ = 0. However, note that
n
X
>
λ K(x)λ = λi λj Kij (x)
i,j=1
X n
= λi λj h∇fi (x), ∇fj (x)i
i,j=1
n 2
X
= λi ∇fi (x) ,
i=1 2
Pn
which implies that i=1 λi ∇fi (x) = 0. This is a contradiction since by assumption
{∇fi (x)}i∈[n] is linearly independent.
Lemma 8.7.2. Consider the loss L defined in (8.17) and manifold Γ defined in (8.18).
If data is full rank, i.e., rank(Z) = n, then it holds that (a). Γ is a smooth manifold
of dimension D − n; (b). rank(∇2 L(x)) = n for all x ∈ Γ. In particular, rank(Z) = n
holds with probability 1 for Gaussian distribution and with probability 1−cd for Boolean
distribution for some constant c ∈ (0, 1).
Proof of Lemma 8.7.2. (1) By preimage theorem [223], it suffices to check the jacobian
z1 u zn u
[∇f1 (x), . . . , ∇fn (x)] = 2[ −z 1 v
, . . . , −zn v
] is full rank. Similarly, for the second
zi u
claim, due to (8.34). it is also equivalent to show that { −z i v
}i∈[n] is of rank n.
Since uv ∈ Γ ⊂ U , each coordinate is non-zero, thus we only need to show that
{zi }i∈[n] is of rank n. This happens with probability 1 in the Gaussian case, and
360
probability at least 1 − cd for some constant c ∈ (0, 1) by Kahn et al. [232]. This
completes the proof.
We first establish some auxiliary results. The following lemma shows the PL condition
along the trajectory of gradient flow.
Lemma 8.11.2. Along the gradient flow generated by −∇L, it holds that
k∇L(x(t))k2 ≥ 16
λ (ZZ > )
n min
· mini∈[d] |ui (0)vi (0)|L(x(t)), ∀t ≥ 0.
To prove Lemma 8.11.2, we need the following invariance along the gradient flow.
Lemma 8.11.3. Along the gradient flow generated by −∇L, uj (t)vj (t) stays constant
for all j ∈ [d]. Thus, sign(uj (t)) = sign(uj (0)) and sign(vj (t)) = sign(vj (0)) for any
j ∈ [d].
= 0.
Therefore, any sign change of uj (t), vj (t) would enforce uj (t) = 0 or vj (t) = 0 for
some t > 0 since uj (t), vj (t) are continuous in time t. This immediately leads to a
contradiction to the invariance of uj (t)vj (t).
361
Proof of Lemma 8.11.2. Note that
n
1 X
k∇L(x)k22 = 2 (fi (x) − yi )(fj (x) − yj ) h∇fi (x), ∇fj (x)i
n i,j=1
n
1 X
≥ 2 (fi (x) − yi )2 λmin (K(x))
n i=1
2
= L(x)λmin (K(x)),
n
where K(x) is a n × n p.s.d. matrix with Kij (x) = h∇fi (x), ∇fj (x)i. Below we
lower bound λmin (K(x)), the smallest eigenvalue of K(x). Note that Kij (x(t)) =
4 dh=1 zi,h zj,h ((uh (t))2 + (vh (t))2 ), and we have
P
K(x(t)) = 4Zdiag((u(t)) 2
+ (v(t)) 2 )Z > 8Zdiag(|u(t) v(t)|)Z >
(∗)
= 8Zdiag(|u(0) v(0)|)Z > 8 min |ui (0)vi (0)|ZZ T
i∈[d]
where (∗) is by Lemma 8.11.3. Thus λmin (K(x(t)) ≥ 8 mini∈[d] |ui (0)vi (0)|λmin (ZZ T )
for all t ≥ 0, which completes the proof.
Lemma 8.11.4. All the stationary points in U are global minimizers, i.e., Γ = {x ∈
U | ∇L(x) = 0}.
Proof of Lemma 8.11.4. Since Γ is the set of local minimizers, each x in Γ must satisfy
∇L(x) = 0. The other direction is proved by noting that rank({zi }i∈[n] ) = n, which
implies rank({∇fi (x)}i∈[n] ) = n.
Lemma 8.7.3. Consider the loss function L defined in (8.17), manifold Γ and its
dxt
open neighborhood defined in (8.18). For gradient flow dt
= −∇L(xt ) starting at any
x0 ∈ U , it holds that Φ(x0 ) ∈ Γ.
362
dx(t)
Proof of Lemma 8.7.3. It suffices to prove gradient flow dt
= −∇L(x(t)) converges
when t → ∞, as long as x(0) ∈ U . Whenever it converges, it must converge to a
stationary point in U . The proof will be completed by noting that all stationary point
of L in U belongs to Γ (Lemma 8.11.4).
Below we prove limt→∞ x(t) exists. Denote C = 16
n
mini∈[d] |ui (0)vi (0)|λmin (ZZ > ),
then it follows from Lemma 8.11.2 that
k∇L(x(t))k22 − dL(x(t))
p
dx(t) dt 1 d L(x(t))
= k∇L(x(t))k ≤ p = p =− √ .
dt CL(x(t)) L(x(t)) 2 C dt
R∞ R∞ √
dx(t) 1 d L(x(t))
Thus the total GF trajectory length is bounded by t=0 dt
dt ≤ t=0
− 2 C ddt dt
√ ≤
L(x(0))
√
2 C
, where the last inequality uses that L is non-negative over RD . Therefore, the
GF must converge.
d n
!
4X X 2
minimize R(w) = z |wj |,
n j=1 i=1 i,j
(8.35)
subject to Zw = Zw∗ .
363
Here we slightly abuse the notation of R and the parameter dimension will be clear
from the context. We can relate the optimal solution to (8.21) to that of (8.35) via a
canonical parametrization defined as follows.
u
Definition 8.11.5 (Canonical Parametrization). For any w ∈ Rd , we define v
=
1/2 1/2 >
ψ(w) = ([w> ]+ , [−w> ]+ ) as the canonical parametrization of w. Clearly, it holds
2 2
that u −v = w.
Indeed, we can show that if (8.35) has a unique optimal solution, it immediately
follows that the optimal solution to (8.21) is also unique up to sign flips of each
coordinate, as summarized in the lemma below.
Lemma 8.11.6. Suppose the optimal solution to (8.35) is unique and equal to w∗ .
Then the optimal solution to (8.21) is also unique up to sign flips of each coordi-
u∗ , ve∗ ) = ψ(w∗ ), that is, the canonical
nate. In particular, one of them is given by (e
parametrization of w∗ .
d n
! d n
! d n
!
X X X X X X
2
zi,j |wj∗ | ≤ 2
zi,j |w
bj | ≤ 2
zi,j u2j + vbj2 ).
(b (8.36)
j=1 i=1 j=1 i=1 j=1 i=1
d n
! d n
! d n
!
X X X X X X
2
zi,j u2j
(b + vbj2 ) ≤ 2
zi,j u∗j )2
((e + vj∗ )2 )
(e = 2
zi,j |wj∗ |.
j=1 i=1 j=1 i=1 j=1 i=1
(8.37)
364
Combining (8.36) and (8.37) yields
d n
! d n
! d n
!
X X X X X X
2
zi,j u2j + vbj2 ) =
(b 2
zi,j |wj∗ | = 2
zi,j u2j − vbj2 |
|b (8.38)
j=1 i=1 j=1 i=1 j=1 i=1
Therefore, the unique optimality of (8.21) can be reduced to that of (8.35). In the
sequel, we show that the latter holds for both Boolean and Gaussian random vectors.
We divide Lemma 8.7.6 into to Lemma 8.11.8 and 8.11.7 for clarity.
i.i.d.
Lemma 8.11.7 (Boolean Case). Let z1 , . . . , zn ∼ Unif({±1}d ). There exist some
constants C, c > 0 such that if the sample size n satisfies
n ≥ C[κ ln(d/κ) + κ]
2
then with probability at least 1 − e−cn , the optimal solution of (8.21), (b
u, vb), is unique
up to sign flips of each coordinate and recovers the groundtruth, i.e., u
b 2
− vb 2
= w∗ .
i.i.d.
Proof of Lemma 8.11.7. By the assumption that z1 , . . . , zn ∼ Unif({±1}d ), we have
Pn 2
i=1 zi,j = n for all j ∈ [d]. Then (8.35) is equivalent to the following optimization
problem:
√
This model exactly fits the Example 6.2 in Tropp [226] with σ = 1 and α = 1/ 2.
Then applying Equation (4.2) and Theorem 6.3 in Tropp [226], (8.39) has a unique
365
2
optimal solution equal to (u∗ ) 2
− (v ∗ ) 2
with probability at least 1 − e−ch for some
constant c > 0, given that the sample size satisfies
n ≥ C(κ ln(d/κ) + κ + h)
n
for some absolute constant C > 0. Choosing h = 2C
and then adjusting the choices of
C, c appropriately yield the desired result. Finally, applying Lemma 8.11.6 finishes
the proof.
n ≥ Cκ ln d,
then with probability at least 1 − (2d + 1)e−cn , the optimal solution of (8.21), (b
u, vb),
is unique up to sign flips of each coordinate of u
b and vb and recovers the groundtruth,
i.e., u
b 2
− vb 2
= w∗ .
i.i.d.
Proof of Lemma 8.11.8. Since z1 , . . . , zn ∼ N (0, Id ), we have
" n
#
X
P 2
zi,j ∈ [n/2, 3n/2], ∀j ∈ [d] ≥ 1 − 2de−cn
i=1
for some constant c > 0, and we denote this event by En . Therefore, on En , we have
D
X D
X
2 2
2 (uj + vj ) ≤ R(x) ≤ 6 (u2j + vj2 )
j=1 j=1
or equivalently,
366
Define w∗ = (u∗ ) 2
− (v ∗ ) 2 , and (8.35) is equivalent to the following convex
optimization problem
d n
!
4X X 2
minimize g(w) = z |wj + wj∗ |,
n j=1 i=1 i,j
(8.40)
subject to Zw = 0.
The point w = 0 is feasible for (8.40), and we claim that this is the unique optimal
solution when n is large enough. In detail, assume that there exists a non-zero feasible
point w for (8.40) in the descent cone [226] D(g, w∗ ) of g, then
kZwk2
λmin (Z; D(g, w∗ )) ≤ =0
kwk2
where the equality follows from that w is feasible. Therefore, we only need to show
that λmin (Z; D(g, x∗ )) is bounded from below for sufficiently large n.
On En , it holds that g belongs to the following function class
( d
)
X
G= h : Rd → R | h(w) = υj |wj |, υ ∈ Υ with Υ = {υ ∈ Rd : υj ∈ [2, 6], ∀j ∈ [d]}.
j=1
367
where S n−1 denotes the unit sphere in Rn . Applying the same argument as in Tropp
[226] yields
√ 2
P λmin (Z; DΥ ) ≥ n − 1 − w(DΥ ) − h ≥ 1 − e−h /2 .
Take the intersection of this event with En , and we obtain from a union bound that
√
λmin (Z; D(g, w∗ )) ≥ n − 1 − w(DΥ ) − h (8.41)
2 /2
with probability at least 1 − e−h − 2de−cn . It remains to determine w(DΥ ), which
is defined as
" # " #
w(DΥ ) = Ez∼N (0,Id ) sup hz, pi = Ez∼N (0,Id ) sup sup hz, pi . (8.42)
p∈DΥ ∩S d−1 υ∈Υ p∈D(gυ ,x∗ )∩S d−1
d
X d
X
υj |wj∗ + τ pj | ≤ υj |wj∗ |
j=1 j=1
d
X κ
X κ
X
τ υj |pj | ≤ υj (|wj∗ | − |wj∗ − τ pj |) ≤ τ υj |pj |
j=κ+1 j=1 j=1
368
where the second inequality follows from the triangle inequality. Then since each
υj ∈ [2, 6], it follows that
d
X κ
X
|pj | ≤ 3 |pj |.
j=κ+1 j=1
Note that this holds for all ξ ∈ Ξ simultaneously. Now let us denote p1:κ =
(p1 , . . . , pκ ) ∈ Rκ and p(κ+1):d = (pκ+1 , . . . , pd ) ∈ Rd−κ , and similarly for other d-
dimensional vectors. Then for all p ∈ DΥ ∩ S d−1 , by Cauchy-Schwartz inequality, we
have
√
kp(κ+1):d k1 ≤ 3kp1:κ k1 ≤ 3 κkp1:κ k2 .
where the last inequality follows from the fact that p ∈ S d−1 . Therefore, combine the
above inequality with (8.42), and we obtain that
√
w(DΥ ) ≤ E kz1:κ k2 + 3 κ · max |zj |
j∈{κ+1,...,d}
√ √
≤ κ+3 κ·E max |zj | . (8.43)
j∈{κ+1,...,d}
p √
where the second inequality follows from the fact that E[kz1:κ k2 ] ≤ E[kz1:κ k22 ] = κ.
To bound the second term in (8.43), applying Lemma 8.11.9, it follows from (8.43)
369
that
√ p
w(DΥ ) ≤ κ + 3 2κ ln(2(d − κ)). (8.44)
√ √ p
λmin (Z; D(g, w∗ )) ≥ n − 1 − κ − 3 2κ ln(2(d − κ)) − h.
√
Therefore, choosing h = n − 1/2, as long as n satisfies that n ≥ C(κ ln d) for some
constant C > 0, we have λmin (Z; D(g, w∗ )) > 0 with probability at least 1−(2d+1)e−cn .
Finally, the uniqueness of the optimal solution to (8.21) in this case follows from
Lemma 8.11.6.
p
Lemma 8.11.9. Let z ∼ N (0, Id ), then it holds that E maxi∈[d] |zi | ≤ 2 ln(2d).
Proof of Lemma 8.11.9. Denote M = maxi∈[d] |zi |. For any λ > 0, by Jensen’s in-
equality, we have
Xd
λ|zi |
λ·E[M ]
E eλ|zi | .
λM
e ≤E e = E max e ≤
i∈[d]
i=1
Note that E[eλ|zi | ] ≤ 2 · E[eλzi ]. Thus, by the expression of the Gaussian moment
generating function, we further have
d
2
X
λ·E[M ]
E eλzi = 2deλ /2 ,
e ≤2
i=1
ln(2d) λ
E[M ] ≤ + .
λ 2
p
Choosing λ = 2 ln(2d) yields the desired result.
370
8.11.5 Proof of Lemma 8.7.5
Instead of studying the convergence of the Riemannian gradient flow directly, it is more
convenient to consider it in the ambient space RD . To do so, we define a Lagrange
function L(x; λ) = R(x) + ni=1 λi (fi (x) − yi ) for λ ∈ Rn . Based on this Lagrangian,
P
Lemma 8.11.10. The `2 norm has a unique minimizer among {∇x L(x; λ) |
λ ∈ Rn } for any fixed x ∈ RD . Thus we can define F : RD → RD by
F (x) = argming∈{∇x L(x;λ)|λ∈Rn } kgk2 . Moreover, it holds that hF (x), ∇fi (x)i = 0 for
all i ∈ [n].
Proof of Lemma 8.11.10. Fix any x ∈ RD . Note that {∇x L(x; λ) | λ ∈ Rn } is the
subspace spanned by {∇fi (x)}i∈[n] shifted by ∇R(x), thus there is unique minimizer
of the `2 norm in this set. This implies that F (x) = argming∈{∇x L(x;λ)|λ∈Rn } kgk2 is
well-defined.
To show the second claim, denote h(λ) = k∇x L(x; λ)k22 /2, which is a quadratic
function of λ ∈ Rn . Then we have
Pn
h∇R(x), ∇f1 (x)i i=1 λi h∇f1 (x), ∇fi (x)i h∇R(x), ∇f1 (x)i
.. .. ..
∇h(λ) =
. +
. =
. + K(x)λ.
P
n
h∇R(x), ∇fn (x)i i=1 λi h∇fn (x), ∇fi (x)i h∇R(x), ∇fn (x)i
For any λ such that ∇x L(x; λ) = F (x), we must have ∇h(λ) = 0 by the definition of
F (x), which by the above implies
371
Therefore, we further have
n
X
hF (x), ∇fi (x)i = h∇R(x), ∇fi (x)i + λj h∇fi (x), ∇fj (x)i = h∇R(x), ∇fi (x)i + (K(x)λ)i = 0
j=1
Hence, with any initialization x(0) ∈ Γ, the limiting flow (8.20) is equivalent to
the following dynamics
dx(t) 1
= − F (x(t)). (8.45)
dt 4
Thus Lemma 8.7.5 can be proved by showing that the above x(t) converges to x∗ as
t → ∞. We first present a series of auxiliary results in below.
j ∈ [d], either uj = 0 or vj = 0.
n
" n n
#
∂R X ∂fi 4X 2 X
0= (x) + λi (x) (x) = 2uj z + λi (x)zi,j ,
∂uj i=1
∂uj n i=1 i,j i=1
n
" n n
#
∂R X ∂fi 4X 2 X
0= (x) + λi (x) (x) = 2vj z − λi (x)zi,j .
∂vj i=1
∂vj n i=1 i,j i=1
If there exists some j ∈ [d] such that uj 6= 0 and vj 6= 0, then it follows from the above
two identities that
n
X
2
zi,j =0
i=1
which happens with probability 0 in both the Boolean and Gaussian case. Therefore,
we must have uj = 0 or vj = 0 for all j ∈ [d].
372
Lemma 8.11.12. Let F : RD → RD be as defined in Lemma 8.11.10. Then F is
continuous on RD .
Proof. Case I. We first consider the simpler case of any fixed x∗ ∈ U = (R \ {0})D ,
assuming that K(x∗ ) is full-rank. Lemma 8.11.10 implies that for any λ ∈ Rn such
that ∇x L(x∗ ; λ) = F (x∗ ), we have
Since K(x) is continuous around x∗ , there exists a sufficiently small δ > 0 such that
for any x ∈ Bδ (x∗ ), K(x) is full-rank, which further implies that K(x)−1 is also
continuous in Bδ (x). Therefore, by the above characterization of λ, we see that λ(x)
is continuous for x ∈ Bδ (x∗ ), and so is F (x) = ∇R(x) + ni=1 λi (x)∇fi (x).
P
Case II. Next, we consider all general x∗ ∈ RD . Here for simplicity, we reorder the
coordinates as x = (u1 , v1 , u2 , v2 , . . . , ud , vd ) with a slight abuse of notation. Without
loss of generality, fix any x∗ such that for some q ∈ [d], (ui (0))2 + (vi (0))2 > 0 for all
i = 1, . . . , q and u∗i = vi∗ = 0 for all i = q + 1, . . . , d. Then ∇R(x∗ ) and {∇fi (x∗ )}i∈[n]
only depend on {zi,j }i∈[n],j∈[q] , and for all i ∈ [n], it holds that
Note that if we replace {∇fi (x)}i∈[n] by any fixed and invertible linear transform of
itself, it would not affect the definition of F (x). In specific, we can choose an invertible
matrix Q ∈ Rn×n such that, for some q 0 ∈ [q], (e
z1 , . . . , zen ) = (z1 , . . . , zn )Q satisfies
zi,1:q }i∈[q0 ] is linearly independent and zei,1:q = 0 for all i = q 0 + 1, . . . , n. We then
that {e
373
h i
consider ∇fe1 (x), . . . , ∇fen (x) = [∇f1 (x), . . . , ∇fn (x)] Q and the corresponding F (x).
For notational simplicity, we assume that Q can be chosen as the identity matrix, so
that (z1 , . . . , zn ) itself satisfies the above property, and we repeat it here for clarity
In the sequel, we use λ for n-dimensional vectors and λ̄ for q 0 -dimensional vectors.
Denote2
n
X
λ(x) ∈ argmin ∇R(x) + λi ∇fi (x) ,
λ∈Rn i=1 2
q0
!
X
λ̄(x) ∈ argmin ∇R(x) + λ̄i ∇(fi (x) .
λ̄∈Rq0 i=1 1:(2q) 2
q 0 !
X n
X
∗ ∗ ∗ ∗
∇R(x ) + λ̄i (x )∇fi (x ) = ∇R(x ) + λi (x)∇fi (x∗ ) = kF (x∗ )k2 .
i=1 1:(2q) 2 i=1 2
(8.48)
2
We do not care about the specific choice of λ(x) or λ̄(x) when there are multiple candidates,
and we only need their properties according to Lemma 8.11.10, so they can be arbitrary. Also, the
minimum of `2 -norm of an affine space can always be attained so argmin exists.
374
On the other hand, for any x ∈ RD , by (8.47), we have
q0
n
!
X X
∇R(x) + λ̄i (x)∇fi (x) = minn ∇R(x) + λi (x)∇fi (x)
λ∈R
i=1 1:(2q) 2 i=1 1:(2q) 2
n
X
≤ ∇R(x) + λi (x)∇fi (x) = kF1:(2q) (x)k2
i=1 1:(2q) 2
n
X
≤ kF (x)k2 ≤ ∇R(x) + λi (x∗ )∇fi (x)
i=1 2
(8.49)
where the first and third inequalities follow from the definition of F (x). Let x → x∗ ,
by the continuity of ∇R(x) and {∇fi (x)}i∈[n] , we have
n
X n
X
∗ ∗
lim ∇R(x) + λi (x )∇fi (x) = ∇R(x ) + λi (x∗ )∇fi (x∗ ) (8.50)
x→x∗
i=1 2 i=1 2
Denote K(x)
e e ij (x))(i,j)∈[q0 ]2 = (h∇fi (x)1:(2q) , ∇fi (x)1:(2q) i)(i,j)∈[q0 ]2 . By apply-
= (K
e ∗ ) is full-rank, it also holds that
ing the same argument as in Case I, since K(x
limx→x∗ λ̄(x) = λ̄(x∗ ), and thus
q0 q 0
X X
∗ ∗
lim ∇R(x) + λ̄i (x)∇fi (x)1:(2q) = ∇R(x) + λ̄i (x )∇fi (x ) .
x→x∗
i=1 2 i=1 1:(2q) 2
(8.51)
n
X
lim∗ kF1:(2q) (x)k2 = lim∗ minn ∇R(x) + λi ∇fi (x) = kF (x∗ )k2 . (8.52)
x→x x→x λ∈R
i=1 1:(2q) 2
p
Moreover, since kF(2q+1):D (x)k2 = kF (x)k22 − kF1:(2q) (x)k22 , we also have
375
It then remains to show that limx→x∗ F1:(2q) (x) = F1:(2q) (x∗ ), which directly follows
from limx→x∗ λ1:q0 (x) = λ1:q0 (x∗ ) = λ̄(x∗ ).
e ∗ ) 0, we can
Now, for any > 0, due to the convergence of λ̄(x) and that K(x
pick a sufficiently small δ1 such that for some constant α > 0 and all x ∈ Bδ1 (x∗ ), it
holds that kλ̄(x) − λ̄(x∗ )k2 ≤ /2 and
2 2
q0 q0
X X
∇R(x) + λ̄i ∇fi (x) ≥ ∇R(x) + λ̄i (x)∇fi (x) + αkλ̄ − λ̄(x)k22 .
i=1 1:(2q) 2 i=1 1:(2q) 2
(8.54)
for all λ̄ ∈ Rp , where the inequality follows from the strong convexity. Meanwhile, due
to (8.47), we have
q0 n
X X
lim∗ ∇R(x) + λi (x)∇fi (x) = lim∗ ∇R(x) + λi (x)∇fi (x)
x→x x→x
i=1 1:(2q) 2 i=1 1:(2q) 2
q0
X
∗ ∗
= ∇R(x) + λ̄i (x )∇fi (x )
i=1 1:(2q) 2
q0
X
= lim∗ ∇R(x) + λ̄i (x)∇fi (x) .
x→x
i=1 1:(2q) 2
where the second equality follows from (8.52) and the second equality is due to (8.51).
Therefore, we can pick a sufficiently small δ2 such that
q0 q0
α2
X X
∇R(x) + λi (x)∇fi (x) ≤ ∇R(x) + λ̄i (x)∇fi (x) +
i=1 1:(2q) 2 i=1 1:(2q) 2 4
(8.55)
for all x ∈ Bδ2 (x∗ ). Setting δ = min(δ1 , δ2 ), it follows from (8.54) and (8.55) that
kλ1:q0 (x) − λ̄(x)k2 ≤ , for all x ∈ Bδ (x∗ ).
2
376
Recall that we already have kλ̄(x) − λ̄(x∗ )k ≤ /2, and thus
kλ1:q0 (x) − λ(x∗ )1:q0 k2 = kλ1:q0 (x) − λ̄(x∗ )k2 ≤ kλ1:q0 (x) − λ̄(x)k2 + kλ̄(x) − λ̄(x∗ )k2 ≤
for all x ∈ Bδ (x∗ ). Therefore, we see that limx→x∗ λ1:q0 (x) = λ(x∗ )1:q0 .
Finally, it follows from the triangle inequality that
where, as x → x∗ , the first term vanishes by the convergence of λ1:q0 (x) and the
continuity of each ∇fi (x), the second term converges to 0 by the continuity of ∇R(x)
and the third term vanishes by (8.53). Therefore, we conclude that
Proof of Lemma 8.11.13. Let [0, T ) be the right maximal interval of existence of the
solution of Riemannian gradient glow and suppose T 6= ∞. Since R(x(t)) is monotone
decreasing, thus R(x(t)) is upper bounded by R(x(0)) and therefore k∇R(x(t))k is
dx(t)
also upper bounded. Since dt
≤ k∇R(x(t))k2 for any t < T , the left limit
2
377
x(T −) := limτ →T − x(τ ) must exist. By Corollary 1, Perko [233], x(T −) belongs to
boundary of U , i.e., uj (T −) = 0 or vj (T −) = 0 for some j ∈ [d] by Lemma 8.11.11.
By the definition of the Riemannian gradient flow in (8.20), we have
d dx(t)
(uj (t)vj (t)) = vj (t)e>
j uj (t)e>
j
dt dt
1
=− v (t)e> uj (t)e> F (x(t)).
4 j j j
Pn
By the expression of F (x(t)) = ∇R(x(t)) + i=1 λi (x(t))∇fi (x(t)), we then have
d
(uj (t)vj (t))
dt " # " n #
n n n
2X 2 1X 2X 2 1X
=− z + λi (x(t))zi,j uj (t)vj (t) − z − λi (x(t))zi,j uj (t)vj (t)
n i=1 i,j 2 i=1 n i=1 i,j 2 i=1
n
!
4X 2
=− z uj (t)vj (t).
n i=1 i,j
Pn
Denote sj = 4
n
2
i=1 zi,j . It follows that |uj (t)vj (t)| = |uj (0)vj (0)|e−sj t for all t ∈ [0, T ).
Taking the limit we have |uj (T −)vj (T −)| ≥ |uj (0)vj (0)|e−sj T > 0. Contradiction with
T 6= ∞!
Before showing that F satisfies the PL condition, we need the following two
intermediate results. Given two points u and v in Rd , we say u weakly dominate v
(written as u ≤ v) if and only if ui ≤ vi , for all i ∈ [d]. Given two subsets A and B of
RD , we say A weakly dominates B if and only if for any point v in B, there exists a
point u ∈ A such that u ≤ v.
378
As a direct implication, for any continuous function f : P → R, which is coordinate-
wise non-decreasing, minx∈U f (x) can always be achieved.
the induction hypothesis, there exists a u ∈ Br1 (0) ∩ Q such that u ≤ y. Thus u ≤ v
and meets our requirement.
If ω has different signs across its coordinates, we take y1 , y2 to be the first
intersections of the line {v − λ|ω|}λ∈R and the boundary of U in directions of
λ > 0 and λ < 0, respectively. Again by the induction hypothesis, there exist
u1 , u2 ∈ Br1 (0) ∩ Q such that u1 ≤ y1 and u2 ≤ y2 . Since v lies in the line con-
necting u1 and u2 , there exists some h ∈ [0, 1] such that v = (1 − h)u1 + hu2 . It
then follows that (1 − h)u1 + hu2 ≤ (1 − h)y1 + hy2 = v. Now since Q is convex,
we have (1 − h)u1 + hu2 ∈ Q, and by the triangle inequality it also holds that
k(1 − h)u1 + hu2 k1 ≤ r, so (1 − h)u1 + hu2 ∈ Br1 (0) ∩ Q. Therefore, we conclude that
Br1 (0) ∩ Q weakly dominates Q, and thus the proposition holds for D. This completes
the proof by induction.
380
sufficiently small r, and thus the proposition holds for D. This completes the proof
by induction.
Lemma 8.11.16. (Polyak-Lojasiewicz condition for F .) For any x∗ such that L(x∗ ) =
0, i.e., x∗ ∈ Γ, there exist a neighbourhood U 0 of x∗ and a constant c > 0, such that
kF (x)k22 ≥ c · max(R(x) − R(x∗ ), 0) for all x ∈ U 0 ∩ Γ. Note this requirement is only
non-trivial when kF (x∗ )k2 = 0 since F is continuous.
Proof of Lemma 8.11.16. It suffices to show the PL condition for {x | F (x) = 0}.
We need to show for any x∗ satisfying F (x∗ ) = 0, there exist some > 0 and
C > 0, such that for all x ∈ Γ ∩ B2 (x∗ ) with R(x) > R(x∗ ), it holds that kF (x)k22 ≥
C(R(x) − R(x∗ )).
u
Canonical Case. We first prove the case where x = v
itself is a canonical
parametrization of w = u 2
− v 2 , i.e., uj vj = 0 for all j ∈ [d]. Since x∗ satisfies
∇F (x∗ ) = 0, by Lemma 8.11.11, we have x∗ = ψ(w∗ ) where w∗ = (u∗ ) 2
− (v ∗ ) 2 . In
this case, we can rewrite both R and F as functions of w ∈ Rd . In detail, we define
R0 (w) = R(ψ(w)) and F 0 (w) = F (ψ(w)) for all w ∈ Rd . For any w in a sufficiently
small neighbourhood of w∗ , it holds that sign(wj ) = sign(wj∗ ) for all j ∈ [q]. Below
we show that for each possible sign pattern of w(q+1):d , there exists some constant
C which admits the PL condition in the corresponding orthant. Then we take the
minimum of all C from different orthant and the proof is completed. W.L.O.G., we
assume that wj ≥ 0, for all j = q + 1 . . . , d.
We temporarily reorder the coordinates as x = (u1 , v1 , u2 , v2 , . . . , ud , vd )> . Recall
that Z = [z1 , . . . , zn ]> is a n-by-d matrix, and we have
2
kF 0 (w)k2 = minn (a − sign(w) Z > λ) 2 , |w| ,
λ∈R
381
Pn 2
where a = 8
n
∈ Rd . Since F (x∗ ) = 0, there must exist λ∗ ∈ Rn , such that
i=1 zi
the first 2q coordinates of ∇R(x∗ ) + ni=1 λ∗i ∇fi (x∗ ) are equal to 0. As argued in the
P
proof of Lemma 8.11.12, we can assume the first q 0 rows of Z are linear independent
0
on q coordinates for some q ∈ [q]. In other words, Z can be written as
the first
ZA ZB 0
where ZA ∈ Rq ×q . We further denote λa := λ1:q0 , λb := λ(q0 +1):n , aa := a1:q
0 ZD
and ab := a(q+1):d , wa := w1:q and wb := w(q+1):d for convenience, then we have
2
kF 0 (w)k2 = minn (aa + sign(wa ) ZA> λa ) 2 , |wa | + (ab + ZB> λa + ZD
>
λb ) 2 , wb .
λ∈R
(8.56)
Pn
Since every w in Γ is a global minimizer, R0 (w) = R0 (w) + i=1 λ∗i (zi> w − yi ) :=
g > w + R0 (w∗ ), where g = sign(w) a + Z > λ∗ . Similarly we define ga := g1:q and
gb := g(q+1):d . It holds that ga = 0 and we assume ZD gb = 0 without loss of generality,
because this can always be done by picking suitable λ∗i for i = q 0 + 1, . . . , n. (We have
such freedom on λ∗q0 +1:n because they doesn’t affect the first 2q coordinates.)
We denote λa − λ∗a by ∆λa , then since 0 = ga = sign(wa ) ZA> λ∗a + aa , we further
have
(aa + sign(wa ) ZA> λa ) 2 , |wa | = (aa + sign(wa ) ZA> λ∗a + sign(wa ) ZA> ∆λa ) 2 , |wa |
2
kF 0 (w)k2 = min (ZA> ∆λa ) 2 , |wa | + (gb + ZB> ∆λa + ZD
>
λb ) 2 , wb . (8.57)
λ∈RD
382
Now suppose R0 (w) − R0 (w∗ ) = gb> wb = δ for some sufficiently small δ (which can
be controlled by ). We will proceed in the following two cases separately.
√
• Case I.1: k∆λa k2 = Ω( δ). Since ZA has full row rank, (ZA> ∆λa ) 2
1
=
2
(ZA> ∆λa ) 2
≥ k∆λa k22 λ2min (ZA ) is lower-bounded. On the other hand, we can
choose small enough such that ∀i ∈ [q]|(wa )2i | ≥ 12 (wa∗ )2i . Thus the first term of
Equation (8.57) is lower bounded by k∆λa k22 λ2min (ZA ) · mini∈[q] 12 (wa∗ )2i = Ω(δ) =
Ω(R0 (w) − R0 (w∗ )).
√
• Case I.2: k∆λa k2 = O( δ). Let u = gb + ZB> ∆λa + ZD
>
λb , then we have
> 0
u ∈ S + Bc2√δ (0) for some constant c > 0, where S = {gb + ZD λb | λb ∈ Rn−q }.
1
By Lemma 8.11.14, there exists some constant c0 ≥ 1, such that c0
· S weakly
dominates S + Bc2√δ (0). Thus we have kF 0 (w)k22 ≥ inf u∈S+Bc√δ (0) hu 2 , wb i ≥
inf u∈ 1 ·S hs 2 , wb i, where the last step is because each coordinate of wb is non-
c0
negative.
kF 0 (w)k22 D
2 wb
E
inf ≥ inf inf u ,
w:R0 (w)−R0 (w∗ )=δ>0 R0 (w) − R0 (w ∗ ) wb :R0 (w)−R0 (w∗ )=δ>0 u∈ 1 ·S δ
c0
1
≥ inf u 2 , wb . (8.58)
c20 wb ∈ δ
gb +A,wb ≥0,u∈S
kgb k22
383
On the other hand, we have u> wb = δ > 0 for all wb ∈ δ
g
kgb k22 b
+ A and u ∈ S,
by ZD gb = 0 and the definition of A. This implies there exists at least one
i ∈ [d − q 0 ] such that w2,i ui > 0, which further implies hu 2 , wb i > 0. Therefore,
we conclude that kF 0 (w)k22 = Ω(R0 (w) − R0 (w0 )).
u
2 2
General Case. Next, for any general x = v
, we define w = u −v and
m = min{u 2 , v 2 }, where min is taken coordinate-wise. Then we can rewrite kF (x)k22
as
kF (x)k22
2
a Z u
= minn + λ
λ∈R
a −Z v
2
2
2
a Z u
= minn + λ
λ∈R 2
a −Z v
1
2
a Z 2 m
= minn + λ ψ(w) +
λ∈R
a −Z m
1
2 2
a Z a Z m
≥ minn + λ ψ(w) 2 + minn + λ
λ∈R λ∈R
a −Z a −Z m
1 1
2 2
√
a Z a Z m
= minn + λ ψ(w) + minn + λ √ .
λ∈R λ∈R
a −Z a −Z m
2 2
384
Then applying the result for the previous case yields the following for some constant
C ∈ (0, 1):
2
√
a Z m
kF (x)k22 ≥ C(R(ψ(w)) − R(ψ(w∗ )) + minn + λ √
λ∈R
a −Z m
2
= C(R(ψ(w)) − R(x∗ ) + 2 a 2 , m
where the first equality follows from the fact that x∗ = ψ(w∗ ) and the last inequality is
due to the fact that both R(ψ(w) − R(ψ(w∗ )) and R(x) − R(ψ(w)) are non-negative.
This completes the proof.
Now, based on the PL condition, we can show that (8.20) indeed converges.
Lemma 8.11.17. The trajectory of the flow defined in (8.20) has finite length, i.e.,
R ∞ dx
k k dt < ∞ for any x∗ ∈ Γ. Moreover, x(t) converges to some x(∞) when
t=0 dt 2
t → ∞ with F (x(∞)) = 0.
Proof of Lemma 8.11.17. Note that along the Riemannian gradient flow, R(x(t)) is
non-increasing, thus kx(t)k2 is bounded over time and {x(t)}t≥0 has at least one limit
point, which we will call x∗ . Therefore, R(x∗ ) is a limit point of R(x(t)), and again since
R(x(t)) is non-increasing, it follows that R(x(t)) ≥ R(x∗ ) and limt→∞ R(x(t)) = R(x∗ ).
Below we will show limt→∞ x(t) = x∗ .
D E
Note that dt = ∇R(x(t)), dt = − ∇R(x(t)), 14 F (x(t)) = − 41 kF (x(t))k22
dR(x(t)) dx(t)
where the last equality applies Lemma 8.11.10. By Lemma 8.11.16, there exists a
neighbourhood of x∗ , U 0 , in which PL condition holds of F . Since x∗ is a limit point,
there exists a time T0 , such that xT0 ∈ U 0 . Let T1 = inf t≥T0 {x(t) ∈
/ U 0 } (which is
385
equal to ∞ if x(t) ∈ U 0 for all t ≥ T0 ). Since x(t) is continuous in t and U is open, we
√
know T1 > T0 and for all t ∈ [T0 , T1 ), we have kF (x(t))k2 ≥ c(R(x(t)) − R(x∗ ))1/2 .
Thus it holds that for t ∈ [T0 , T1 ),
√
d(R(x(t)) − R(x∗ )) c
≤ − (R(x(t)) − R(x∗ ))1/2 kF (x(t))k2 ,
dt 4
that is,
√
d(R(x(t)) − R(x∗ ))1/2 c
≤− kF (x(t))k2 .
dt 8
Therefore, we have
Z T1
8
kF (x(t))k2 dt ≤ √ (R(x(T0 )) − R(x∗ ))1/2 . (8.59)
t=T0 c
Thus if we pick T0 such that R(x(T0 )) − R(x∗ ) is sufficiently small, R(T1 ) will remain
in U if T1 is finite. Contradiction! This implies that T1 has to be ∞. Therefore,
Equation (8.59) shows that the trajectory of x(t) is of finite length, so x(∞) :=
limt→∞ x(t) exists and is equal to x∗ . As a by-product, F (x∗ ) must be 0.
Finally, collecting all the above lemmas, we are able to prove Lemma 8.7.5. In
Lemma 8.11.17 we already show the convergence of x(t) as t → ∞, the main part
of the proof of Lemma 8.7.5 is to show the x(∞) cannot be sub-optimal stationary
points of R on Γ, the closure of Γ. The key idea here is that we can construct a
different potential φ for each such sub-optimal stationary point x∗ , such that (1) φ(xt )
is locally increasing in a sufficiently neighborhood of x∗ and (2) limx→x∗ φ(x) = −∞.
Lemma 8.7.5. Let {xt }t≥0 ⊆ RD be generated by the flow defined in (8.20) with any
initialization x0 ∈ Γ. Then x∞ = limt→∞ xt exists. Moreover, x∞ = x∗ is the optimal
solution of (8.21).
386
u(∞)
Proof of Lemma 8.7.5. We will prove by contradiction. Suppose x(∞) = v(∞)
=
2
limt→∞ x(t) is not the optimal solution to (8.21). Denote w(t) = (u(t)) − (v(t)) 2 ,
then w(∞) = limt→∞ w(t) is not the optimal solution to (8.35). Thus we have
R(w(t)) > R(w∗ ). Without loss of generality, suppose there is some q ∈ [d] such
that (ui (∞))2 + (vi (∞))2 > 0 for all i = 1, . . . , q and ui (∞) = vi (∞) = 0 for all
i = q + 1, . . . , d. Again, as argued in the proof of Lemma 8.11.12, we can assume that,
for some q 0 ∈ [q],
Since both w(∞) and w∗ satisfy the constraint that Zw(∞) = Zw∗ = Y , we further
have
d
X
wj∗ ln(uj )2 1{wj∗ > 0} − ln(vj )2 1{wj∗ < 0} .
ϕ(x) = ϕ(u, v) =
j=q+1
Clearly limt→∞ ϕ(x(t)) = −∞ if limt→∞ x(t) = x(∞). Below we will show contra-
diction if x(∞) is suboptimal. Consider the dynamics of ϕ(x) along the Riemannian
gradient flow:
dϕ dx(t) 1
(x(t)) = ∇ϕ(x(t)), = − ∇ϕ(x(t)), F (x(t)) (8.62)
dt dt 4
387
where F is defined previously in Lemma 8.11.10. Recall the definition of F , and we
have
* q 0 +
1 1X
h∇ϕ(x(t)), F (x(t))i = ∇ϕ(x(t)), ∇R(x(t)) + λi (x(t))∇fi (x(t))
4 4 i=1
| {z }
I1
* n
+
1 X
+ ∇ϕ(x(t)), λi (x(t))∇fi (x(t)) . (8.63)
4 i=q0 +1
| {z }
I2
d
1{wj∗ > 0} 1{wj∗ < 0}
X
∇ϕ(x) = 2wj∗ · ej − · eD+j
j=q+1
uj vj
zi u
where ej is the j-th canonical base of Rd . Recall that ∇fi (x) = 2 −zi v
, and we
further have
n d
1{wj∗ > 0} 1{wj∗ < 0}
X X
I2 = λi (x(t)) wj∗ ui +hej , zi hej , zi vi
0
i=q +1 j=q+1
uj vj
n d
1{wj∗ > 0} 1{wj∗ < 0}
X X
∗
= λi (x(t)) wj zi,j uj + zi,j vj
i=q 0 +1 j=q+1
uj vj
n
X d
X n
X
= λi (x(t)) wj∗ zi,j = ∗
λi (x(t))hzi,(q+1):d , w(q+1):d i=0 (8.64)
i=q 0 +1 j=q+1 i=q 0 +1
q 0
X
R(w)
e = R(w) + λi (x(∞))(fei (w) − yi ).
i=1
388
Clearly, for any w ∈ RD satisfying Zw = Y , it holds that fei (w) − yi = 0 for each
i ∈ [n], and thus R(w) = R(w).
e In particular, we have R(w(∞))
e = R(w(∞)) >
R(w∗ ) = R(w
e ∗ ). Since R(w)
e is a convex function, it follows that R(w(∞)
e + s(w∗ −
e ∗ ) + (1 − s)R(∞)
w(∞))) ≤ sR(w e < R(w(∞))
e for all 0 < s ≤ 1, which implies
dR
(w(∞) + s(w∗ − w(∞)))|s=0 < −2c < 0+ for some constant c > 0. Note that, for
e
dt
d n
!
∗ 4X X 2
R(w(∞) + s(w − w(∞))) = z |wj (∞) + s(wj∗ − wj (∞))|
n j=1 i=1 i,j
q n
!
4X X 2
= z sign(wj (∞))(wj (∞) + s(wj∗ − wj (∞)))
n j=1 i=1 i,j
d n
!
4 X X 2
+ z s|wj∗ |.
n j=q+1 i=1 i,j
dR
e
−2c > (w(∞) + s(w∗ − w(∞)))
dt s=0
q n
! d n
!
4X X 2 4 X X
= z sign(wj (∞))(wj∗ − wj (∞)) + z 2 |wj∗ |
n j=1 i=1 i,j n j=q+1 i=1 i,j
q0
X
+ λi (x(∞))zi> (w∗ − wj (∞))
i=1
qn
! d n
!
4X X 2 4 X X 2
= z sign(wj (∞))(wj∗ − w(∞)) + z |wj∗ |
n j=1 i=1 i,j n j=q+1 i=1 i,j
q q0 q0
X X d
X X
+ (wj∗ − wj (∞)) λi (x(∞))zi,j + wj∗ λi (x(∞))zi,j (8.65)
j=1 i=1 j=q+1 i=1
where the second equality follows from the fact that w(q+1):d (∞) = 0. Since x(t)
converges to x(∞), we must have F (x(∞)) = 0, which implies that for each j ∈
389
{1, . . . , q},
q0 q0
" n #
∂R X ∂fi 4X 2 X
0= (x(∞)) + λi (x(∞)) (x(∞)) = 2uj (∞) zi,j + λi (x(∞))zi,j ,
∂uj i=1
∂u j n i=1 i=1
q0 q0
" n #
∂R X ∂fi 4X 2 X
0= (x(∞)) + λi (x(∞)) (x(∞)) = 2vj (∞) z − λi (x(∞))zi,j .
∂vj i=1
∂vj n i=1 i,j i=1
n q0
4X 2 X
zi,j = − sign(wj (∞)) λi (x(∞))zi,j , for all j ∈ [q].
n i=1 i=1
q q0 !
d n
X X 4 X X 2
−2c > − sign(wj (∞))2 (wj∗ − w(∞)) λi (x(∞))zi,j + z |wj∗ |
j=1 i=1
n j=q+1 i=1 i,j
q q0 d q0
X X X X
+ (wj∗ − wj (∞)) λi (x(∞))zi,j + wj∗ λi (x(∞))zi,j
j=1 i=1 j=q+1 i=1
! q 0
d n d
4 X X 2 X X
= z |wj∗ | + wj∗ λi (x(∞))zi,j (8.66)
n j=q+1 i=1 i,j j=q+1 i=1
390
On the other hand, by directly evaluating ∇R(x(t)) and each ∇fi (x(t)), we can
compute I1 as
q0
d
" n #
X wj∗ 1{wj∗ > 0} 2 X 2 1 X
I1 = zi,j uj (t) + λi (x(t))zi,j uj (t)
j=q+1
uj (t) n i=1
2 i=1
q0
d
" n #
X wj∗ 1{wj∗ < 0} 2 X 1 X
2
− zi,j vj (t) − λi (x(t))zi,j vj (t)
j=q+1
vj (t) n i=1
2 i=1
q0
d n
! d
2 X X 2 1 X X
= z |wj∗ | + w∗ λi (x(t))zi,j
n j=q+1 i=1 i,j 2 j=q+1 j i=1
q0
d n
! d
2 X X 2 ∗ 1 X ∗X
= z |wj | + w λi (x(∞))zi,j
n j=q+1 i=1 i,j 2 j=q+1 j i=1
q 0
d
1 X ∗X
+ wj (λi (x(t)) − λi (x(∞))) zi,j .
2 j=q+1 i=1
We already know that λ1:q0 (x) is continuous at x(∞) by the proof of Lemma 8.11.12,
so the third term converges to 0 as x(t) tends to x(∞). Now, applying (8.66), we
immediately see that there exists some δ > 0 such that I1 < −c for x(t) ∈ Bδ (x(∞)).
As we have shown in the above that I2 = 0, it then follows from (8.62) and (8.63) that
dϕ
(x(t)) > c, for all x(t) ∈ Bδ (x(∞)). (8.67)
dt
Since limt→∞ x(t) = x(∞), there exists some T > 0 such that x(t) ∈ Bδ (x(∞)) for all
t > T . By the proof ofLemma 8.11.13, we know that ϕ(x(T )) > −∞, then it follows
from (8.67) that
Z ∞ Z ∞
dϕ(x(t))
lim ϕ(x(t)) = ϕ(x(T )) + dt > ϕ(x(T )) + cdt = ∞
t→∞ T dt T
391
8.11.6 Proof of Theorem 8.7.7
Here we present the lower bound on the sample complexity of GD in the kernel regime.
i.i.d.
Theorem 8.7.7. Assume z1 , . . . , zn ∼ N (0, Id ) and yi = zi> w∗ , for all i ∈ [n].
Define the loss with linearized model as L(x) = ni=1 (fi (x0 ) + h∇fi (x0 ), x − x0 i − yi )2 ,
P
rate schedule {ηt }t≥1 , and any fixed number of steps T , the expected `2 loss of x(T )
is at least (1 − nd ) kw∗ k22 , where x(T ) is the T -th iterate of GD on L, i.e., x(t + 1) =
x(t) − ηt ∇L(x(t)), for all t ≥ 0.
Ew∗ Ezi kw∗ − (u(T ) − v(T ))k22 = Ezi Ew∗ kw∗ − (u(T ) − v(T ))k22
393
Chapter 9
Deep learning experiments by Cohen et al. [234] using deterministic Gradient Descent
(GD) revealed an Edge of Stability (EoS) phase when learning rate (LR) and sharpness
(i.e., the largest eigenvalue of Hessian) no longer behave as in traditional optimization.
Sharpness stabilizes around 2/LR and loss goes up and down across iterations, yet
still with an overall downward trend. This chapter mathematically analyzes a new
mechanism of implicit regularization in the EoS phase, whereby GD updates due
to non-smooth loss landscape turn out to evolve along some deterministic flow on
the manifold of minimizers as introduced in Chapter 8. This is in contrast to many
previous results about implicit bias either relying on infinitesimal updates or noise in
gradient. Formally, for any smooth function L with certain regularity condition, this
η
effect is demonstrated for (1) Normalized GD, i.e., GD with a varying LR ηt = k∇L(x(t))k
p
and loss L; (2) GD with constant LR and loss L − minx L(x). Both provably enter
394
the Edge of Stability, with the associated flow on the manifold minimizing λ1 (∇2 L).
The above theoretical results have been corroborated by an experimental study.
9.1 Introduction
The above defined stableness is a better indicator for EoS than only using the
sharpness at a specific point x, i.e. ηλ1 (∇2 L(x)) < 2, because the loss can still oscillate
2
in the latter case. A concrete example is L(x) = |x|, x ∈ R. For any c ∈ (0, 1) and
LR η > 0, the GD iterates x(2k) = cη and x(2k + 1) = −(1 − c)η, always have zero
sharpness for all k ∈ N, but Descent Lemma doesn’t apply because the gradient is not
continuous around x = 0 (i.e. the sharpness is infinity when x = 0). As a result, the
loss is not stable and oscillates between cη and (1 − c)η.
2
See such experiments (e.g., ReLU CNN (+BN), Figure 75) in Appendix of in Cohen et al. [234].
396
102 102
100 100
101 101
L 50 L 50
12 12
x
0
x
100 0 100
4 4
4 4
y0 4 4
y0 4 4
√
(a) GD on L (b) Normalized GD on L
Figure 9.1: GD operating on EoS oscillates around the zero loss manifold Γ = {(x, y) |
y = 0} while slowly moving towards flatter local minima. Here L(x, y) = (1 + x2 )y 2
and the sharpness of L decreases as |x| decreases.
The first setting, which is simple yet quite general, is to consider a modified
training loss f (L) where f : R → R is a monotone increasing but non-smooth function.
√
For concreteness, assume GD is performed on L e := L where L is a smooth loss
∇L
function with minx L(x) = 0 and ∇2 L 6= 0 at its minimizers. Note that ∇L
e= √
2 L
2L∇2 L−∇L∇L>
and ∇2 L
e= √ 3 , which implies ∇2 L
e must diverge whenever x converges to
4 L
any minimizer where ∇2 L has rank at least 2, since ∇L∇L> is rank-1. (An analysis
is also possible when ∇2 L is rank-1, which is the reason for Definition 9.1.1.)
The second setting assumes that the loss is smooth but learning rate is effectively
adaptive. We focus a concrete example, Normalized Gradient Descent, x ← x −
η∇L/k∇Lk, which exhibits EoS behavior as ∇L → 0. We can view Normalized GD
η
as GD with a varying LR ηt = k∇L(x(t))k
, which goes to infinity when ∇L → 0.
3
These analyses will require (1) The zero-loss solution set {x | L(x) = 0} contains
a (D − M ) dimensional submanifold of RD for some 1 ≤ M ≤ D and we denote it by
Γ and (2) ∇2 L(x) is rank-M for any x ∈ Γ. Note that while modern deep learning
3
Without loss of generality, we assume minx0 L(x0 ) = 0 throughout the paper. The main results
for Normalized GD still hold √ if we relax the assumption
√ and√only √assume Γ to be a manifold of
local minimizers. For GD on L, we need to replace L by L − Lmin where Lmin is the local
minimum.
397
evolved using non-differentiable losses, the recent use of activations such as Swish
[236] instead of ReLU has allowed differentiable losses without harming performance.
398
Novelty of Our Analysis: Our analysis is inspired by the mathematical framework
of studying limiting dynamics of SGD around manifold of minimizers by Li et al. [133],
where the high-level idea is to introduce a projection function Φ mapping the current
iterate xt to the manifold and it suffices to understand the dynamics of Φ(xt ). It turns
out that the one-step update of Φ(xt ) depends on the second moment of (stochastic)
gradient at xt , E[∇L(xt )(∇L(xt ))> ]. While for SGD the second moment converges to
the covariance matrix of stochastic gradient (see Chapter 8) as xt gets close to the
√ ∇L(xt )
manifold when η → 0, for GD operating on EOS, the updates ∇ L(xt ) or k∇L(x t )k
is
non-smooth and not even defined at the manifold of the minimizers! To show Φ(xt )
moves in the direction which decreases the sharpness, the main technical difficulty is
√ ∇L(xt )
to show that ∇ L(xt ) or k∇L(x t )k
aligns to the top eigenvector of the Hessian ∇2 L(xt )
and then the analysis follows from the framework by Li et al. [133].
To prove the alignment between the gradient and the top eigenvector of Hessian, it
boils down to analyze Normalized GD on quadratic functions (9.2), which to the best
of our knowledge has not been studied before. The dynamics is like chaotic version of
power iteration, and we manage to show that the iterate will always align to the top
eigenvector of Hessian of the quadratic loss. The proof is based on identifying a novel
potential (Section 9.3) and might be of independent interest.
Sharpness: Low sharpness has long been related to flat minima and thus to good
generalization [9, 237]. Recent study on predictors of generalization [238] does show
sharpness-related measures as being good predictors, leading to SAM algorithm that
improves generalization by explicitly controlling a parameter related to sharpness [8].
However, Dinh et al. [239] show that due to the positive homogeneity in the network
architecture, networks with rescaled parameters can have very different sharpness yet
399
be the same to the original one in function space. This observation weakens correlation
between sharpness and and generalization gap and makes the definition of sharpness
ambiguous. In face of this challenge, multiple notions of scale-invariant sharpness have
been proposed [240–243]. Especially, Yi et al. [244], Kwon et al. [245] derived new
algorithms with better generalization by explicitly regularizing new sharpness notions
aware of the symmetry and invariance in the network. He et al. [246] goes beyond
the notion of sharpness/flatness and argues that the local minima of modern deep
networks can be asymmetric, that is, sharp on one side, but flat on the other side.
Implicit Bias: The notion that training algorithm plays an active role in selecting
the solution (when multiple optima exist) has been termed the implicit bias of the
algorithm [105] and studied in a large number of papers [14, 20, 25, 29, 95, 98, 100,
112, 138, 142, 159, 172, 247]. In the infinite width limit, the implicit bias of Gradient
Descent is shown to be the solution with the minimal RKHS norm with respect to
the Neural Tangent Kernel (NTK) [15, 109, 148, 150–155, 213]. The implicit bias
results from these papers are typically proved by performing a trajectory analysis
for (Stochastic) Gradient Descent. Most of the results can be directly extended to
the continuous limit (i.e., GD infinitesimal LR) and even some heavily relies on the
400
conservation property which only holds for the continuous limit. In sharp contrast,
the implicit bias shown in this chapter – reducing the sharpness along the minimizer
manifold – requires finite LR and doesn’t exist for the corresponding continuous limit.
Other implicit bias results that fundamentally relies on the finiteness of LR includes
stability analysis [248, 249] and implicit gradient regularization [250], which is a special
case of approximation results for stochastic modified equation by Li et al. [190, 214].
To introduce ideas that will be used in the main results, we sketch analysis of
Normalized GD (9.1) on quadratic loss function L(x) = 12 x> Ax where A ∈ RD×D
is positive definite with eigenvalues λ1 > λ2 ≥ . . . ≥ λD and v1 , . . . , vD are the
corresponding eigenvectors.
∇L(x(t)) Ax(t)
x(t+1) = x(t)−η = x(t)−η . (9.1)
k∇L(x(t))k kAx(t)k
401
Our main result Theorem 9.3.1 is that the iterates of Normalized GD x(t) converge
to v1 in direction, from which the loss oscillation Corollary 9.3.2 follows, suggesting
that GD is operating in EoS. Since in quadratic case there is only one local minima,
there is of course no need to talk about implicit bias. However, the observation that
the GD iterates always align to the top eigenvector as well as the technique used in
its proof play a very important role for deriving the sharpness-reduction implicit bias
for the case of general loss functions.
Ax(t)
Define x
e(t) = η
, and the following update rule (9.2) holds. It is clear that the
convergence of x
et to v1 in direction implies the convergence of xt as well.
x
e(t)
x e(t) − A
e(t + 1) = x . (9.2)
ke
x(t)k
Theorem 9.3.1. If |hv1 , x 6 0, ∀t ≥ 0, then there exists 0 < C < 1 and s ∈ {±1}
e(t)i| =
such that limt→∞ x e(2t + 1) = (C − 1)sλ1 v1 .
e(2t) = Csλ1 v1 and limt→∞ x
As a direct corollary, the loss oscillates as between time step 2t and time step
2t + 1 as t → ∞. This shows that the behavior of loss is not monotonic and hence
indicates the edge of stability phenomena for the quadratic loss.
402
vanishing increment over steps turns out to suggest the x
e(t) must converge to v1 in
direction.
λ1
Lemma 9.3.3 (Preparation Phase). For any j ∈ [D] and t ≥ λj
ln λλ1j +
max{ kex(0)k−λ
λD
1
e(t) ∈ Ij .
, 0}, it holds that x
Proof of Lemma 9.3.3. First, we show for any j ∈ [D], Ij is indeed an invariant set
for update rule (9.2) via Lemma 9.9.1. With straightforward calculation, one can
λD kP (j:D) x
e(t)k
show that for any j ∈ [D], P (j:D) x
e(t) decreases by kex(t)k
if P (j:D) x
e(t) ≥
λj (Lemma 9.9.2). Setting j = 1, we have ke
x(t)k decreases by λD if ke
x(t)k ≥
λ1 (Corollary 9.9.3). Thus for all t ≥ max{ kex(0)k−λ
λD
1
e(t) ∈ I1 . Finally once
, 0}, x
x x(t)k by λ1 , and thus P (j:D) x
e(t) ∈ I1 , we can upper bound ke e(t) shrinks at least
λD λ1
by a factor of λ1
e(t) will be in Ij in another
per step, which implies x λj
ln λλ1j
steps.(Corollary 9.9.4)
e(T ) ∈ ∩D
Lemma 9.3.4 (Alignment Phase). If x j=1 Ij holds for some T , then for any
Proof of Lemma 9.3.4. First, Lemma 9.3.5 (proved in Section 9.9) shows that the
norm of the iterate x
e(t) remains above 0.5λ1 for only one time-step.
λ1
Lemma 9.3.5. For any t with x e(t) ∈ ∩D
j=1 Ij , if ke
x(t)k > 2
, then ke
x(t + 1)k ≤
λ2
max λ21 − 2λD1 , λ1 − ke
x(t)k .
403
Invariant sets ||x(t)|| | v1, x(t) |
2 1 1.0 100
2
||x|| = 2
1
1 2 0.8
1 10 2
0.6
0
v1
0.4 10 4
1
0.2 10 6
2
0.0
2 1 0 1 2 0 2 4 6 8 10 0 2 4 6 8 10
v2 Normalized GD steps Normalized GD steps
Figure 9.2: Visualization of key concepts and lemmas in the analysis for Normalized
GD on a 2D quadratic loss with λ1 = 1, λ2 = 0.4. Left: invariant sets (defined in
Lemma 9.3.3). Middle: ke x(t)k drops below λ21 in the next step whenever it is above
λ1
2
(Lemma 9.3.5). Right: |hv1 , xe(t)i| monotone increases among all the steps with
λ1
norm below 2 . (Lemma 9.3.6)
λ1 λ1
e(t) ∈ ∩D
Thus, for any t with x j=1 Ij and ke
x(t)k ≤ 2
, either ke
x(t + 1)k ≤ 2
, or
λ1 λ1
ke
x(t + 1)k > 2
, which in turn implies that ke
x(t + 2)k ≤ 2
by Lemma 9.3.5. The
proof of Lemma 9.3.4 is completed by induction on Lemma 9.3.6.
λ1
Lemma 9.3.6. For any step t with ke
x(t)k ≤ 2
, for any k ∈ {1, 2}, |hv1 , x
e(t + k)i| ≥
|hv1 , x
e(t)i|.
Proof of case k = 1 in Lemma 9.3.6 follows directly from plugging the assumption
λ1
ke
x(t)k ≤ 2
into (9.2) (See Lemma 9.9.5). The case of k = 2 in Lemma 9.3.6 follows
from Lemma 9.9.7. We defer the complete proof of Lemma 9.3.6 into Section 9.9.
To complete the proof for Theorem 9.3.1, we relate the increase in the projection
along v1 at any step t, |hv1 , x
e(t)i|, to the magnitude of the angle between x
e(t) and the
λ1
top eigenspace, θt . Briefly speaking, we show that if ke
x(t)k ≤ 2
, |hv1 , x
e(t)i| has to
increase by a factor of Θ(θt2 ) in two steps. Since |hv1 , x
e(t)i| is bounded and monotone
λ1
increases among {t | ke
x(t)k ≤ 2
} by Lemma 9.3.4, we conclude that θt gets arbitrarily
λ1 λ1
small for sufficiently large t with ke
x(t)k ≤ 2
, ke
x(t + 2)k ≤ 2
satisfied. Since the
one-step normalized GD update Equation (9.2) is continuous when bounded away
404
from origin, with a careful analysis, we conclude θt → 0 for all iterates. Please see
Section 9.9.3 for details.
q
1 >
√ q
1 >
Equivalence to GD on 2
x Ax: Below we show GD on loss L(x) = 2
x Ax,
Equation (9.3), follows the same update rule as Normalized GD on L(x) = 12 x> Ax,
up to a linear transformation.
√ Ax(t)
x(t + 1) = x(t) − η∇ L(x(t)) = x(t) − η p . (9.3)
2x(t)> Ax(t)
9.4 Notations
For any integer k, we denote C k as the set of the k times continuously differentiable
functions. For any mapping F , we use ∂F (x)[u] and ∂ 2 F (x)[u, v] to denote the first and
second order directional derivative of F at x along the derivation of u (and v). Given
the loss function L, the gradient flow (GF) governed by L can be described through a
Rτ
mapping φ : RD × [0, ∞) → RD satisfying φ(x, τ ) = x − 0 ∇L(φ(x, s))ds. We further
define the limiting map of gradient flow as Φ, that is, Φ(x) = limτ →∞ φ(x, τ ).
For a matrix A ∈ RD×D , we denote its eigenvalue-eigenvector pairs by
{λi (A), vi (A))}i∈[D] . For simplicity, whenever Φ is defined at point x, we use
{(λi (x), vi (x))}D 2
i=1 to denote the eigenvector-eigenvalue pairs of ∇ L(Φ(x)), with
λ1 (x) > λ2 (x) ≥ λ3 (x) . . . ≥ λD (x). As an analog to the quadratic case, we use x
e to
1/2
denote ∇2 L(Φ(x))(x − Φ(x)) for Normalized GD on L and (2∇2 L(Φ(x))) (x − Φ(x))
√
for GD on L. Furthermore, when the iterates x(t) are clear in the context, we also
use shorthand λi (t) := λi (x(t)), vi (t) := vi (x(t)) and θt ∈ [0, π2 ] to denote the angle
e(t) and top eigenspace of ∇2 L(Φ(x(t))). Given a differentiable submanifold
between x
Γ of RD and point x ∈ Γ, we use Px,Γ : Γ → RD to denote the projection operator
405
⊥
onto the normal space of Γ at x, and Px,Γ := ID − Px,Γ . As before, for notational
⊥ ⊥
convenience, we use the shorthand Pt,Γ := PΦ(x(t)),Γ and Pt,Γ := PΦ(x(t)),Γ .
In this section, we focus on the setting where LR η goes to 0 and we fix the
initialization xinit and the loss function L throughout this chapter. We use O(·) to
hide constants about xinit and L.
In this section we present the main results of this chapter. In Section 9.5.1, we make
our key assumptions that the minimizers of the loss function form a manifold. In
Sections 9.5.2 and 9.5.3 we present our main results for Normalized GD and GD on
√
L respectively. In Section 9.5.4 we show the above two settings for GD do enter the
regime of Edge of Statbility.
Similar to Chapter 8, we make the following assumption throughout this chapter. The
only difference between the following assumption and Assumption 8.5.1 is below we
assume L is C 4 smooth instead of C 3 . This extra degree of differentiability allows us to
give a non-asymptotic rate instead of asymptotic convergence as shown in Chapter 9.
Let U be the sets of points starting from which, gradient flow w.r.t. loss L
converges to some point in Γ, that is, U := {x ∈ RD | Φ(x) exists and Φ(x) ∈ Γ}.
Assumption 9.5.1 implies that U is open and Φ is C 3 on U . (By Lemma 8.8.2)
We also make the following assumption to ensure that λ1 (∇2 L(·)) is differentiable,
which is necessary for our main results, Theorems 9.5.4 and 9.5.6.
406
Assumption 9.5.2. For any x ∈ Γ, ∇2 L(x) has a positive eigengap, i.e.,
λ1 (∇2 L(x)) > λ2 (∇2 L(x)).
We first denote the iterates of Normalized GD with LR η by xη (t), with xη (0) ≡ xinit
for all η:
∇L(xη (t))
Normalized GD: xη (t + 1) = xη (t) − η (9.4)
k∇L(xη (t))k
The first theorem demonstrates the movement in the manifold, when the iterate
travels from xinit to a position that is O(η) distance closer to the manifold (more
specifically, Φ(xinit )). Moreover, just like the result in the quadratic case, we have more
fine-grained bounds on the projection of xη (t) − Φ(xη (t)) into the bottom-k eigenspace
of ∇2 L(Φ(xη (t))) for every k ∈ [D]. For convenience, we define the following quantity
for all j ∈ [d] and x ∈ U :
v v
uM uM
uX uX
Rj (x) := t hvi (x), x 2
ei − λj (x)η = t λ2i (x)hvi (x), x − Φ(x)i2 − λj (x)η
i=j i=j
In the quadratic case, Lemma 9.3.3 shows that Rj (x) will eventually become
non-positive for normalized GD iterates. Similarly, for the general loss, the following
theorem shows that Rj (xη (t)) eventually becomes approximately non-positive (smaller
than O(η 2 )) in O( η1 ) steps.
Theorem 9.5.3 (Phase I). Let {xη (t)}t∈N be the iterates of Normalized GD (9.4)
with LR η and xη (0) = xinit ∈ U . There is T1 > 0 such that for any T10 > T1 , it
holds that for sufficiently small η that (1) max kxη (t) − Φ(xinit )k ≤ O(η) and (2)
T1 ≤ηt≤T10
max Rj (xη (t)) ≤ O(η 2 ).
T1 ≤ηt≤T10 ,j∈[D]
Our main contribution is the analysis for the second phase (Theorem 9.5.4), which
says just like the quadratic case, the angle between x
eη (t) and the top eigenspace of
407
√
Figure 9.3: Illustration for two-phase dynamics of Normalized GD and GD on L on a
1D zero loss manifold Γ. For sufficiently small LR η, Phase I is close to Gradient Flow
and lasts for Θ(η −1 ) steps, while Phase II is close to the limiting flow which decreases
the sharpness of the loss and lasts for Θ(η −2 ) steps. GD iterate oscillates along the
top eigenvector of the Hessian with the period equal to two steps. (cf. Figure 8.1)
∇2 L(Φ(xη (t))), denoted by θt , will be O(η) on average. And as a result, the dynamics
of Normalized GD tracks the riemannian gradient flow with respect to log(λ1 (∇2 L(·)))
⊥
on manifold, that is, the unique solution of Equation (9.5), where Px,Γ is the projection
matrix onto the tangent space of manifold Γ at x ∈ Γ.
Z τ
1 ⊥
Limiting Flow: X(τ ) = Φ(xinit ) − PX(s),Γ ∇ log λ1 (X(s))ds, X(τ ) ∈ Γ
4 s=0
(9.5)
Note Equation (9.5) is not guaranteed to have a global solution, i.e., a well-defined
solution for all τ ≥ 0, for the following two reasons: (1). when the multiplicity of
top eigenvalue is larger than 1, λ1 (∇2 L(·)) may be not differentiable and (2). the
projection matrix is only defined on Γ and the equation becomes undefined when the
solution leaves Γ, i.e., moving across the boundary of Γ. For simplicity, we make
Assumption 9.5.2 that every point on Γ has a positive eigengap. Or equivalently, we
can work with a slightly smaller manifold Γ0 = {x ∈ Γ | λ1 (x) > λ2 (x)}.
Towards a mathematical rigorous characterization of the dynamics in the second
phase, we need to make the following modifications: (1). we add negligible noise
408
Algorithm 7 Perturbed Normalized Gradient Descent
Input: loss function L : RD → R, initial point xinit , maximum number of iteration
T , LR η, Frequency parameter Tfreq = Θ(η −0.1 ), noise parameter r = Θ(η 100 ).
for t = 1 to T do
Generate n(t) ∼ B0 (r) if t mod Tfreq = 0, else set n(t) = 0.
∇L(x(t))
x(t) ← x(t − 1) − η k∇L(x(t))k + n(t).
of magnitude O(η 100 ) every η −0.1 steps, (2). we assume for each η > 0, there exist
some step t = Θ(1/η) in phase I, except the guaranteed condition (1) and (2) (by
Theorem 9.5.3, the additional condition (3) also holds. This assumption is mild
T1 T10
because we only require (3) to hold for one step among Θ(1/η) steps from η
to η
,
where T1 is the constant given by Theorem 9.5.3 and T10 is arbitrary constant larger
than T1 . This assumption also holds empirically for all our experiments in Section 9.7.
Theorem 9.5.4 (Phase II). Let {xη (t)}t∈N be the iterates of perturbed Normalized GD
(Algorithm 7) with LR η. Under Assumptions 9.5.1 and 9.5.2, if the initialization xη (0)
satisfy that (1) kxη (0) − Φ(xinit )k ≤ O(η) where xinit ∈ U , (2) maxj∈[D] Rj (xη (t)) ≤
O(η 2 ), and additionally (3) min{|hv1 (xη (0)), xη (0) − Φ(xη (0))i| , −R1 (xη (0))} ≥ Ω(η),
then for any time T2 > 0 till which the solution of (9.5) exists, it holds for suffi-
ciently small η, with probability at least 1 − O(η 10 ), that kΦ(xη (bT2 /η 2 c)) − X(T2 )k =
P 2 /η2 c
O(η) and bT21/η2 c bTt=0 θt ≤ O(η), where θt ∈ [0, π2 ] denotes the angle between
∇2 L(Φ(xη (t)))(xη (t) − Φ(xη (t))) and top eigenspace of ∇2 L(Φ(xη (t))).
√
9.5.3 Results for GD on L
√
In this subsection, we denote the iterates of GD on L with LR η by xη (t), with
xη (0) ≡ xinit for all η:
√ √
GD on L: xη (t + 1) = xη (t) − η∇ L(xη (t)) (9.6)
409
√
Algorithm 8 Perturbed Gradient Descent on L
Input: loss function L : RD → R, initial point xinit , maximum number of iteration
T , LR η, Frequency parameter Tfreq = Θ(η −0.1 ), noise parameter r = Θ(η 100 ).
for t = 1 to T do
Generate n(t) ∼ B0 (r)√if t mod Tfreq = 0, else set n(t) = 0.
x(t) ← x(t − 1) − η∇ L(x(t)) + n(t).
Similar to Normalized GD, we will have two phases. The first theorem demonstrates
the movement in the manifold, when the iterate travels from xinit to a position that is
O(η) distance closer to the manifold. For convenience, we will denote the quantity
qP
M
p
2
i=j λi (x)hvi (x), x − Φ(x)i − η 1/2λj (x) by Rj (x) for all j ∈ [M ] and x ∈ U .
Theorem 9.5.5 (Phase I). Let {xη (t)}t∈N be the iterates of Normalized GD (9.6)
with LR η and xη (0) = xinit ∈ U . There is T1 ∈ R+ such that for any T10 ∈ R+ ,
it holds for sufficiently small η that (1) max kxη (t) − Φ(xinit )k ≤ O(η) and (2)
T1 ≤ηt≤T10
max
0
Rj (xη (t)) ≤ O(η 2 ).
T1 ≤ηt≤T1 ,j∈[D]
The next result demonstrates that close to the manifold, the trajectory implicitly
minimizes sharpness.
√
Theorem 9.5.6 (Phase II). Let {xη (t)}t∈N be the iterates of perturbed GD on L
(Algorithm 8). Under Assumptions 9.5.1 and 9.5.2, if the initialization xη (0) sat-
isfy that (1) kxη (0) − Φ(xinit )k ≤ O(η), where xinit ∈ U , (2) maxj∈[D] Rj (xη (t)) ≤
O(η 2 ), and additionally (3) min{|hv1 (xη (0)), xη (0) − Φ(xη (0))i| , −R1 (xη (t))} ≥ Ω(η),
then for any time T2 > 0 where the solution of (9.7) exists, it holds for suffi-
ciently small η, with probability at least 1 − O(η 10 ), that kΦ(xη (bT2 /η 2 c)) − X(T2 )k =
P 2 /η2 c
O(η 1/2 ) and bT21/η2 c bT
t=0 θt ≤ O(η 1/2 ), where θt ∈ [0, π2 ] denotes the angle between
p
∇2 L(Φ(xη (t)))(xη (t) − Φ(xη (t))) and top eigenspace of ∇2 L(Φ(xη (t))).
Z τ
1 ⊥
X(τ ) = Φ(xinit ) − PX(s),Γ ∇λ1 (X(s))ds, X(τ ) ∈ Γ. (9.7)
8 s=0
410
9.5.4 Operating on the Edge of Stability
√
In this section, we show that both Normalized GD on L and GD on L is on Edge
of Stability in their phase II, that is, at least in one of every two consecutive steps,
the stableness is at least 2 and the loss oscillates in every two consecutive steps.
Interestingly, the average loss over two steps still monotonically decreases, even when
operating on the edge of Stability (see Figure 9.1 for illustration), as indicated by
the following theorems. Note that Theorems 9.5.4 and 9.5.6 ensures that the average
√
of θt are O(η) and O( η). We defer their proofs into Sections 9.13.5 and 9.15.4
respectively.
Theorem 9.5.7 (Stableness, Normalized GD). Under the setting of Theorem 9.5.4,
η
by viewing Normalized GD as GD with time-varying LR ηt := k∇L(xη (t))k
, we have
p
[SL (xη (t), ηt )]−1 +[SL (xη (t+1), ηt+1 )]−1 = 1+O(θt +η). Moreover, we have L(xη (t))+
q
2
L(xη (t + 1)) = η λ1 (∇ L(x η (t)))
p
2
+ O(ηθt ).
√
Theorem 9.5.8 (Stableness, GD on L). Under the setting of Theorem 9.5.6, we
p p
have [S√L (xη (t), ηt )] ≥ Ω( θ1t ). Moreover, we have L(xη (t)) + L(xη (t + 1)) =
ηλ1 (∇2 L(xη (t))) + O(ηθt ).
We sketch the proof of the Normalized GD in phase I and II respectively in Section 9.6.2.
√
Then we briefly discuss how to prove the results for GD with L with same analysis
in Section 9.6.3. We start by introducing the properties of limit map of gradient flow
Φ in Section 9.6.1, which plays a very important role in the analysis.
9.6.1 Properties of Φ
The limit map of gradient flow Φ lies at the core of our analysis. When LR η is
small, one can show xη (t) will be O(η) close to manifold and Φ(xη (t)). Therefore,
411
Φ(xη (t)) captures the essential part of the implicit regularization of Normalized GD
and characterization of the trajectory of Φ(xη (t)) immediately gives us that of Φ(xη (t))
up to O(η).
Below we first recap a few important properties of Φ that will be used later this
section, which makes the analysis of Φ(xη (t)) convenient.
Lemma 9.6.1. Under Assumption 9.5.1, Φ satisfies the following two properties:
⊥
2. For any x ∈ Γ, if λ1 (x) > λ2 (x), ∂ 2 Φ(x)[v1 (x), v1 (x)] = − 12 Px,Γ ∇ log λ1 (x).
(Lemmas 9.10.16 and 9.10.18)
∇L(xη (t))
Note that xη (t + 1) − xη (t) = −η k∇L(x η (t))k
, using a second order taylor expansion
of Φ, we have
where we use the first claim of Lemma 9.6.1 in the final step. Therefore, we have
Φ(xη (t + 1)) − Φ(xη (t)) = O(η 2 ), which means Φ(xη (t)) moves slowly along the
manifold, at a rate of at most O(η 2 ) step. The Taylor expansion of Φ, (9.9) plays a
crucial role in our analysis for both Phase I and II and will be used repeatedly.
Analysis for Phase I, Theorem 9.5.3: The Phase I itself can be divided into
two subphases: (A). Normalized GD iterate xη (t) gets O(η) close to manifold; (B).
counterpart of preparation phase in the quadratic case: local movement in the O(η)-
412
neighborhood of the manifold which decreases Rj (xη (t)) to O(η 2 ). Below we sketch
their proofs respectively:
• Subphase (A): First, with a very classical result in ODE approximation theory,
normalized GD with small LR will track the normalized gradient flow, which is
a time-rescaled version of standard gradient flow, with O(η) error, and enter a
small neighborhoods of the manifold where Polyak-Lojasiewicz (PL) condition
holds. Since then, Normalized GD decreases the fast loss with PL condition and
the gradient has to be O(η) small in O( η1 ) steps. (See details in Section 9.11.1).
With a similar proof technique, we show xη (t) enters ainvariant set around the
manifold Γ, that is, {x ∈ U | Rj (x) ≤ O(η 2 ), ∀j ∈ [D]}. Formally, we show the
following analog of Lemma 9.3.3:
Analysis for Phase II, Theorem 9.5.4: Similar to the subphase (B) in the Phase
I, the high-level idea here is again that xη (t) locally evolves like normalized GD with
quadratic loss around Φ(xη (t)) and with an argument similar to the alignment phase
413
of quadratic case (though technically more complicated), we show xη (t) − Φ(xη (t))
approximately aligns to the top eigenvector of ∇2 L(Φ(xη (t))), denoted by v1 (t) and so
does ∇L(xη (t)). More specifically, it corresponds to the second claim in Theorem 9.5.4,
P 2 /η2 c
that bT21/η2 c bT
t=0 θt ≤ O(η).
We now have a more detailed look at the movement in Φ. Since Φ(xη (t)) belongs to
the manifold, we have ∇L(Φ(xη (t))) = 0 and so ∇L(xη (t)) = ∇2 L(Φ(xη (t)))(xη (t) −
Φ(xη (t))) + O(η 2 ) using a Taylor expansion. This helps us derive a relation between
the Normalized GD update and the top eigenvector of the hessian (simplified version
of Lemma 9.10.9):
∇L(xη (t))
∃s ∈ {±1}, = sv1 (t) + O(θt + η). (9.10)
k∇L(xη (t))k
Incorporating the above into the movement in Φ(xη (t)) from Equation (9.9) gives:
η2 2
Φ(xη (t + 1)) − Φ(xη (t)) = ∂ Φ(xη (t))[v1 (t), v1 (t)] + O(η 2 θt + η 3 ) (9.11)
2
Applying the second property of Lemma 9.6.1 on Equation (9.11) above yields
Lemma 9.6.3.
η2 ⊥
Φ(xη (t + 1)) − Φ(xη (t)) = − Pt,Γ ∇ log λ1 (t) + O(η 3 + η 2 θt ).
4
To complete the proof of Theorem 9.5.4, we show that for small enough η, the
P 2 /η2 c
trajectory of Φ(xη (τ /η 2 )) is O(η 3 bT2 /η 2 c+η 2 bT
t=0 θt )-close to X(τ ) for any τ ≤ T2 ,
414
PbT2 /η2 c
where X(·) is the flow given by Equation (9.5). This error is O(η), since t=0 θt =
O(bT2 /η 2 cη).
One technical difficulty towards showing the average of ηt is only O(η) is that
our current analysis requires |hv1 (xη (t)), xη (t) − Φ(xη (t))i| doesn’t vanish, that is, it
remains Ω(η) large throughout the entire training process. This is guaranteed by
Lemma 9.3.4 in quadratic case – since the alignment monotone increases whenever
λ1
it’s smaller 2
, but the analysis breaks when the loss is only approximately quadratic
and the alignment |hv1 (xη (t)), xη (t) − Φ(xη (t))i|could decrease decrease by O(θt η 2 )
per step. Once the alignment becomes too small, even if the angle θt is small, the
normalized GD dynamics become chaotic and super sensitive to any perturbation.
Our current proof technique cannot deal with this case and that’s the main reason we
have to make the additional assumption in Theorem 9.5.4.
Role of η 100 noise. Fortunately, with the additional assumption that the
initial alignment is at least Ω(η), we can show adding any poly(η) perturbation
(even as small as Ω(η 100 )) suffices to prevent the aforementioned bad case, that is,
|hv1 (xη (t)), xη (t) − Φ(xη (t))i| stays Ω(η) large. The intuition why Ω(η 100 ) perturba-
e = cv1 for any |c| ≤ 1
tion works again comes from quadratic case – it’s clear that x
is a stationary point for two-step normalized GD updates for quadratic loss under
the setting of Section 9.3. But if c is smaller than critical value determined by the
eigenvalues of the hessian, the stationary point is unstable, meaning any deviation
away from the top eigenspace will be amplified until the alignment increases above
the critical threshold. Based on this intuition, the formal argument, Lemma 9.13.11
uses the techniques from the ‘escaping saddle point’ analysis [252]. Adding noise is
not necessary in experiments to observe the predicted behavior (see ‘Alignment’ in
Figure 9.4 where no noise is added). On one hand, it might be because the floating
point errors served the role of noise. On the other hand, we suspect it’s not necessary
415
even for theory, just like GD gets stuck at saddle point only when initialized from a
zero measure set even without noise [119, 121].
√
9.6.3 Analysis for GD on L
In this subsection we will make an additional assumption that L(x) = 0 for all x ∈ Γ.
The analysis then will follow a very similar strategy as the analysis for (Normalized)
GD. However, the major difference from the analysis for Normalized GD comes from
the update rule for xη (t) when it is O(η)-close to the manifold:
√ p
∃s ∈ {±1}, ∇ L(xη (t)) = s λ1 (t)v1 (t) + O(η + θt ).
p
Thus, the effective learning rate is λ1 (t)η at any step t. This shows up, when we
compute the change in the function Φ. Thus, we have the following lemma showcasing
√
the movement in the function Φ with the GD update on L:
Lemma 9.6.4 (Movement in the manifold, Informal version of Lemma 9.15.1). Under
the setting in Theorem 9.5.6, for sufficiently small η, we have at any step t ≤ bT2 /η 2 c,
2
Φ(xη (t + 1)) − Φ(xη (t)) = − η8 Pt,Γ
⊥
∇λ1 (t) + O(η 3 + η 2 θt ).
9.7 Experiments
416
Top Eigenvalue Alignment 4.0Lower bound on Stableness Test Acc
Square root Loss 1.0
20 3.5 50
Normalized GD 0.9 3.0
15 0.8 2.5 40
0.7 2.0
10 30
1.5
0.6
5 Square root Loss 1.0 Square root Loss 20 Square root Loss
0.5 Normalized GD 0.5 Normalized GD Normalized GD
10
0 5000 10000 15000 0 5000 10000 15000 0.0 0 5000 10000 15000 0 5000 10000 15000
Gradient Steps
Figure 9.4: We verify our theoretical claims in the second phase —(a) the sharpness
decreases; (b) gradient aligns with the top eigenvector of Hessian; (c) stableness will
be higher than 2 — under the setting √ of training VGG-16 on CIFAR-10 dataset with
Normalized GD on L and GD with L loss respectively.
417
140
Top Eigenvalue Hessian trace Test loss Test Acc
0.32 90.2
Normalized GD 4000
120 Riemannian Flow 90.1
3500 0.31
100 90.0
80 3000
0.30 89.9
60 2500
89.8
0.29
40 2000 89.7
20 1500 0.28 89.6
0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3
Continuous time
Figure 9.5: Normalized GD and Riemannian flow have almost the same behavior
under proper time scalings, for a 2-layer network on MNIST initialized with tiny loss.
0.020
0.015
0.010
0 1 2 3
Continuous time
Figure 9.6: The trajectory of Normalized GD is very close to that of the limiting flow
minimizing the sharpness on manifold, as predicted by our theory. Absolute difference
is the norm of the difference between the parameters of the two trajectories at the
same continuous time, while relative parameter difference is the ratio of the norm of
the difference to the norm of parameters of each runs.
418
Verifying Convergence to Limiting Flow on MNIST: We further verify the
closeness between the Riemannian gradient flow w.r.t. the top eigenvalue and Nor-
malized GD, as predicted by Theorem 9.5.4, on a 1 hidden-layer fully connected
network on MNIST [256]. The network had 784 hidden units, with GeLU activation
function. We use `2 loss to ensure the existence of minimizers, which is necessary for
the existence of the manifold. For efficient training on a single GPU, we train on a
random subset of training data of size 1000.
We first trained the model with full to reach loss of order 10−3 . Starting from
this checkpoint, we make two different runs, one for Normalized GD and another for
Riemannian gradient flow w.r.t. the top eigenvalue (see Section 9.16 for details). We
plot the behavior of the network w.r.t. continuous time defined for Normalized GD as
#GradientSteps × η 2 /4, and for Riemannian flow as #GradientSteps × η, where η is
the learning rate. We track the behavior of Test Loss, Test accuracy, the top eigenvalue
of the Hessian and also the trace of the Hessian in Figure 9.5. We see that there is an
exact match between the behavior of the four functions, which supports our theory.
Moreover, Figure 9.6 computes the norm of the difference in the parameters between
the two runs, and shows that the runs stay close to each other in the parameter space
throughout training.
One limitation of our analysis is that it only applies close to the manifold of local
minimizers. In contrast, in experiments the EoS phenomenon, including the control of
sharpness, begins much sooner. Addressing this gap, as well as analying the EoS for
√
the loss L itself (as opposed to L as done here) is left for future work. Very likely
this will require novel understanding of properties of deep learning losses, which we
√
were able to circumvent by looking at L instead. Exploration of EoS-like effects in
419
SGD setting would also be interesting, although we first need definitive experiments
analogous to Cohen et al. [234].
tions
We first recall the settings and notations. Let A be a positive definite matrix. Without
loss of generality, we can assume A is diagonal, i.e., A = diag(λ1 , λ2 , . . . , λD ) ∈ RD×D ,
where λ1 > λ2 ≥ λ3 ≥ . . . ≥ λD > 0 and the eigenvectors are the standard basis
vectors e1 , · · · , eD of the D-dimensional space. We will denote P (j:D) = D >
P
i=j ei ei as
x
e(t)
x e(t) − A
e(t + 1) = x . (9.2)
ke
x(t)k
Theorem 9.3.1. If |hv1 , x 6 0, ∀t ≥ 0, then there exists 0 < C < 1 and s ∈ {±1}
e(t)i| =
such that limt→∞ x e(2t + 1) = (C − 1)sλ1 v1 .
e(2t) = Csλ1 v1 and limt→∞ x
√
We also note that GD on L with any LR η can also be reduced to update
rule (9.2), as shown in the discussion at the end of Section 9.3.
420
9.9.1 Proofs for Preparation Phase
P (j:D) A
(j:D) (j:D) (j:D) x
e(t)
P x
e(t + 1) = P e(t) − P
x A = I− P (j:D) x
e(t),
ke
x(t)k kex(t)k
which implies
P (j:D) A
P (j:D) x
e(t + 1) ≤ I − P (j:D) x
e(t) . (9.12)
kex(t)k
λj P (j:D) A P (j:D) A λj
− I4− 4I− 4I4 I.
kP (j:D) x
e(t)k kex(t)k kex(t)k kP (j:D) x
e(t)k
P (j:D) A λj
Therefore I − ≤ and thus we conclude P (j:D) x
e(t + 1) ≤ λj .
kex(t)k kP (j:D) xe(t)k
P (j:D) A
Proof of Lemma 9.9.2. Since λj ≤ P (j:D) x
e(t) ≤ ke
x(t)k, we have 0 4 I − kex(t)k
4
λD P (j:D) A λD
1− ke
x(t)k
. Therefore I − kex(t)k
≤1− ke
x(t)k
. The proof is completed by plugging
this into Equation (9.12).
421
Lemma 9.9.2 has the following two direct corollaries.
ke
x(0)k−λ1
e(0) and t ≥
Corollary 9.9.3. For any initialization x λD
, ke
x(t)k ≤ λ1 , that is,
e(t) ∈ I1 .
x
Proof of Corollary 9.9.3. Set j = 1 in Lemma 9.9.2, it holds that ke x(t + 1)k ≤
kex(0)k−λ1
ke
x(t)k − λD whenever ke x(t)k ≥ λ1 . Thus x
e( λD
) ≤ λ1 . The proof is
completed as I1 is an invariant set by Lemma 9.9.1.
−T
λD λj
P (j:D) x
e(T ) ≤ e λ1
P (j:D) x
e(0) ≤ ke
x(0)k ≤ λj .
λ1
The proof is completed since Ij is a invariant set for any j ∈ [D] by Lemma 9.9.1.
In this subsection, we analyze how normalized GD align to the top eigenvector once
e(t) ∈ ∩D
it goes through the preparation phase, meaning x j=1 Ij for all t in alignment
phase.
λ1
Lemma 9.3.5. For any t with x e(t) ∈ ∩D
j=1 Ij , if ke
x(t)k > 2
, then ke
x(t + 1)k ≤
λ2
max λ21 − 2λD1 , λ1 − ke
x(t)k .
422
Proof. The update at step t as:
x(t)k − λ1 )e
(ke x1 (t)
1 1 (ke
x (t)k − λ 2 )e
x 2 (t)
x
e(t + 1) = x(t)k I − A) x
(ke e(t) = .
ke
x(t)k ke
x(t)k
..
.
x(t)k − λD )e
(ke xD (t)
Let the index k be the smallest integer such that λk+1 < 2 ke
x(t)k − λ1 . If no such
index exists, then one can observe that ke
x(t + 1)k ≤ λ1 − ke
x(t)k. Assuming that such
an index exists in [D], we have λk ≥ 2 ke
x(t)k − λ1 and ke
x(t)k − λj ≤ λ1 − ke
x(t)k,
∀j ≤ k. Now consider the following vectors:
v (2) (t) := (2 ke
x(t)k − λ1 − λk )P (k:D) x
e(t),
By definition of k, | ke
x(t)k − λj | ≤ | ke
x(t)k − λ1 |. Thus
(ke x(t)k − λ1 )e x1 (t)
..
.
(ke
x (t)k − λ )e
x (t)
1 1 k
ke
x(t + 1)k ≤
ke
x(t)k
x(t)k − λk+1 )e
(ke xk+1 (t)
.
..
(kex(t)k − λD )e xD (t)
1
= v (1) (t) + v (2) (t) + . . . + v (D−k+2) (t)
ke
x(t)k
1
v (1) (t) + v (2) (t) + . . . + v (D−k+2) (t) .
≤
ke
x(t)k
423
e(t) ∈ ∩D
By assumption, we have x j=1 Ij . Thus
v (2) (t) ≤ (2 ke
x(t)k − λ1 − λk )λk
Hence,
X X
v (j) (t) = (2 ke
x(t)k − λ1 − λk )λk + (λj − λj+1 )λj+1
j≥2 j≥k
X X
= (2 ke
x(t)k − λ1 )λk + λj λj+1 − λ2j
j≥k j≥k
(2 ke 2
x(t)k − λ1 ) + λ2k X λ2j + λ2j+1 X
≤ + − λ2j
2 j≥k
2 j≥k
2
(2 ke
x(t)k − λ1 ) λ2D
≤ − ,
2 2
1
v (1) (t) + v (2) (t) + . . . + v (D−k+1) (t)
ke
x(t + 1)k ≤
kex(t)k
x(t)k − λ1 )2
(2 ke λ2D
≤ − + λ1 − ke x(t)k
2 ke
x(t)k 2 ke
x(t)k
λ2 − λ2D
x(t)k + 1
= ke − λ1
2 kex(t)k
λ1 λ2
≤ − D,
2 2λ1
λ1
where the final step is because 2
≤ ke
x(t)k ≤ λ1 and that the maximal value of a
convex function is attained at the boundary of an interval.
424
λi
Lemma 9.9.5. At any step t and i ∈ [D], if ke
x(t)k T 2
, then |e
xi (t + 1)| S |e
xi (t)|,
where T denotes larger than, equal to and smaller than respectively. (Same for S, but
in the reverse order)
λi
Proof. From the Normalized GD update rule, we have x ei (t) 1 −
ei (t+1) = x ke
x(t)k
, for all i ∈
[D]. Thus
λ1 λ1
S 2 ⇐⇒ 1 − S 1 ⇐⇒ |e
xi (t + 1)| S |e
xi (t)| ,
ke
x(t)k ke
x(t)k
λ1
Lemma 9.9.6. At any step t, if ke
x(t)k ≤ 2
, then
λ λ
(λ1 − ke
x(t)k) cos θt ≤ ke
x(t + 1)k ≤ λ1 − ke
x(t)k − 1− λ1 sin2 θt ,
2λ1 λ1
kP (2:D) xe(t)k
where θt = arctan and λ = min(λ1 − λ2 , λD ).
|e>1 xe(t)|
Proof. We first show that the left side inequality holds by the following update rule
for he1 , x
e(t)i:
he1 , x
e(t)i
he1 , x x(t)k − λ1 )
e(t + 1)i = (ke .
kex(t)k
Since ke
x(t + 1)k ≥ |he1 , x
e(t + 1)i| and θt denotes the angle between e1 and x
e(t + 1),
we get the left side inequality.
Now, we focus on the right hand side inequality. First of all, the update in the
coordinate j ∈ [2, D] is given by
hej , x
e(t)i
hej , x x(t)k − λj )
e(t + 1)i = (ke .
kex(t)k
425
Then, we have
D
X
2
ke
x(t + 1)k = e(t + 1)i2
hej , x
j=1
D 2
X
2 hej , x
e(t)i
= x(t)k − λj )
(ke
j=1
kex(t)k
D 2
2 2
X hej , x
e(t)i 2
x(t)k − λ1 ) cos θt +
= (ke x(t)k − λj )
(ke
j=2
kex(t)k
D 2
2 2 2
X hej , x
e(t)i
≤ (ke
x(t)k − λ1 ) cos θt + (ke
x(t)k − λ)
j=2
kex(t)k
where in the fourth step, we have used λ = argmaxλi |2≤i≤D |ke x(t)k − λi | . The final
√
x(t)k < λ21 . Hence, using the fact that 1 − y ≤ 1 − y/2 for any y ≤ 1, we
step uses ke
have
1
ke
x(t + 1)k ≤ λ1 − ke
x(t)k − λ(λ1 − λ) sin2 θt
2(λ1 − ke
x(t)k)
λ λ
≤ λ1 − ke
x(t)k − 1− λ1 sin2 θt ,
2λ1 λ1
λ1
where again in the final step, we have used ke
x(t)k < 2
. The above bound can be
further bounded by
λ λ
ke
x(t + 1)k ≤ λ1 − ke
x(t)k − 1− λ1 sin2 θt
2λ1 λ1
λ0 λ0
1
≤ λ1 − ke
x(t)k − min 1− λ1 sin2 θt
2 λ ∈{λ2 ,λD } λ1
0 λ1
1 λ λ
= λ1 − ke
x(t)k − 1− λ1 sin2 θt ,
2 λ1 λ1
426
where we have used λ = min(λ1 − λ2 , λD ).
λ1 λ λ
ke
x(t)k ≤ =⇒ |e x1 (t)| (1 + 2 (1 − ) sin2 θt ),
x1 (t + 2)| ≥ |e
2 λ1 λ1
kP (2:D) xe(t)k
where θt = arctan , and λ = min(λ1 − λ2 , λD ).
|e>1 xe(t)|
λ1 λ1
e1 (t + 1) = 1 −
x x
e1 (t), e1 (t + 2) = 1 −
x x
e1 (t + 1).
ke
x(t)k ke
x(t + 1)k
λ1 λ1
|e
x1 (t + 2)| = 1 − 1− |e
x1 (t)|
ke
x(t)k ke
x(t + 1)k
λ2 − λ1 (ke
x(t)k − ke
x(t + 1)k)
= 1+ 1 |e
x1 (t)|
ke
x(t)k ke
x(t + 1)k
≥ |e
x1 (t)| ,
427
Hence, retracing the steps we followed before, we have
λ21 − λ1 (ke
x(t)k + kex(t + 1)k)
|e
x1 (t + 2)| = 1 + |e
x1 (t)|
ke
x(t)k kex(t + 1)k
λ(λ1 − λ) sin2 θt
≥ 1+ |e
x1 (t)|
kex(t)k kex(t + 1)k
λ λ
≥ 1 + 2 (1 − ) sin2 θt |e x1 (t)| ,
λ1 λ1
1. Preparation phase: e(t) enters and stays in an invariant set around the
x
origin, that is, ∩D x| D e(t)i2 ≤ λ2j }. (See Lemma 9.3.3,
P
j=1 Ij , where Ij := {e i=j hei , x
which is a direct consequence of Lemmas 9.9.1 and 9.9.1 and corollary 9.9.3.)
Below we elaborate the convergence argument in the alignment phase. For con-
venience, we will use θt to denote the angle between e1 and x
e(t) and we assume
e(0) ∈ ∩D Ij without loss of generality. We first define S := {t ∈ N | ke λ1
j=1 x(t)k ≤ 2
}
and S 0 := {t ∈ S | t + 2 ∈ S}. The result in alignment phase says that 1
λ1
|e
x1 (t)|
monotone increases and converges to some constant C ∈ (0, 12 ] among all t ∈ S, thus
|e
x1 (t+2)|
lim |e
x1 (t)|
= 1. By Lemma 9.9.7, we have lim θt = 0. Since the one-step
t→∞,t∈S 0 t→∞,t∈S 0
428
update function F (e e − A kexxek is uniformly lipschitz when ke
x) = x xk is bounded away
from zero, we know lim θt+k = 0, ∀k ∈ N.
t→∞,t∈S 0
e(t + 2) − x
flips its sign per step and thus lim x e(t) = 0, lim ke
x(t + 1)k + ke
x(t)k = λ1 .
t→∞ t→∞
1 λ1
If C = 2
, then we must have lim ke
x(t)k = 2
and we are done in this case. If C < 12 ,
t→∞
case)
For a general loss function L satisfying Assumption 9.5.1, the loss landscape looks like
a strongly convex quadratic function locally around its minimizer. When sufficient
small learning rate, the dynamics will be sufficiently close to the manifold and behaves
like that in quadratic case with small perturbations. Thus it will be very useful to
have more refined analysis for the quadratic case, as they allow us to bound the error
in the approximate quadratic case quantitatively. Lemmas 9.9.8 to 9.9.11 are such
examples. Note that they are only used in the proof of the general loss case, but not
in the quadratic loss case.
Lemma 9.9.8 is a slightly generalized version of Lemma 9.3.5.
429
λ2D
Lemma 9.9.8. Suppose at time t, P (j:D) x
e(t) ≤ λj (1 + λ21
), for all j ∈ [D], if
λ1 λ1
ke
x(t)k > 2
, then ke
x(t + 1)k ≤ 2
.
Proof of Lemma 9.9.8. The proof is similar to the proof of Lemma 9.3.5. Let the
index k be the smallest integer such that λk+1 < 2 ke
x(t)k − λ1 . If no such index exists,
then one can observe that ke
x(t + 1)k ≤ λ1 − ke
x(t)k. Assuming that such an index
exists in [D], we have λk ≥ 2 ke
x(t)k − λ1 and ke
x(t)k − λj ≤ λ1 − ke
x(t)k, ∀j ≤ k. With
λ2D
e(t) ∈ ∩D
the same decomposition and estimation, since x j=1 (1 + λ21
)Ij , we have
Thus we conclude
1
v (1) (t) + v (2) (t) + . . . + v (D−k+1) (t)
ke
x(t + 1)k ≤
ke
x(t)k
λ1 λ2 λ2 λ1
≤ (1 − D2 )(1 + 21 ) ≤ ,
2 λ1 λD 2
• |he1 , x
e(t)i| ≤ (1 − 2c)g(λk ).
p
• θt ≤ c |he1 , x
e(t)i|,
kP (2:D) (ex(t))k
where θt = arctan |he1 ,e
x(t)i|
.
430
Then, we have
hek , x
e(t + 2)i hek , x
e(t)i
≥ (1 + c) .
he1 , x
e(t + 2)i he1 , x
e(t)i
Proof of Lemma 9.9.9. From the quadratic update, we have the update rule as:
λk
x ek (t) 1 −
ek (t + 1) = x , for all k ∈ {1, . . . , D}.
ke
x(t)k
hek , x
e(t + 2)i λ1 − λk λ1 − λk hek , x
e(t)i
= 1− 1−
he1 , x
e(t + 2)i λ1 − ke
x(t)k λ1 − kex(t + 1)k he1 , x e(t)i
(λ1 − λk )(λ1 + λk − ke
x(t)k − kex(t + 1)k) hek , x e(t)i
= 1− .
(λ1 − kex(t + 1)k)(λ1 − kex(t)k) he1 , x
e(t)i
(λ1 − λk )(λ1 + λk − ke
x(t)k − ke
x(t + 1)k)
≥ 2 + c,
(λ1 − kex(t + 1)k)(λ1 − ke
x(t)k)
we must have
hek , x
e(t + 2)i hek , x
e(t)i
≥ (1 + c) .
he1 , x
e(t + 2)i he1 , x
e(t)i
We can use (λ1 − ke
x(t)k) cos θt ≤ ke
x(t + 1)k ≤ λ1 − ke
x(t)k − λ
2λ1
1− λ
λ1
λ1 sin2 θt ,
where λ = min(λ1 − λ2 , λD ) from Lemma 9.9.6 to show the following with additional
algebraic manipulation:
(λ1 − λk )(λ1 + λk − ke
x(t)k − ke
x(t + 1)k) (λ1 − λk )λk
≥ .
(λ1 − kex(t + 1)k)(λ1 − ke
x(t)k) (λ1 − (λ1 − ke
x(t)k) cos θt )(λ1 − ke
x(t)k)
431
Hence, it suffices to show that
(λ1 − λk )λk
≥ 2 + c.
(λ1 − (λ1 − ke
x(t)k) cos θt )(λ1 − ke
x(t)k)
p
where the last step we use that |θt | ≤ c |he1 , x
e(t)i|, we only need
r
λ1
Lemma 9.9.10. Consider the function g : R → R, with g(λ) = 2
1− 1− 2 λλ1 1− λ
λ1
.
Consider any coordinate 2 ≤ k ≤ D. For any constant 0 < c < 4 λλk1 (1 − λk
λ1
), consider
e(t) ∈ ∩D
any t with x j=1 Ij , with x
e(t) satisfying
0.5λ1 ≥ ke
x(t)k ≥ (1 + c)g(λk ).
hek , x
e(t + 2)i hek , x
e(t)i
≤ (1 − 0.5c) ,
he1 , x
e(t + 2)i he1 , x
e(t)i
432
Proof. By the Normalized GD update, we have:
λk λk
! !
hek , x
e(t + 2)i 1− ke
x(t+1)k
1− ke
x(t)k hek , x
e(t)i
= λ1 λ1
he1 , x
e(t + 2)i 1− ke
1−
x(t+1)k ke
x(t)k
he1 , x
e(t)i
(λ1 − λk )(λ1 + λk − ke
x(t)k − ke
x(t + 1)k) hek , xe(t)i
= 1− .
(λ1 − kex(t + 1)k)(λ1 − ke
x(t)k) he1 , x
e(t)i
(9.13)
λ1
1. If ke
x(t)k ≥ λk , which is only possible when λk ≤ 2
, we find that
ratio(λ1 , λk , ke
x(t)k , ke
x(t + 1)k) is a monotonically decreasing function w.r.t.
ke
x(t + 1)k, keeping other terms fixed. Using the fact that ke
x(t + 1)k ≤
λ1 − ke
x(t)k from Lemma 9.9.6, we can bound the term as:
2. If ke
x(t)k ≤ λk , we find that ratio(λ1 , λk , ke
x(t)k , ke
x(t + 1)k) is a monotonically
increasing function w.r.t. ke
x(t + 1)k, keeping other terms fixed. Using the fact
433
that ke
x(t + 1)k ≤ λ1 − ke
x(t)k from Lemma 9.9.6, we can bound the term as:
Continuing in the similar way as the previous case, we show that ratio(λ1 , λk , a, 0)
is at least 1−(λk /λ1 )2 in the range ((1+c)g(λk ), min(0.5λ1 , λk )). ratio(λ1 , λk , a, λ1 −
a) is maximized in the range ((1 + c)g(λk ), min(0.5λ1 , λk )) at a = (1 + c)g(λk )
λk (λ1 −λk )
and is given by (1+c)g(λk )(λ1 −(1+c)g(λk ))
. From the definition of λk , we observe that
λ1 − (1 + c)g(λk ) is atleast (1 − 4c )(λ1 − (1 + c)g(λk )) for any c ∈ (0, 1). Thus,
we have
λk (λ1 − λk ) 1 λk (λ1 − λk )
≤ c
(1 + c)g(λk )(λ1 − (1 + c)g(λk )) (1 + c)(1 − 4 ) g(λk )(λ1 − g(λk ))
2
= ≤ 2 − 0.5c,
(1 + c)(1 − 4c )
where the final step holds true for any c ∈ (0, 1).
λk λk λk λk λk 2
2 (1 − ) ≤ min 4 (1 − ), 1 − ( )
λ1 λ1 λ1 λ1 λ1
(λ1 − λk )(λ1 + λk − ke
x(t)k − ke
x(t + 1)k)
≤ ≤ 2 − 0.5c.
(λ1 − kex(t + 1)k)(λ1 − ke
x(t)k)
λ1
Lemma 9.9.11. At any step t, if ke
x(t)k ≤ 2
,
λ1
2. |tan(∠(e
x(t + 2), e1 ))| ≤ ke
x(t)k
|tan(∠(e
x(t), e1 ))|.
434
Proof of Lemma 9.9.11. From the Normalized GD update rule, we have
λi
x ei (t) 1 −
ei (t + 1) = x , for all i ∈ [D],
ke
x(t)k
1
implying |e
xi (t + 1)| < 1− ke
x(t)k
|e
xi (t)| for all i ∈ [2, D], since λi < 1.
λ1
Since λi < λ1 and ke
x(t)k ≤ 2
, it holds that
λi
|e
xi (t + 1)| 1− ke
x(t)k |e
xi (t)| λ1 − λi |e
xi (t)| λi λi |e
xi (t)|
= λ1
= 1− ≤ max( , 1 − 2 ) .
|e
x1 (t + 1)| 1− ke
x(t)k
|e
x1 (t)| λ1 − ke
x(t)k |e
x1 (t)| λ1 λ1 |e
x1 (t)|
Finally we conclude
P (2:D) x
e(t + 1) λ2 λD P (2:D) x
e(t)
≤ max( , 1 − 2 ) .
|e
x1 (t + 1)| λ1 λ1 |e
x1 (t)|
kP (2:D) vk
Recall |tan(∠(v, e1 ))| = |he1 ,vi|
for any vector v, the first claim follows from re-
arranging the terms.
For the second claim, it suffices to apply the above inequality to t + 1, which yields
that
λ2 λ1 − λi λ1
|tan(∠(e
x(t + 2), e1 ))| ≤ max( , − 1) |tan(∠(e
x(t + 1), e1 ))| ≤ |tan(∠(e
x(t +
λ1 λ1 − kxt+1 k λ1 − kxt+1 k
Before we start the analysis for Normalized GD for general loss functions in Section 9.11,
we need to introduce some new notations and terminologies to complete the formal
435
setup. We start by first recapping some core assumptions and definitions in the main
paper and provide the missing proof in the main paper.
Z τ
Φ(x) = lim φ(x, τ ), where φ(x, τ ) = x − ∇L(φ(x, s))ds. (9.14)
τ →∞ 0
Let U be the sets of points starting from which, gradient flow w.r.t. loss L converges
to some point in Γ, that is, U := {x ∈ RD | Φ(x) exists and Φ(x) ∈ Γ}. We have that
U is open and Φ is C 3 on U . (By Lemma 8.8.2)
For a matrix A ∈ RD×D , we denote its eigenvalue-eigenvector pairs by
{λi (A), vi (A))}i∈[D] . For simplicity, whenever Φ is defined and C 2 at point x,
we use {(λi (x), vi (x))}D 2
i=1 to denote the eigenvector-eigenvalue pairs of ∇ L(Φ(x)),
1:D
Therefore, Px,Γ = Px,Γ by Lemma 9.10.17. Additionally, for any x ∈ U , we use θ(x)
436
to denote the angle between x
e and the top eigenspace of the hessian at Φ(x), i.e.
(2:M )
PΦ(x),Γ x
e
θ(x) = arctan |hv1 (x),e
xi|
. Furthermore, when the iterates x(t) is clear in the context,
⊥ ⊥
we use shorthand λi (t) := λi (x(t)), vi (t) := vi (x(t)), Pt,Γ := PΦ(x(t)),Γ , Pt,Γ := PΦ(x(t)),Γ
and θt to denote θ(x(t)). We define the function gt : R → R for every t ∈ N as
s !
1 λ λ
gt (λ) = 1− 1−2 1− .
2 λ1 (t) λ1 (t)
Given any two points x, y, we use xy to denote the line segment between x and y,
i.e., {z | ∃λ ∈ [0, 1], z = (1 − λ)x + λy}.
The main result of this chapter focuses on the trajectory of Normalized GD
from fixed initialization xinit with LR η converges to 0, which can be roughly split
into two phases. In the first phase, Theorem 9.5.3 shows that the normalized GD
trajectory converges to the gradient flow trajectory, φ(xinit , ·). In second phase,
Theorem 9.5.4 shows that the normalized GD trajectory converges to the limiting flow
which decreases sharpness on Γ, (9.5). Therefore, for sufficiently small η, the entire
trajectory of normalized GD will be contained in a small neighbourhood of gradient
flow trajectory Z and limiting flow trajectory Y . The convergence rate given by our
proof depends on the various local constants like smoothness of L and Φ in this small
neighbourhood, which intuitively can be viewed as the actual ”working zone” of the
algorithm. The constants are upper bounded or lower bounded from zero because this
”working zone” is compact after fixing the stopping time of (9.5), which is denoted by
T2 .
Z τ
1 ⊥
X(τ ) = Φ(xinit ) − PX(s),Γ ∇ log λ1 (X(s))ds, X(τ ) ∈ Γ (9.5)
4 s=0
Below we give formal definitions of the ”working zones” and the corresponding
properties. For any point y ∈ RD and positive number r, we define Br (y) := {x ∈
RD | ky − xk < r} as the open `2 norm ball centered at y and B r (y) as its closure. For
437
any set S and positive number r, we define S r := ∪y∈S Br (y) and B r (S) := ∪y∈S B r (y).
Given the stopping time T2 > 0, we denote the trajectory of limiting flow Equation (9.5)
{X(τ )}Tτ =0
2
by Y and we use the notation Y r := ∪y∈Y B y (r) for any r > 0. By definition,
Y r are compact for any r > 0.
We construct the ”working zone” of the second phase, Y ρ and Y in Lemmas 9.10.2
and 9.10.5 respectively, where 0 < < ρ, implying Y ⊂ Y ρ . The reason that we
need the two-level nested ”working zones” is that even though we can ensure all the
points in Y ρ have nice properties as listed in Lemma 9.10.2, we cannot ensure the
trajectory of gradient flow from x ∈ Y ρ to Φ(x) or the line segment xΦ(x) is in Y ρ ,
which will be crucial for the geometric lemmas (in Section 9.10.1) that we will heavily
use in the trajectory analysis around the manifold. For this reason we further define
Y and Lemma 9.10.5 guarantees the trajectory of gradient flow from x to Φ(x) or
the line segment xΦ(x) whenever x ∈ Y ρ .
Definition 9.10.1 (PL condition). A function L is said to be µ-PL in a set U iff for
all x ∈ U ,
k∇L(x)k2 ≥ 2µ(L(x) − inf L(x)).
x∈U
For convenience, we define ∆ := 12 inf x∈Y λ1 (∇2 L(x)) − λ2 (∇2 L(x))) and µ :=
1
4
inf x∈Y λM (∇2 L(x)). By Assumption 9.5.1, we have µ > 0. By Assumption 9.5.2,
∆ > 0.
Lemma 9.10.2. Given Y , there are sufficiently small ρ > 0 such that
1. Y ρ ∩ Γ is compact;
2. Y ρ ⊂ U ;
3. L is µ-PL on Y ρ ; (see Definition 9.10.1)
4. inf x∈Y ρ λ1 (∇2 L(x)) − λ2 (∇2 L(x))) ≥ ∆ > 0;
5. inf x∈Y ρ λM (∇2 L(x)) ≥ µ > 0.
438
Proof of Lemma 9.10.2. We first claim for every y ∈ Y , for all sufficiently small
ρy > 0 (i.e. for all ρy smaller than some threshold depending on y), the following
three properties hold (1) B y (ρy ) ∩ Γ is compact; (2) B y (ρy ) ∩ Γ ⊂ U and (3) L is
µ-PL on B y (ρy ∩ Γ).
Among the above three claims, (2) is immediate. (1) holds because B y (ρy ) ∩ Γ
is bounded and we can make ρy small enough to ensure B y (ρy ) ∩ Γ is closed. For
(3), by Proposition 7 of [203], we define p(y) := argminx∈Γ kx − yk which is uniquely
defined and C 1 in B y (ρy ) for sufficiently small ρy . Moreover, Lemma 14 in [203] shows
that k∇L(x) − ∇2 L(p(x))(x − p(x))k ≤ c kx − p(x)k22 for all x in By (ρy ) uniformly
and some constant c. Thus for small enough ρy ,
and that
Thus for any c0 > 0, for sufficiently small ρy , (x − p(x))> ∇2 L(p(x))(x − p(x)) ≥
c0 kx − p(x)k3 . Combining Equations (9.15) and (9.16), we conclude that for sufficiently
small ρy ,
439
Again for sufficiently small ρy , by Taylor expansion of L at p(x), we have
1
(x − p(x))> ∇2 L(p(x))(x − p(x)) ≥ L(x) − O(kx − p(x)k3 ).
2
Thus we conclude
Meanwhile, since λM (∇2 L(p(x))) and λ1 (∇2 L(p(x))) − λ2 (∇2 L(p(x))) are
continuous functions in x, we can also choose a sufficiently small ρy such that
1 1
for all x ∈ B y (ρy ), λM (∇2 L(p(x))) ≥ λ (∇2 L(p(y)))
2 M
= λ (∇2 L(y))
2 M
> ∆
and λ1 (∇2 L(p(x))) − λ2 (∇2 L(p(x))) ≥ 12
λ1 (∇2 L(p(y))) − λ2 (∇2 L(p(y))) =
1
2
λ1 (∇2 L(y)) − λ2 (∇2 L(y)) ≥ µ. Further note Y ⊂ ∪y∈Y By (ρy ) and Y is a compact
set, we can take a finite subset of Y , Y 0 , such that Y ⊂ ∪y∈Y 0 By (ρy ). Taking
ρy
ρ := miny∈Y 0 2
completes the proof.
440
We assume each of the constants ζ, ν, Υ, ξ, χ are at least 1 for simplicity (otherwise
we can set them to be 1 and our bound still holds)
Lemma 9.10.5. Given ρ as defined in Lemma 9.10.2, there is an ∈ (0, ρ) such that
2 5
1. supx∈Y L(x) − inf L(x) < min( µρ8 , ν2µ
2 ζ 2 );
x∈Y
ρ
2. ∀x ∈ Y , Φ(x) ∈ Y 2 .
continuous. Further note Y ⊂ ∪y∈Y By (y ) and Y is a compact set, we can take a
y
finite subset of Y , Y 0 , such that Y ⊂ ∪y∈Y 0 By (y ). Taking := miny∈Y 0 2
completes
the proof.
Summary for Setups: The initial point xinit is chosen from an open neighborhood
of manifold Γ, U , where the infinite-time limit of gradient flow Φ is well-defined and
for any x ∈ U , Φ(x) ∈ Γ. We consider normalized GD with sufficiently small LR
η such that the trajectory enters a small neighborhood of limiting flow trajectory,
Y ρ . Moreover, L is µ-PL on Y ρ and the eigengaps and smallest eigenvalues are
uniformly lower bounded by positive ∆, µ respectively on Y ρ . Finally, we consider
a proper subset of Y ρ , Y , as the final ”working zone” in the second phase (defined
in Lemma 9.10.5), which enjoys more properties than Y ρ , including Lemmas 9.10.7
to 9.10.10.
In this subsection we present several geometric lemmas which are frequently used in
the trajectory analysis of normalized GD. In this section, O(·) only hides absolute
constants. Below is a brief summary:
441
• Lemma 9.10.6: Inequalities connecting various terms: the distance between x
and Φ(x), the length of GF trajectory from x to Φ(x), square root of loss and
gradient norm;
• Lemma 9.10.7: For any x ∈ Y , the gradient flow trajectory from x to Φ(x) and
the line segment between x and Φ(x) are all contained in Y ρ , so it’s ”safe” to
use Taylor expansions along GF trajectory or xΦ(x) to derive properties;
Lemma 9.10.6. If the trajectory of gradient flow starting from x, φ(x, t), stays in
Y ρ for all t ≥ 0, then we have
s
∞
2(L(x) − L(Φ(x))) k∇L(x)k
Z
dφ(x, t)
kx − Φ(x)k ≤ dt ≤ ≤ .
t=0 dt µ µ
Proof of Lemma 9.10.6. Since Φ(x) is defined as limt→∞ φ(x, t) and φ(x, 0) = x, the
left-side inequality follows immediately from triangle inequality. The right-side in-
equality is by the definition of PL condition. Below we prove the middle inequality.
Since ∀t ≥ 0, φ(x, t) ∈ Y ρ , it holds that k∇L(φ(x, t))k2 ≥ 2µ(L(φ(x, t))−L(Φ(x)))
by the choice of ρ in Lemma 9.10.2. Without loss of generality, we assume L(y) =
442
0, ∀y ∈ Γ. Thus we have
∞ ∞
k∇L(φ(x, t))k2
Z Z
k∇L(φ(x, t))k dt ≤ p dt.
t=0 t=0 2µL(φ(x, t))
s
∞
k∇L(φ(x, t))k2 ∞ ∞
r
−dL(φ(x, t))
Z Z Z
2 p 2L(φ(x, 0))
p dt ≤ p = d L(φ(x, t)) = .
t=0 2µL(φ(x, t)) t=0 2µL(φ(x, t)) t=0 µ µ
Lemma 9.10.7. Let ρ, be defined in Lemmas 9.10.2 and 9.10.5. For any x ∈ Y ,
we have
2
2. Moreover, kΦ(x) − φ(x, t)k ≤ min(ρ, 2µ
νζ
), ∀t ≥ 0.
Proof of Lemma 9.10.7. Let time τ ∗ ≥ 0 be the smallest time after which the trajec-
tory of GF is completely contained in Y ρ , that is, τ ∗ := inf{t ≥ 0 | ∀t0 ≥ t, φ(x, t0 ) ∈
Y ρ }. Since Y ρ is closed and φ(x, ·) is continuous, we have φ(x, τ ∗ ) ∈ Y ρ .
Since ∀τ ≥ τ ∗ , φ(x, τ ) ∈ Y ρ , by Lemma 9.10.6, it holds that kφ(x, τ ∗ ) − Φ(x)k ≤
q
2(L(φ(x,τ ∗ ))−L(Φ(x)))
µ
.
Note that loss doesn’t increase along GF, we have L(φ(x, τ ∗ )) − L(Φ(x)) ≤ L(x) −
µρ2
L(Φ(x)) ≤ 8
, which implies that kφ(x, τ ∗ ) − Φ(x)k ≤ ρ2 . Therefore τ ∗ must be 0,
otherwise there exists a 0 < τ 0 < τ ∗ such that kφ(x, τ ) − Φ(x)k ≤ ρ for all τ 0 < τ < τ ∗
by the continuity of φ(x, ·). This proves the first claim.
Given the first claim is proved, the second claim follows directly from Lemma 9.10.6.
The following theorem shows that the projection of x in the tangent space of
Φ(x) is small when x is close to the manifold. In particular if we can show that in a
443
discrete trajectory with a vanishing learning rate η, the iterates {xη (t)} stay in Y ,
we can interchangeably use kxη (t) − Φ(xη (t))k with kPt,Γ (xη (t) − Φ(xη (t)))k, with an
additional error of O(η 3 ), when kPt,Γ (xη (t) − Φ(xη (t)))k ≤ O(η).
νζ
⊥
PΦ(x),Γ (x − Φ(x)) ≤ kx − Φ(x)k2 ,
4µ2
and that
2 2 νζ 1
PΦ(x),Γ (x − Φ(x)) ≥ kx − Φ(x)k 1 − 2
kx − Φ(x)k ≥ kx − Φ(x)k2 .
4µ 2
Proof of Lemma 9.10.8. First of all, we can track the decrease in loss along the
Gradient flow trajectory starting from x. At any time τ , we have
d d
L(φ(x, τ )) = h∇L(φ(x, τ )), φ(x, τ )i = − k∇L(φ(x, τ ))k2 ,
dτ dτ
d
L(φ(x, τ )) ≤ −2µL(φ(x, τ )),
dτ
which implies
444
By Lemma 9.10.6, we have
s
2√
r
2L(φ(x, 0))e−2µτ
kφ(x, τ ) − Φ(x)k ≤ L(φ(x, τ )) ≤ . (9.17)
µ µ
Moreover, we can relate L(φ(x, 0) with kΦ(x) − xk with a second order taylor
expansion:
where in the final step, we have used the fact that L(Φ(x)) = 0 and ∇L(Φ(x)) = 0. By
Lemma 9.10.7, we have xΦ(x) ⊂ Y ρ . Thus maxs∈[0,1] k∇2 L(sx + (1 − s)Φ(x))k ≤ ζ
from Definition 9.10.4 and it follows that
Z 1
ζ
L(x) ≤ (1 − s)ζ kx − Φ(x)k2 ds = kΦ(x) − xk2 , (9.18)
s=0 2
Z ∞ Z ∞
⊥ ⊥ ⊥
PΦ(x),Γ (φ(x, ∞) − φ(x, 0)) ≤ PΦ(x),Γ ∇L(φ(x, τ )) dτ ≤ PΦ(x),Γ ∇L(φ(x, τ )) dτ.
0 0
(9.19)
ν
∇L(φ(x, τ )) − ∇2 L(Φ(x)) Φ(x) φ(x, τ ) − Φ(x) ≤ kφ(x, τ ) − Φ(x)k2 .
2
445
⊥ ⊥
Since PΦ(x),Γ is the projection matrix for the tangent space, PΦ(x),Γ ∇2 L(Φ(x)) = 0 and
thus by Equation (9.17)
⊥ ν 2 νL(φ(x, 0))e−2µτ
PΦ(x),Γ ∇L(φ(x, τ )) ≤ kφ(x, τ ) − Φ(x)k ≤ (9.20)
2 µ
∞
νL(φ(x, 0))e−2µτ νζ kx − Φ(x)k2
Z
⊥ νL(x)
PΦ(x),Γ (φ(x, ∞) − x) ≤ = ≤
τ =0 µ 2µ2 4µ2
(9.21)
For the second claim, simply note that
⊥
PΦ(x),Γ (x − Φ(x))
q
2
= kx − Φ(x)k2 − PΦ(x),Γ (x − Φ(x))
2
PΦ(x),Γ (x − Φ(x))
≥ kx − Φ(x)k − .
kx − Φ(x)k
The left-side inequality of the second inequality is proved by plugging the first
claim into the above inequality Equation (9.21) and rearranging the terms. Note by
νζ
the second claim in Lemma 9.10.7, 4µ2
kx − Φ(x)k ≤ 12 , the right-side inequality is
also proved.
1
∇L(x) − ∇2 L(Φ(x))(x − Φ(x)) ≤ ν kx − Φ(x)k2 .
2
and
k∇L(x)k ν
− 1 ≤ kx − Φ(x)k ,
k∇2 L(Φ(x))(x − Φ(x))k µ
446
Moreover, the normalized gradient of L can be written as
1
≤ ν kx − Φ(x)k2 .
2
≥µ PΦ(x),Γ (x − Φ(x)) ,
we have
k∇L(x)k ν kx − Φ(x)k2 ν
2
−1 ≤ ≤ kx − Φ(x)k ,
k∇ L(Φ(x))(x − Φ(x))k 2µ PΦ(x),Γ (x − Φ(x)) µ
where we use Lemma 9.10.8 since x ∈ Y . Thus, the normalized gradient at any step
t can be written as
447
Lemma 9.10.10. Consider any point x ∈ Y . Then,
∇L(x) ν
v1 (x), ≥ cos θ − O( kx − Φ(x)k),
k∇L(x)k µ
(2:M )
PΦ(x),Γ x
e = ∇2 L(Φ(x))(x − Φ(x)).
e
where θ = arctan |hv1 (x),e
xi|
, with x
∇L(x)
Lemma 9.10.11. For any xy ∈ Y where y = x − η k∇L(x)k is the one step Normalized
GD update from x, we have
1
kΦ(y) − Φ(x)k ≤ ξη 2 .
2
1
λk (∇2 L(Φ(x))) − λk (∇2 L(Φ(y))) ≤ νξη 2 ,
4
and
1 νξη 2 νξη 2 ν 2 ξ2 η4
v1 (∇2 L(Φ(x))) − v1 (∇2 L(Φ(y))) ≤ = + O( ).
2 ∆ − 14 νξη 2 2∆ ∆
448
Proof. By Lemma 9.10.14, we have ∂Φ(x)∇L(x) = 0 for all x ∈ U . Thus we have
1
∇L(x) ∇L(x)
Z
kΦ(y) − Φ(x)k =η ∂Φ x − sη ds
s=0 k∇L(x)k k∇L(x)k
Z 1
∇L(x) ∇L(x)
=η ∂Φ x − sη − ∂Φ(x) ds
s=0 k∇L(x)k k∇L(x)k
Z 1
∇L(x)
≤η ∂Φ x − sη − ∂Φ(x) ds
s=0 k∇L(x)k
Z 1
≤η 2
s sup ∇2 Φ((1 − s0 )x + s0 y) ds
s=0 s0 ∈[0,s]
2
η
= sup ∇2 Φ((1 − s0 )x + s0 y)
2 s0 ∈[0,1]
1
≤ ξη 2 ,
2
≤ ∇2 L(Φ(x)) − ∇2 L(Φ(y))
Z 1
= (1 − s)∂ 2 (∇L)(Φ(sx + (1 − s)y))(Φ(x) − Φ(y))ds
s=0
Z 1
≤ (1 − s)ds max ∂ 2 (∇L)(Φ(sx + (1 − s)y)) kΦ(x) − Φ(y))k
s=0 s∈[0,1]
1
≤ νξη 2 ,
4
449
The third claim follows from using Theorem 9.14.4. Again,
where we borrow the bound on k∇2 L(Φ(x)) − ∇2 L(Φ(y))k from our previous calcula-
tions. The final step follows from the constants defined in Definition 9.10.4.
∇L(x)
Lemma 9.10.12. For any xy ∈ Y where y = x − η k∇L(x)k is the one step Normalized
GD update from x, we have that
η2 ⊥
Φ(y) − Φ(x) = − P ∇(log λ1 (∇2 L(Φ(x))))
4 Φ(x),Γ
νξ kx − Φ(x)k η 2
+ O(η 2 ξθ) + O( ) + O(χ kx − Φ(x)k η 2 ) + O(χη 3 ).
µ
(2:M )
PΦ(x),Γ x
e = ∇2 L(Φ(x))(x − Φ(x)). Additionally, we have that
e
Here θ = arctan |hv1 (x),e
xi|
, with x
νξ
PΦ(x),Γ (Φ(y) − Φ(x)) ≤ O(χ kx − Φ(x)k η 2 ) + O(χη 3 ) + O( kx − Φ(x)k η 2 ).
µ
1
Φ(y) − Φ(x) =∂Φ(x) (y − x) + ∂ 2 Φ(x)[y − x, y − x] + O(χ ky − xk3 )
2
η2 2
∇L(x) ∇L(x) ∇L(x)
=∂Φ(x) −η + ∂ Φ(x) , + O(χη 3 )
k∇L(x)k 2 k∇L(x)k k∇L(x)k
η2 2
∇L(x) ∇L(x)
= ∂ Φ(x) , + O(χη 3 ),
2 k∇L(x)k k∇L(x)k
450
where in the pre-final step, we used the property of Φ from Lemma 9.10.14. In the
final step, we have used a second order taylor expansion to bound the difference
∇L(x)
between ∂ 2 Φ(x) and ∂ 2 Φ(Φ(x)). Additionally, we have used y − x = η k∇L(x)k from the
Normalized GD update rule.
Applying Taylor expansion on Φ again but at Φ(x), we have that
η2 2
∇L(x) ∇L(x)
Φ(y) − Φ(x) = ∂ Φ(Φ(x)) , + O(χ kx − Φ(x)k η 2 ) + O(χη 3 )
2 k∇L(x)k k∇L(x)k
(9.23)
Also, at Φ(x), since v1 (x) is the top eigenvector of the hessian ∇2 L, we have that
from Corollary 9.10.21,
1
∂ 2 Φ(Φ(x)) v1 (x)v1 (x)> = − ∂Φ(Φ(x))∂ 2 (∇L)(Φ(x))[v1 (x), v1 (x)].
(9.24)
2λ1 (x)
∇L(x) ∇L(x)
sign , v1 (x) − v1 (x)
k∇L(x)k k∇L(x)k
θ ν kx − Φ(x)k ν kx − Φ(x)k
≤2 sin + O( ) ≤ θ + O( ). (9.25)
2 µ µ
Plug Equations (9.24) and (9.25) into Equation (9.23), we have that
η2 1
Φ(y) − Φ(x) = − ∂Φ(Φ(x))∂ 2 (∇L)(Φ(x))[v1 (x), v1 (x)]
2 2λ1 (x)
νξ kx − Φ(x)k η 2
+ O(η 2 ξθ) + O( ) + O(χ kx − Φ(x)k η 2 ) + O(χη 3 ).
µ
By Lemma 9.10.16, for any x ∈ Γ, ∂Φ(x) is the projection matrix onto the tangent
⊥
space TΦ(x) Γ. Thus, ∂Φ(Φ(x)) = PΦ(x),Γ . Thus the proof of the first claim is completed
⊥
by noting that ∂Φ(Φ(x))∂ 2 (∇L)(Φ(x))[v1 (x), v1 (x)] = PΦ(x),Γ ∇λ1 (∇2 L(Φ(x))) by
Corollary 9.10.22.
451
For the second claim, continuing from Equation (9.23), we have that
η2 2
∇L(x) ∇L(x)
Φ(y) − Φ(x) = ∂ Φ(Φ(x)) , + O(χ kx − Φ(x)k η 2 ) + O(χη 3 )
2 k∇L(x)k k∇L(x)k
2
η νξ
= ∂ 2 Φ(Φ(x)) [Σ] + O(χ kx − Φ(x)k η 2 ) + O(χη 3 ) + O( kx − Φ(x)k η 2 ),
2 µ
>
∇L(x) ∇L(x)
where Σ = PΦ(x),Γ k∇L(x)k PΦ(x),Γ k∇L(Φ(x))k and the last step is by Lemma 9.10.9.
Here PΦ(x),Γ denotes the projection matrix of the subspace spanned by v1 (x), . . . , vM (x).
By Lemmas 9.10.16, 9.10.17 and 9.10.20, we have that PΦ(x),Γ ∂ 2 Φ(Φ(x)) [Σ] =
−PΦ(x),Γ ∂Φ(x)∂ 2 (∇L)(x)[L−1
∇2 L(x) Σ] = 0, we conclude that
νξ
PΦ(x),Γ (Φ(y) − Φ(x)) ≤ O(χ kx − Φ(x)k η 2 ) + O(χη 3 ) + O( kx − Φ(x)k η 2 ),
µ
∇L(x)
Lemma 9.10.13. Let Lmin = miny∈U L(y). For any xy ∈ Y where y = x − η k∇L(x)k
is the one step Normalized GD update from x, if k∇L(xη (t))k ≥ ζη, we have that
√
p p 2µ
L(y) − Lmin ≤ L(x) − Lmin − η .
4
ζη 2
L(y) ≤ L(x) − η k∇L(x)k + .
2
√
η 2µ p
L(y) − L(x) ≤ − k∇L(x)k ≤ −η L(x) − Lmin ≤ 0,
2 2
452
where the last step is because L is µ-PL on Y . In other words, we have that
p √ √
p p L(x) − Lmin 2µ 2µ
L(y) − Lmin − L(x) − Lmin ≤ −η p p ≤ −η ,
L(y) − Lmin + L(x) − Lmin 2 4
where in the last step we use L(y) − L(x) ≤ 0. This completes the proof.
The following results Lemmas 9.10.14 to 9.10.18 and 9.10.20 and definition 9.10.19
are from Chapter 8.
Lemma 9.10.14. For any x ∈ U , it holds that (1). ∂Φ(x)∇L(x) = 0 and (2).
∂ 2 Φ(x)[∇L(x), ∇L(x)] = −∂Φ(x)∇2 L(x)∇L(x).
Lemma 9.10.16. For any x ∈ Γ, ∂Φ(x) ∈ RD×D is the projection matrix onto the
⊥
tangent space Tx Γ, i.e. ∂Φ(x) = Px,Γ .
Corollary 9.10.21. For any x ∈ Γ, let v1 be the unit top eigenvector of ∇2 L(x), then
1
∂ 2 Φ(x)[v1 v1> ] = − ∂Φ(x)∂ 2 (∇L)(x)[v1 , v1 ]
2λ1 (∇2 L(x))
Corollary 9.10.22. For any x ∈ Γ, let v1 be the unit top eigenvector of ∇2 L(x), then
1 ⊥
∂ 2 Φ(x)[v1 v1> ] = − Px,Γ ∇ log(λ1 (∇2 L(x))).
2
Proof of Corollary 9.10.22. The proof follows from using Corollary 9.10.21 and the
derivative of λ1 from Theorem 9.14.1.
Corollary 9.10.23. For any x ∈ Γ, let v1 be the unit top eigenvector of ∇2 L(x), then
1 ⊥
∂ 2 Φ(x)[λ1 (∇2 L(x))v1 v1> ] = − Px,Γ ∇(λ1 (∇2 L(x))).
2
Functions
We restate the theorem concerning Phase I for the Normalized GD algorithm. Recall
the following notation for each 1 ≤ j ≤ M :
v
uM
uX
Rj (x) := t λ2i (x)hvi (x), x − Φ(x)i2 − λj (x)η, for all x ∈ U.
i=j
454
Theorem 9.5.3 (Phase I). Let {xη (t)}t∈N be the iterates of Normalized GD (9.4)
with LR η and xη (0) = xinit ∈ U . There is T1 > 0 such that for any T10 > T1 , it
holds that for sufficiently small η that (1) max kxη (t) − Φ(xinit )k ≤ O(η) and (2)
T1 ≤ηt≤T10
max Rj (xη (t)) ≤ O(η 2 ).
T1 ≤ηt≤T10 ,j∈[D]
The intuition behind the above theorem is that for sufficiently small LR η, xη (t)
will track the normalized gradient flow starting from xinit , which is a time-rescaled
version of the standard gradient flow. Thus the normalized GF will enter Y and so
does normalized GD. Since L satisfies PL condition in Y , the loss converges quickly
and the iterate xη (t) gets η to manifold. To finish, we need the following theorem,
which is the approximately-quadratic version of Lemma 9.3.3 when the iterate is O(η)
close to the manifold.
Lemma 9.11.1. Suppose {xη (t)}t≥0 are iterates of Normalized GD (9.4) with a
learning rate η and xη (0) = xinit . There is a constant C > 0, such that for any
kxη (t0 )−Φ(xη (t0 ))k
constant ς > 1, if at some time t0 , xη (t0 ) ∈ Y and satisfies η
≤ ς, then
for all t̄ ≥ t0 + C ζς
µ
log ςζ
µ
, the following must hold true for all 1 ≤ j ≤ M :
v
uM
uX
e(t̄)i2 ≤ ηλj (t̄) + O(η 2 ),
t hvi (t̄), x (9.26)
i=j
455
Let Tx be the length of the GF trajectory starting from x, and we know
limτ →Tx φ(x, τ ) = Φ(x), where φ(x, τ ) is defined as the Normalized gradient flow
starting from x. In Lemmas 9.10.2 and 9.10.5 we show there is a small neigh-
bourhood around Φ(xinit ), Y such that L is µ-PL in Y . Thus we can take some
time T0 < Txinit such that φ(xinit , T0 ) ∈ Y /2 and L(φ(xinit ), T0 ) ≤ 12 Lcritical , where
2 µ
Lcritical := 8
. (Without loss of generality, we assume miny∈Y L(y) = 0) By standard
ODE approximation theory, we know there is some small η0 , such that for all η ≤ η0 ,
xη (dT0 /ηe) − φ(xinit , T0 ) = O(η), where O(·) hides constants depending on the
initialization xinit and the loss function L.
Without loss of generality, we can assume η0 is small enough such that xη (dT0 /ηe) ∈
Y and L(xη (dT0 /ηe)) ≤ Lcritical . Now let tη be the smallest integer (yet still
larger than dT0 /ηe) such that xη (tη )xη (tη − 1) 6⊂ Y and we claim that there is
t ∈ {dT0 /ηe, . . . , tη }, k∇L(xη (t))k < ζη. By the definition of tη , we know for any t ∈
{dT0 /ηe + 1, . . . , tη − 1}, by Lemma 9.10.11 we have kΦ(xη (t)) − Φ(xη (t − 1))k ≤ ξη 2 ,
√
and by Lemma 9.10.13, L(xη (t)) − xη (t − 1) ≤ −η 42µ if k∇L(xη (t))k ≥ ζη. If
p p
√
the claim is not true, since L(xη (t)) decreases η 42µ per step, we have
p
√
2µ
q q
0 ≤ L(xη (tη − 1)) ≤ L(xη (dT0 /ηe)) − (tη − dT0 /ηe − 1)η ,
4
ξη 2 ξη
kΦ(xη (tη − 1)) − Φ(xη (dT0 /ηe))k ≤ (tη − dT0 /ηe − 1) =
2 2
Thus we have
≤ kΦ(xη (tη − 1)) − Φ(xη (dT0 /ηe))k + Φ(xη (dT0 /ηe) − Φ(φ(xinit , T0 ))) = O(η).
456
q
2L(xη (tη −1))
Meanwhile, by Lemma 9.10.6, we have kΦ(xη (tη − 1)) − xη (tη − 1)k ≤ µ
≤
q
2L(xη (dT0 /ηe))
µ
= 2 . Thus for any κ ∈ [0, 1], we have kκxη (tη ) + (1 − κ)xη (tη − 1) − Φ(xinit )k
is upper bounded by
κ kxη (tη ) − xη (tη − 1)k + kΦ(xη (tη − 1)) − xη (tη − 1)k + kΦ(xη (tη − 1)) − Φ(xinit )k
=κη + + O(η),
2
which is smaller than since we can set η0 sufficiently small. In other words,
Φ(xη (tη ))Φ(xη (tη − 1)) ⊂ Y , which contradicts with the definition of tη . So far we
have proved our claim that there is some t0η ∈ {dT0 /ηe, . . . , tη }, ∇L(xη (t0η )) < ζη.
√
Moreover, since L(xη (t)) decreases η 42µ per step before t0η , we know t0η −dT0 /ηe ≤ η .
p
ζη
By Lemma 9.10.6, we know xη (t0η ) − Φ(xη (t0η )) ≤ µ
.
Now we claim that for any T10 , there is some sufficiently small threshold
T10
η0 , tη ≥ η
+ 1 if η ≤ η0 . Below we prove this claim by contradiction. If
T10
the claim is not true, that is, tη < η
+ 1. if tη ≤ C ζς
µ
log ςζ
µ
+ t0η with
ζ
ς = µ
, we know kxη (tη ) − Φ(xinit )k ≤ xη (tη ) − xη (t0η ) + xη (t0η ) − Φ(xη (t0η )) +
Φ(xη (t0η )) − Φ(xinit ) = O(η), which implies that xη (tη )xη (tη − 1) ∈ Y . If
tη ≥ C ζς
µ
log ςζ
µ
+ t0η , by Lemma 9.11.1, we have kxη (tη ) − Φ(xη (tη ))k = O(η). By
Lemma 9.10.11, we have kΦ(xη (tη )) − Φ(xη (dT0 /ηe))k ≤ O(η). Thus again we
have that kxη (tη ) − Φ(xinit )k ≤ kxη (tη ) − Φ(xη (tη ))k + kΦ(xη (tη )) − Φ(xη (dT0 /ηe))k
+ kΦ(xη (dT0 /ηe)) − Φ(xinit )k = O(η), which implies that xη (tη )xη (tη − 1) ∈ Y . In
both cases, the implication is in contradiction to the definition of tη .
T10
Thus for any T10 , tη ≥ η
+ 1 for sufficiently small threshold η0 and η ≤ η0 . To
complete the proof of Theorem 9.5.3, we pick T1 to be any real number strictly
larger than + T0 , as T1
η
> C ζς
µ
log ςζ
µ
+
η
+ dT0 /ηe ≥ C ζς
µ
log ςζ
µ
+ t0η when η is
sufficiently small with ς = µζ . By Lemma 9.11.1 the second claim of Theorem 9.5.3
457
T10
is proved. Using the same argument again, we know ∀ Tη1 ≤ t ≤ η
, it holds that
kΦ(xη (t)) − Φ(xinit )k ≤ O(η).
We first restate the main theorem that demonstrates that the trajectory implicitly
minimizes sharpness.
Theorem 9.5.4 (Phase II). Let {xη (t)}t∈N be the iterates of perturbed Normalized GD
(Algorithm 7) with LR η. Under Assumptions 9.5.1 and 9.5.2, if the initialization xη (0)
satisfy that (1) kxη (0) − Φ(xinit )k ≤ O(η) where xinit ∈ U , (2) maxj∈[D] Rj (xη (t)) ≤
O(η 2 ), and additionally (3) min{|hv1 (xη (0)), xη (0) − Φ(xη (0))i| , −R1 (xη (0))} ≥ Ω(η),
then for any time T2 > 0 till which the solution of (9.5) exists, it holds for suffi-
ciently small η, with probability at least 1 − O(η 10 ), that kΦ(xη (bT2 /η 2 c)) − X(T2 )k =
P 2 /η2 c
O(η) and bT21/η2 c bTt=0 θt ≤ O(η), where θt ∈ [0, π2 ] denotes the angle between
∇2 L(Φ(xη (t)))(xη (t) − Φ(xη (t))) and top eigenspace of ∇2 L(Φ(xη (t))).
To show the closeness between the continuous and the discrete dynamic, we
will need to use the following classic differential inequality from [257]. The original
statement is for differential equations defined on RD . Without loss of generality, we
can restrict it to an open subset of RD with the same proof.
Theorem 9.11.2. [Adaption of “Variant form of Theorem 10.2”, p.59, [257]] Let U
be an open subset of RD . Suppose that {y(τ ) ∈ U }Tτ=0 is a solution of the differential
dy
equation dτ
= f (y(τ )), y(τ ) = y0 , and that v(τ ) ∈ U is a piecewise linear curve. If
f (y) is βlip -Lipschitz in y, that is, ∀y, y 0 ∈ U , τ ∈ [0, T ], kf (y) − f (y 0 )k ≤ βlip , then
458
for any 0 ≤ τ ≤ T , it holds that for any τ ∈ [0, T ],
Z τ
−βlip τ
ky(τ ) − v(τ )k ≤e T βlip
kv(0) − y(0)k + e kv (τ + 0) − f (v(τ ))k dτ 0
0 0
0
Zτ τ=0
≤e T βlip
kv(0) − y(0)k + kv (τ + 0) − f (v(τ ))k dτ 0 ,
0 0
τ 0 =0
v(τ 0 +δ)−v(τ 0 )
where v 0 (τ 0 + 0) := limδ→0 δ
is the right time derivative of v at τ 0 .
Proof of Theorem 9.5.4. Without loss of generality, we can change assumption (3) in
xη (0)k ≤ ηλ1 (0)/2 + Ψnorm η 2 and |hv1 (xη (0)), x
the theorem statement into ke eη (0)i| ≥
Ω(η). (Constant Ψnorm is defined in Lemma 9.12.1) This is because we know from
λ1 (·)
Lemma 9.12.1, that the norm can’t stay above 2
η + Ψnorm η 2 for two consecutive
eη (0)| ≥ Ω(η) but ηλ1 (0)/2 + Ω(Ψnorm η 2 ) ≤ ke
steps. Moreover, if |v1 (0), x xη (0)k ≤
ηλ1 (0) − Ω(η), we can further show that |v1 (1), x
eη (1)| ≥ Ω(η) from the update rule of
Normalized GD (Lemma 9.10.9). Thus, we can shift our analysis by one time-step
if our assumption isn’t true at step 0. This simplification of assumption helps us to
prove the second claim using Lemma 9.13.5.
To prove the first claim, we first show the movement in the manifold for the
discrete trajectory for Algorithm 7 by Lemma 9.10.12: for each step t, provided
Φ(xη (t))Φ(xη (t + 1)) ∈ Y , it holds that
η2 ⊥
Φ(xη (t + 1)) − Φ(xη (t)) = − PΦ(xη (t)),Γ ∇ log λ1 (xη (t)) + O((θt + kxη (t) − Φ(xη (t))k)η 2 ).
4
(9.27)
Z τ
1 ⊥
X(τ ) = Φ(xinit ) − PX(s),Γ ∇ log λ1 (X(s))ds, X(τ ) ∈ Γ (9.5)
4 s=0
The high-level idea for the proof of the first claim is to bound the gap between
Equation (9.27) and Equation (9.5) using Theorem 9.11.2. And the first claim
459
eventually boils down to upper bound the average angle by O(η), which is exactly the
second claim.
Formally, let t2 be the largest integer no larger than bT2 /η 2 c such that for any
0 ≤ t ≤ t2 , it holds that Φ(xη (t))Φ(xη (t + 1)) ∈ Y .
To apply Theorem 9.11.2, we let y(τ ) = X(τ ), f : U → RD , f (y) =
⊥
PΦ(y),Γ ∇ log λ1 (y), βlip be an upper bound for lipschitzness of f on compact set Y
and v(τ ) = Φ(xη (bτ /η 2 c)) + (τ /η 2 − bτ /η 2 c)) (Φ(xη (bτ /η 2 c + 1)) − Φ(xη (bτ /η 2 c))).
In other words, v is a piecewise linear curve interpolating all xη (t) at time tη 2 .
Therefore, by Equation (9.27), note Φ(x) = Φ(Φ(x)) for all x ∈ U , it holds that
Since we started from a point that has max1≤j≤M Rj (xη (0)) ≤ O(η 2 ), we have from
Lemma 9.11.1, that the iterate satisfies the condition max1≤j≤M Rj (xη (t)) ≤ O(η 2 ) at
step t as well, meaning that kxη (t) − Φ(xη (t))k ≤ O(η).
Therefore, for any τ ≤ t2 η 2 , note that v 0 (τ + 0) = v 0 (bτ /η 2 c + 0) and that
kf (v(bτ /η 2 c + 0)) − f (v(τ ))k = O(kΦ(xη (bτ /η 2 c + 1)) − Φ(xη (bτ /η 2 c))k) = O(η 2 ),
we have that
Z t2 η 2 t2
X
2 2 2 2
X(t2 η )−xη (t2 ) = y(t2 η )−v(t2 η ) = O η+ (θbτ /η2 c +η) = O(η+ η (θt +η) ) = O(η),
τ =0 t=0
460
where in the last step we use the second claim. This implies that t2 must be equal
to bT2 /η 2 c for sufficiently small η otherwise xη (t2 )xη (t2 + 1) ⊆ Y . This is because
kxη (t2 + 1) − xη (t2 )k = O(η) and X(t2 η 2 ) ∈ Y . The proof is completed by noting
that kX(T2 ) − X(bT2 /η 2 c)k = O(η 2 ).
Proof of Lemma 9.11.1. The Normalized GD update at any step t can be written as
(from Lemma 9.10.9)
From Lemma 9.10.11, we have kΦ(xη (t)) − Φ(xη (t + 1))k ≤ O(ξη 2 ), which further
implies,
k∇2 L(Φ(xη (t + 1))) − ∇2 L(Φ(xη (t)))k ≤ O(νξη 2 ). Thus, using the notation x
e =
∇2 L(Φ(x))(x − Φ(x)), we have
eη (t + 1) − x
x eη (t + 1) − ∇2 L(Φ(xη (t)))(xη (t + 1) − Φ(xη (t)))
eη (t) = x
461
That is,
∇2 L(Φ(xη (t)))
eη (t + 1) = I − η
x eη (t) + O(η 2 ) + O(η kxη (t) − Φ(xη (t))k).
x
ke
xη (t)k
(9.29)
Below we will show that kxη (t) − Φ(xη (t))k ≤ O(η), and thus the trajectory of
eη is similar to the trajectory in the qudratic model with an O(η 2 ) error, with the
x
hessian fixed at ∇2 L(Φ(xη (t))), and hence we can apply the same techniques from
Corollary 9.9.4 and Lemma 9.9.1.
eη (t) for t0 + 1 ≤ t ≤ t. We will show
First, we consider the norm of the vector x
the following induction hypothesis:
ke
xη (t)k ≤ 1.01ηζς.
2. Induction case:(t > t0 ). Suppose the hypothesis holds true for t − 1. Then,
1 1.01ηςζ
kxη (t − 1) − Φ(xη (t − 1))k ≤ ke
xη (t − 1)k ≤ .
λM (t) µ
462
(a) If ke
xη (t − 1)k ≥ ηλ1 (t). We can directly apply Corollary 9.9.3 on (9.29) to
show that
ηλM (t − 1)
ke
xη (t)k ≤ 1 − xη (t − 1)k + O(νξη 2 )
ke
2 ke
xη (t − 1)k
νζ
+ O( η kxη (t − 1) − Φ(xη (t − 1))k)
µ
ηλM (t − 1) νζ
≤ ke
xη (t − 1)k − + O(νξη 2 ) + O( 2 ςη 2 )
2 µ
ηλM (t − 1)
≤ ke
xη (t − 1)k − ,
4
(b) If ke
xη (t − 1)k ≤ ηλ1 (t). Then, we can directly apply Lemma 9.9.1 on (9.29)
to show that
νζ
xη (t)k ≤ ηλ1 (t) + O(νξη 2 ) + O(
ke η kxη (t − 1) − Φ(xη (t − 1))k)
µ
νζς
≤ ηλ1 (t) + O(νξη 2 ) + O( 2 η 2 )
µ
≤ 1.01ηλ1 (t).
1 1.01ηςζ
Hence, we have shown that, kxη (t) − Φ(xη (t))k ≤ λM (t)
ke
xη (t)k ≤ µ
for all
time t0 ≤ t ≤ t.
We complete the proof of Lemma 9.11.1 with a similar argument as that for the
quadratic model (see Corollary 9.9.4 and Lemma 9.9.1). The major difference from the
quadratic model is that here the hessian changes over time, along with its eigenvectors
and eigenvalues. Hence, we need to take care of the errors introduced in each step by
the change of hessian.
The high-level idea is to divide the eigenvalues at each step t into groups such that
eigenvalues in the same group are O(η) close and eigenvalues from different groups
463
are at least 2η far away from each other. Formally, we divide [M ] into disjoint subsets
(t) (t)
S1 , · · · , Sp(t) (with 1 ≤ p(t) ≤ M ) such that
and
(t)
∀k ∈ [p(t)], i, i + 1 ∈ Sk λi (t) − λi+1 (t) ≤ η.
For S ⊂ [M ], we denote by PtS the projection matrix at time t onto the subspace
spanned by {vi (t)}i∈S . From Lemma 9.10.11, we have kΦ(xη (t + 1)) − Φ(xη (t + 1))k ≤
ξη 2 , which further implies, k∇2 L(Φ(xη (t + 1))) − ∇2 L(Φ(xη (t)))k ≤ O(νξη 2 ). That
implies, using Theorem 9.14.2, |λj (t) − λj (t)| ≤ O(νξη 2 ) for any j ∈ [M ]. Therefore,
(t) (t)
S` S
we can use Theorem 9.14.4 to have for any ` ∈ [p] Pt − Pt+1
`
≤ O(νξη), since
we have created the eigen subspaces such that the eigenvalue gap between any two
distinct eigen subspaces is at least 0.5η in the desired interval.
(t) (t)
Thus for any t0 ≤ t ≤ t − 1 and k ∈ [p(t)], suppose i ∈ Sk and j = min Sk , we
have that
v
uM
uX
eη (t + 1)i2
t hv (t + 1), x
h
h=i
v v
uM u p(t) 2
uX uX S (t)
eη (t + 1)i2 = t
≤t hvh (t + 1), x Pt+1
`
x
eη (t + 1)
h=j `=k
v v
u p(t) 2 uM
uX S (t) uX
2
≤t Pt ` x
eη (t + 1) eη (t + 1)i2 + O(η 2 )
+ O(η ) = t hvh (t), x
`=k h=j
464
Therefore, we have that
v v
uM uM
1 uX 1 uX
eη (t + 1)i2 ≤
t hv (t + 1), x
h eη (t + 1)i2 + O(η)
t hvh (t), x
ηλi (t + 1) h=i
ηλj (t) h=j
Next we will use the results from the quadratic case to upper bound
qP qP
M 2 M
h=j hvh (t), x
eη (t + 1)i using eη (t)i2 . For all 1 ≤ j ≤ M , we
h=j hvh (t), x
qP
M
2. If eη (t)i2
h=j hvh (t), x ≤ ηλj (t), then we can apply Lemma 9.9.1 on (9.29) to
show that
v
uM
uX νζ
eη (t + 1)i2 ≤ ηλj (t) + O(νξη 2 ) + O( η kxη (t) − Φ(xη (t))k)
t hvh (t), x
h=j
µ
νζ 2 ς 2
≤ ηλj (t) + O(νξη 2 ) + O( η ).
µ2
465
Thus we conclude that
v
uM
1 uX
max h eη (t + 1)i2
t hv (t + 1), x
i∈[M ] ηλi (t + 1)
h=i
v
uM
µ 1 u X
≤ max 1, (1 − )· eη (t)i2 + O(η),
max t hvh (t), x
2ζς ηλj (t) j∈[M ] h=j
and therefore following the same proof of quadratic case Corollary 9.9.4, for t ≥
qP
M
t0 + Ω( ςζ
µ
log ζς
µ
), it holds that ∀j ∈ [M ], e(t̄)i2 ≤ ηλj (t̄) + O(η 2 ).
i=j hvi (t̄), x
By Lemma 9.11.1, the following condition will continue to hold true for all 1 ≤ j ≤ M
eη (t) leaves Y :
before x
v
uM
uX
eη (t)i2 ≤ λj (t)η + O(η 2 ),
t hvi (t), x (9.30)
i=j
eη (t) = ∇2 L(Φ(xη (t)))(xη (t) − Φ(xη (t))). We will call the above condition as
where x
the alignment condition from now onwards.
From the alignment condition (9.30), we can derive the following property that
continues to hold true throughout the trajectory, once the condition is satisfied:
Lemma 9.12.1. There is some constant Ψnorm > 0, such that if the condition (9.30)
ηλ1 (t)
holds true and ke
xη (t)k > 2
, we have:
ηλ1 (t)
ke
xη (t + 1)k ≤ + Ψnorm η 2 .
2
The proof follows from applying Lemma 9.9.8 using the alignment condition (9.30).
eη (t) can’t stay at norm larger than 0.5ηλ1 (t) + Ψnorm η 2 for time
Hence, the iterate x
larger than 1.
466
Another useful lemma is to about the change of the angle between x
eη (t) and the top
ηλ1 (t)
eigenvector when ke
xη (t)k ≤ 2
+ Ψnorm η 2 , which is a noisy version of Lemma 9.9.11
for a quadratic model.
ηλ1 (t)
Lemma 9.12.2. Consider any time t such that ke
xη (t)k ≤ 2
+ Ψnorm η 2 , and the
condition (9.30) holds true, then we have that
min(∆, 2µ) η
tan θt+1 ≤ 1− tan θt + O( ).
ζ ke
xη (t)k
and that
ηλ1 η2
tan θt+2 ≤ tan θt + O( ).
ke
xη (t)k ke
xη (t)k
ηλk (t)
Corollary 9.12.3. If for some 1 ≤ k ≤ M , ke
xη (t)k ≤ 2
+ Ψnorm η 2 and condition
(9.30) holds true, the following must hold true:
vk (t + 1)> x
eη (t + 1) ≥ vk (t)> x
eη (t) − O(η 2 ).
The proof follows from using the noisy quadratic update for Normalized GD in
Lemma 9.10.9 (Equation (9.22)) and the behavior in a quadratic model along the
non-top eigenvectors in Lemma 9.9.5.
The main lemma in this section is Lemma 9.13.5 in Section 9.13.2, which says the
sum of the angles across the entire trajectory in any interval [0, t2 ] with t2 = Ω(1/η 2 ),
is at most O(ηt2 ). Before proving the main lemma, we will first recap and introduce
some notations that will be used.
467
In Phase II, we start from a point xη (0), such that (1) kxη (0) − Φ(xinit )k ≤ O(η),
(2) maxj∈[D] Rj (xη (t)) ≤ O(η 2 ), and additionally (3) |hv1 (xη (0)), xη (0) − Φ(xη (0))i| =
Ω(η).
(2:M )
Pt,Γ x
eη (t)
Formally, recall our notation on θt as θt = arctan |hv1 (t),e
xη (t)i|
, with our notation
eη (t) as ∇2 L(Φ(xη (t)))(xη (t) − Φ(xη (t))). Moreover, recall the definition of the
of x
function gt : R → R as
s !
λ1 (t) λ λ
gt (λ) = 1− 1−2 1− .
2 λ1 (t) λ1 (t)
v
uM
uX
eη (t)i2 ≤ λj (t)η + O(η 2 ).
t hvi (t), x
i=j
ηλ1 (t)
Thus, the iterate can’t stay greater than 2
+ Ψnorm η 2 for more than 1 step. This
lemma allows us to divice all time steps into groups of length 1 and 2.
3. ∀t, t ∈ N1 ⇐⇒ t + 1 ∈ N2 .
4. N0 ∪ N1 ∪ N2 = [e
t, t], and the intersection between each pair of them is empty.
469
be in N1 as t is not in N2 . Thus t − 2 must be in N0 and Lemma 9.12.1 implies that
xη (t)k ≤ 0.5λ1 (t)η + Ψnorm η 2 .
ke
ηλ1 2
2. tan θt+2 ≤ Gt
tan θt + O( Gη t )
3. If Gt ≥ 1.02gmax (t)η, tan θt+2 ≤ 1 − min 0.01, min 2λi (t)
λ1 (t)
(1 − λi (t)
λ1 (t)
) tan θt +
i≤M
O(η).
Proof of Lemma 9.13.3. The proof follows from using the noisy update rule for Nor-
malized GD, as derived in Equation (9.29), which says that the Normalized GD
update is very close to the update in a quadratic model with an additional O(η 2 )
error. Using the property of N0 and N1 outlined above, we have the norm of x
eη (t)
at most 0.5λ1 (t)η + Ψnorm η 2 . The result then follows from using Lemma 9.9.11 and
Lemma 9.9.10, that computes the convergence rate towards the top eigenvector for a
quadratic model. (The first two properties are stated in Lemma 9.12.2).
t = max N1 ∩ {e
Lemma 9.13.4. Given any t with θt = Ω(1), let e t | e
t ≤ t}. If
Get ≥ Ω(η), then θet = Ω(1).
Lemma 9.13.5 (Average of the Angles). For any T2 > 0 for which solution of
Equation (9.5) exists, consider an interval [0, t2 ], with Ω(1/η 2 ) ≤ t2 ≤ bT2 /η 2 c.
470
Suppose Algorithm 7 is run with learning rate η for t2 steps, starting from a point
xη (0) that satisfies (1) maxj∈[D] Rj (xη (0)) ≤ O(η 2 ), and (2) G0 := |hv1 (0), x
e(0)i| ≥
ηλ1 (0)
βη, ke
xη (0)k ≤ 2
+ Ψnorm η 2 for some constant β independent of η. The following
holds true with probability at least 1 − η 10 :
t2
1X
θ` ≤ O (η) ,
t2 `=0
provided η is set sufficiently small, and for all time 0 ≤ t ≤ t2 −1, xη (t)xη (t + 1) ⊂ Y .
Proof. We split the entire interval [0, t2 ) into small trunks in the following way,
0=e
t0 < e
t1 < e t` = t2 with e
t2 . . . e t` denoting the starting step of each trunk. Each e
ti is
defined from e
ti−1 for i > 0. The behavior of each trunk depends on the magnitude of
the iterate along the top eigenvector of hessian. We classify the trunks on the basis of
3 possibilities: Consider a general e
ti ,
A. If Geti ≥ 1.02gmax (e
ti ), then we define e
ti+1 as
ti ) ≤ Geti ≤ 1.02gmax (e
B. If 0.98gmin (e ti ) then we define e
ti+1 as
C. If Geti ≤ 0.98gmax (e
ti ), then we define e
ti+1 as
471
Case (A). First of all, since Gt ≥ 1.02gmax (t) for all e
ti ≤ t < e
ti+1 we can show from
Lemma 9.13.3 that the angle with the top eigenvector quickly drops to O(η) in at
most O(ln 1/η) time-steps. Moreover, the iterate’s magnitude can only drop along the
top eigenvector when the angle with the top eigenvector is smaller than O(η), and the
drop is at most O(η 3 ) (Lemma 9.13.10). Thus, during alignment of the iterate to the
top eigenvector, Gt never drops. Moreover, after the alignment, it takes Ω( η12 ) steps for
the iterate’s magnitude along the top eignvector to drop below 1.01 maxk∈[M ] gt (λk (t)).
Hence,
ti+1
e
1 X
e ti ≥ Ω
ti+1 − e , θt ≤ O (e
ti+1 − e ti+1 − e
ti )η + log 1/η = O (e ti )η .
η2
t=e
ti
Case (B). From Lemma 9.13.6 we have that the sum of angle over this time is
ti+1
e q
X
θt = O ti+1 − e
e ti+1 − e
ti + η(e ti ) .
t=e
ti
Case (C). We claim Gt will become larger than 0.99gmax (t) in O(η −0.1 ) steps with
probability at least 1 − O(η 12 ), because of the η 100 perturbation added per Θ(η −0.1 )
ti , θt ≤ Ω(η), then by Lemma 9.13.11, we know that in
steps. If for some t > e
O(log 1/η) steps after the perurbation, with probability at least 1 − O(η 12 ), we have
θt0 ≥ Ω(η) for some t0 ≤ t + O(log 1/η). And thus we can apply Lemma 9.13.7
and θt0 = Ω(1). By Lemma 9.13.4, we know the θet = Ω(1) as well, where e
t is the
t. Then by Lemma 9.13.10, Get+2 ≥ Get + Ω(η). If
largest step in N1 yet smaller than e
t+2 ∈
e / N1 , then e t + 3 ∈ N1 . Again by Lemma 9.13.9, we have
t + 2 must be in N0 and e
Get+3 ≥ Get+2 − O(η 2 ) ≥ Get + Ω(η). Thus Gt will increase Ω(η) every O(log 1/η) steps
among those steps in N1 (among the steps in N1 and N0 , Gt decreases at most O(η 3 )
472
ti + O(log 1/η) + O(η −0.1 ) = e
ti+1 ≤ e
by Lemma 9.13.10. Thus e ti + O(η −0.1 ). Thus,
Peti+1
t=e
θ = O(η −0.1 ).
t t i
Now it remains to upper bound the number of occurrence of (A),(B) and (C).
Since our goal is to show average angle is O(η), which is equal to the average angle in
case (A), so the number of occurrence of case (A) doesn’t matter. For case (B), if it is
followed by case (A), then there is an Ω(1/η 2 ) gap before next occurrence of (B). If
(B) is followed by case (C), then by Lemma 9.13.10, it takes at least Ω(1/η 2 ) steps to
escape from (B). Thus we can have O(1) occurrence of case (B). For the same reason,
there could be at most O(1) occurrence of case (C).
All in all, with probability at least 1 − O(η 12 · η 12 ) = 1 − O(η 10 ), we must have
t2
X X q
t0 )η + O(1) · O(η −0.1 ) + O(
θt ≤ O (e
t` − e ti+1 − e
e ti
t=0 i:case (B)
s X X
t0 )η + O(η −0.1 ) + O
≤ O (e
t` − e ti+1 − e
(e ti ) 1
i:case (B) i:case (B)
= O (t2 η) + O (t2 )
= O(t2 η)
where we use t2 ≥ Ω( η12 ) in the last step and and the number of occurrence of case
(B) is O(1) in the second to the last step.
Lemma 9.13.6. Consider the setting of Lemma 9.13.5. Consider any time interval
[t, t0 ], where t ≤ ` < t0 , xη (`)xη (` + 1) ⊂ Y and Ω(η) ≤ Gl := |hv1 (t), x
eη (t)i| ≤
λ1 (`)η
2
− Ω(η), we have that
X X p
θt = θt ≤ O( t0 − t + (t0 − t)η).
t∈[t,t0 ] t∈N0 ∪N1 ∪N2
Proof of Lemma 9.13.6. The noisy update rule for Normalized GD, as derived in
Lemma 9.10.9, which says that the Normalized GD update is very close to the update
473
in a quadratic model with an additional O(η 2 ) error. Keeping this in mind, we then
divide our trajectory in the interval (t, t0 ) as per Algorithm 9 into three subsets
N0 , N1 , N2 . (Please see Section 9.13.1 for a summary on the properties of these 3 sets.)
Consider any t ∈ N1 . Using the behavior of Gt from Lemma 9.13.10, we can
show that in each of the time-frames, Gt+2 ≥ (1 + Ω(sin2 θt ))Gt − O(η 2 (η + ηt )) ≥
Gt + Ω(θt2 η) − O(η 2 (η + θt )).
P
Next we want to telescope over Gt+2 − Gt to get an upper bound for t∈N1 θt . If
t + 2 is also in N1 then it’s fine. If t + 2 ∈ N0 , then t + 3 ∈ N1 by Lemma 9.13.1 and
we proceed in the following two cases.
Since total increase in Gt during this interval can is most O(η), we conclude that
2
P P
t∈N1 θt = O(1) + η t∈N1 (η + θt ) and thus it holds that
X s X p s X
θt ≤ (t0 − t) θt2 ≤ t0 − t · O 1 + η θt + η 2 (t0 − t)
t∈N1 t∈N1 t∈N1
p
≤O( t0 − t + (t0 − t)η)
Moreover, by Lemma 9.13.3, we must have θt < θt−1 + O(η) for any time t ∈ N2 ,
and t − 1 must be in N1 . By Lemma 9.13.9, we have θt ≤ Ω(θt−2 ) for any t ∈ N0 and
474
t − 2 must be in N2 . That implies,
X X p
= θt ≤ O( t0 − t + (t0 − t)η),
t∈[t,t0 ] t∈N0 ∪N1 ∪N2
2. θt ≥ Ω(1).
Proof of Lemma 9.13.7. We will prove by contradiction. Suppose neither of the two
condition happens, we will show θt grows exponentially and thus the condition (2)
must be false in O(log 1/η) steps.
ηλ1 (t) ηλ1 (t)
First of all, because θt = O(1) and Gt ≤ 2
whenever ke
xη (t)k ≤ 2
+ Ψnorm η 2 ,
ηλ1 (t) ηλ1 (t+1)
by Lemma 9.13.9, we know if ke
xη (t)k ≤ 2
+ Ψnorm η 2 , then ke
xη (t + 1)k > 2
+
ηλ1 (t+2)
Ψnorm η 2 . And thus ke
xη (t + 2)k ≤ 2
+ Ψnorm η 2 . In other words, t ∈ N1 ∪ N2 for
t + 2k ∈ N1 for all natural numbers k with e
all t. Therefore, for e t + 2k ≤ t. Moreover,
Get+2k ≥ Gt − O(kη 2 ) = Ω(η) by Lemma 9.13.10 for k ≤ O( η1 ).
Now, we can use Equation (9.22) (Lemma 9.10.9) to show that the Normalized GD
update is equivalent to update in quadratic model, up to an additional O(η 2 ) error.
475
Similar to Lemma 9.9.9, consider the coordinate k, we have that
The third step follows from using the same argument as the one used for the quadratic
update in Lemma 9.9.9 and the assumption that Gt ≥ 0.99gt (λk (t)). The final step
holds true because we can pick α as a large enough constant and by assumption
|hvk (t),e
xη (t)i|
|hv1 (t),e
xη (t)i|
≥ αη.
We then bound vk (t) − vk (e t)) by O(η 2 (t − e
t) and Φ(xη (t) − Φ(xη (e t)) using
Lemma 9.10.11. Combining everything, we conclude that at least one of the two
assumptions has to break for some t ≤ e
t + O(log 1/η).
eη (t)i ≤ O(η 2 ).
hvk (t), x
and
∇2 L(Φ(xη (t)))(xη (t) − Φ(xη (t))) ≤ 0.5λ1 (t)η + Ψnorm η 2 .
476
The proof of Lemma 9.13.8 is very similar to the proof of Lemma 9.13.7 and thus
we omit the proof. The only difference will be that we need to use Lemma 9.9.10 in
place of Lemma 9.9.9, when we use the result for the quadratic model.
Here, we will state two important lemmas that we used for the proof of Lemma 9.13.5,
which is about the behavior of the iterate along the top eigenvector. Lemma 9.13.10
can be viewed as perturbed version for Lemma 9.9.7 in the quadratic case, and
We assume in all the lemmas, that Equation (9.30) holds true for the time under
consideration, which we showed in Lemma 9.11.1, and also that we start Phase II
from a point where the alignment along the top eigenvector is non negligible.
The following lemmas give the properties of dynamics in the top eigenspace in
Phase II for one-step and two-step updates respectively. Recall we use Gt to denote
the quantity |hv1 (t), x
e(t)i|.
Lemma 9.13.9 (Behavior along the top eigenvector, one step). For sufficiently small
1
η, consider any time t, such that xη (t) ∈ Y , ke
xη (t)k ≤ 2
ηλ1 (t) + Ψnorm η 2 and
Gt ≥ Ω(η) holds true , the following holds:
λ1 (t)
Gt+1 ≥ − 1 Gt − O(η 2 )
ke
xi (t)k
Lemma 9.13.10 (Behavior along the top eigenvector, two steps). For sufficiently
xη (t)k ≤ 21 ηλ1 (t) + Ψnorm η 2 and
small η, consider any time t, such that xη (t) ∈ Y , ke
477
Gt ≥ Ω(η) holds true, the following holds:
∇L(xη (t))
v1 (t), = cos ∠(v1 (t), ∇L(xη (t)))
k∇L(xη (t))k
From Lemma 9.10.11, we have kΦ(xη (t)) − Φ(xη (t + 1))k ≤ O(ξη 2 ), which further
implies, k∇2 L(Φ(xη (t + 1))) − ∇2 L(Φ(xη (t)))k ≤ O(η 2 ). Thus, we can use Theo-
νξη 2
rem 9.14.4 to have kv1 (t) − v1 (t + 1)k ≤ O( λ1 (t)−λ 2 (t)
) = O(η 2 ). From Lemma 9.10.12,
478
we have |hv1 (t), Φ(xη (t + 1)) − Φ(xη (t))i| ≤ O(η 3 ). Thus we have that
and therefore,
hv1 (t + 1), x
eη (t + 1)i − hv1 (t), x
eη (t)i
λ1 (t)
Gt+1 = 1 − η Gt + O((θt + η)η 2 )
ke
x(t)k
Therefore, we have the following inequality by applying the same argument above
to t + 1:
λ1 (t + 1)
Gt+2 = 1 − η Gt+1 + O((θt+1 + η)η 2 )
ke
x(t + 1)k
λ1 (t + 1) λ1 (t)
= 1−η 1−η Gt
ke
x(t + 1)k kex(t)k
λ1 (t + 1)
+ 1−η · O((θt + η)η 2 ) + O((θt+1 + η)η 2 )
ke
x(t + 1)k
479
η∇2 L(Φ(xη (t)))
Next we will show ke
xη (t + 1)k − (I − ke
xη (t)k
)e
xη (t) = O(η 2 θt ). For con-
venience, we denote ∇2 L(Φ(xη (t))) by H. First we have that
2 2
∇L(xη (t)) x
eη (t)
eη (t) − ηH
x − x eη (t) − ηH
k∇L(xη (t))k ke
xη (t)k
∇L(xη (t)) x
eη (t) ∇L(xη (t)) x
eη (t)
= 2exη (t) − ηH + , ηH −
k∇L(xη (t))k ke xη (t)k k∇L(xη (t))k ke xη (t)k
2 ∇L(xη (t)) x
eη (t) ∇L(xη (t)) x
eη (t)
= 2H x eη (t) − ηH + ,η −
k∇L(xη (t))k ke xη (t)k k∇L(xη (t))k ke xη (t)k
2 ∇L(xη (t)) x
eη (t) ∇L(xη (t)) x
eη (t)
eη (t) − ηH
= 2H x + η − cos α
k∇L(xη (t))k ke xη (t)k k∇L(xη (t))k ke xη (t)k
∇L(xη (t)) ∇L(xη (t))
where α is the angle between − xxeηη (t)k
k∇L(xη (t))k ke
(t)
eη (t)−ηH 2
and 2H x k∇L(xη (t))k
+ x
eη (t)
ke
xη (t)k
.
Note that and that both ∠(e xη (t), v1 (t)), ∠(∇L(xη (t)), v1 (t)) = O(ηt + η), we have
∇L(xη (t)) x
eη (t) 2 ∇L(xη (t)) x
eη (t)
that the angle between k∇L(x η (t))k
+ ke
xη (t)k
and 2H x
e η (t) − ηH k∇L(xη (t))k
+ ke
xη (t)k
∇L(xη (t)) x
eη (t)
is at most O(ηt + η). Further note that k∇L(xη (t))k
− ke
xη (t)k
is perpendicular to
∇L(xη (t)) x
eη (t)
k∇L(xη (t))k
+ ke
xη (t)k
, we know cos α ≤ O(θt + η). Therefore we have that
∇L(xη (t)) x
eη (t)
eη (t) − ηH
x − x
eη (t) − ηH
k∇L(xη (t))k ke
xη (t)k
2 2
∇L(xη (t))
eη (t) − ηH k∇L(x
x η (t))k
eη (t) − ηH kexxeηη (t)k
− x (t)
=
∇L(xη (t))
eη (t) − ηH k∇L(x
x η (t))k
eη (t) − ηH kexxeηη (t)k
+ x (t)
=O(η 2 (η + θt )).
480
Thus we have proved a perturbed version of Lemma 9.9.6, that is,
1 λj (t)(λ1 (t) − λj (t)) 2
ke
xη (t)k + ke
xη (t + 1)k ≤ ηλ1 (t) 1 − min sin θt + O(η 2 (η + θt )).
2λ1 (t) 2≤j≤M λ21 (t)
The proof of the first inequality is completed by plugging the above equation into
Equation (9.33).
The second inequality is immediate by noting that ηθt2 + C 2 η 3 ≥ 2Cη 2 θt for any
C > 0.
Threshold
In this section, we will show that the projection along the top eigenvector cannot
drop below a certain threshold. Formally, we will show the following lemma that
predicts the increase in the projection Gt = |hv1 (t), x
eη (t)i| along the top eigenvector in
O(log 1/η) steps, whenever the projection drops below a certain threshold gmax (t) :=
maxk∈[M ] gt (λk (t)).
Lemma 9.13.11. Denote r = η 100 . For any constant 0 < β, there is a constant
α > 0, such that for any step t and xη (t) ∈ Y with the following conditions hold:
1. βη ≤ Gt ≤ 0.98gmax (t)η.
Then, with probability at least 1 − η 12 , after perturbing xη (t) with noise generated
uniformly from B0 (r) followed by tesc + 2 = Θ(log 1/η) steps of Normalized GD
481
(2:M )
(t = t+tesc +2), it holds that Pt,Γ eη (t) ≥ Ω(η 2 ) provided that xη (t0 )xη (t0 + 1) ⊂ Y
x
for all time t ≤ t0 ≤ t.
Lemma 9.13.12. Consider any time t, with xη (t) ∈ Y . Suppose xη (t) satisfies the
conditions in Lemma 9.13.11. The constants cesc , gmax (t), r, α, and β have been taken
from Lemma 9.13.11. Define Xstuck as the region in Bxη (t) (r) such that starting from
any point u ∈ Xstuck , the points {u(e
t)}et∈[tesc ] , with u(0) := u, obtained using tesc steps
of Normalized GD satisfy:
(2:M )
Pt,Γ t) − Φ(xη (t))) ≤ αη 2 ,
(u(e t ∈ [tesc ],
for all e (9.34)
(2:M )
where Pt,Γ denotes the subspace spanned by v2 (t), . . . , vM (t).
Consider two points u and w in Bxη (t) (r), with the property w = u + Kη 12 rvk (t), 4
where K ≥ 1 can be arbitrary number and vk (t) denotes the eigenvector corresponding
to the eigenvalue λk (t) = argmaxλi (t)|1≤i≤M gt (λi (t)). Then, at least one of u and w is
not present in the region Xstuck .
We will first prove Lemma 9.13.11 and then we turn to the proof of Lemma 9.13.12.
Proof of Lemma 9.13.11. Lemma 9.13.12 shows that if some point u ∈ Bxη (r) is in
Xstuck , then it holds that
The other words, Xstuck is only a thin slice of width at most η 12 r of Bxη (t) (r), which
implies vol(Xstuck )/vol(Bxη (t) (t)) = O(η 1 2), where vol(·) denotes the volume of the
set.
4 12
η can be replaced by any η p , and the final success probability in Lemma 9.13.11 becomes
1 − η p−2 .
482
Proof of Lemma 9.13.12. We will prove by contradiction. Consider the two sequences
obtained with tesc steps of Normalized GD, {u(e
t), w(e
t)}et∈[tesc ] :
∇L(u(e
t)) ∇L(w(e
t))
u(0) = u, w(0) = w, u(e t) − η
t) = u(e , w(e t) − η
t + 1) = w(e .
∇L(u(e
t)) ∇L(w(e
t))
(2:M )
Pt,Γ (u(tesc ) − w(tesc )) ≥ Ω(η 2 ).
v v
uM uM
uX uX
max t hvi (t), w(
ee t)i2 , t hvi (t), u t)i2 ≤ λj (t)η + O(Ψnorm η 2 ).
e(e
i=j i=j
Note that the condition has been slightly changed to use {vi (t)} as reference
coordinate system and Φ(xη (t)) as reference point. The above lemma follows from the
fact that both u(0) and w(0) are r-close to xη (t), which itself satisfies the alignment
condition (Equation (9.30)). Thus, both u(0) and w(0) initially follow the desired
condition. Since, both the trajectories follow Normalized GD updates, the proof
will follow from applying the same technique used in the proof of Lemma 9.11.1.
Another result to keep in mind is the following modified version of Corollary 9.12.3,
Lemma 9.13.14.
483
Lemma 9.13.14. If u t) ≤ η λ12(t) +Ψnorm η 2 , then v1 (t)> u
e(e t + 1) ≥ v1 (t)> u
e(e t) −
e(e
O(η 2 ). The same results hold for w(
ee t) as well.
If w(
eet) ≤ η λ12(t) + Ψnorm η 2 , u t) ≤ η λ12(t) + Ψnorm η 2 , and z(γ) denotes γu(0) +
e(e
∇L(x)
(1 − γ)w(0) for any γ ∈ [0, 1], let F (x) = x − η k∇L(x)k , we have
The above lemma uses {vi (t)} as reference coordinate system and Φ(xη (t)) as
reference point. The above lemma follows from showcasing Normalized GD updates
of u(e
t) and w(e
t) as equivalent to the update in a quadratic model, with an additional
noise of O( νζ
µ
η 2 ), similar to Equation (9.29).
Continuing with the proof of Lemma 9.13.12, we first consider the behavior of u.
Since u ∈ Xstuck , we have for any time-step e
t:
t) − Φ(xη (t))
u(e
min − sv1 (t) ≤ η (9.35)
s∈{±1} t) − Φ(xη (t))
u(e
Further, applying the same technique from Lemma 9.13.10, we can show that
|hv1 (t), u
e(0)i − hv1 (t), x
eη (t)i| ≤ O(r).
484
t) − Φ(xη (t))i − hv1 (t), u(0) − Φ(xη (t))i ≤ O(η 3 tesc ) for all even e
Hence, hv1 (t), u(e t∈
[tesc ]. With tesc ∼ O(log 1/η) , we must have
for any t ≤ e
t ≤ t + tesc . The same argument applies to w(·) as well. By Equation (9.34),
we know u
e(e
t) , w(
eet) = o(η).
Now, we consider the behavior of w(·) and u(·). Consider an even time step
0≤e
t ≤ tesc . From the update rule of w and u, we have
t + 2) = F (2) (w(e
t + 2) − u(e
w(e t)) − F (2) (u(e
t)),
∇L(v)
where the function F : RD → RD , F (v) = v − η k∇L(v)k is the one-step update rule of
Normalized GD and F (2) = F ◦ F .
Now, we use taylor expansion of F around u(e
t) to get
1 2
max ∇2 F (2) (z(γ))) w(e t) − u(e
t)
γ∈[0,1] 2
1 2
= max k∇[∇F (F (z(γ))))∇F (z(γ))]k w(e t) − u(e
t)
γ∈[0,1] 2
1 1 2
≤ max η · O 2 + 2 t) − u(e
w(e t)
γ∈[0,1] k∇L(z(γ))k k∇L(F (z(γ)))k
2 2
· max ∂ 2 (∇L)(z(γ)) , ∇2 L(z(γ)) , ∂ 2 (∇L)(F (z(γ))) , ∇2 L(F (z(γ)))
1 1 2
≤ max 2 + 2 · O(η w(e t) − u(e
t) ).
γ∈[0,1] k∇L(z(γ))k k∇L(F (z(γ)))k
485
Using taylor expansion: ∇L(z(γ)) = ∇2 L(Φ(xη (t)))(z(γ)−Φ(xη (t)))+O(ν kz(γ) − Φ(xη (t))k2 )
and hence, we must have k∇L(z(γ))k ≥ Ω(η).
With u
e(e
t) = o(η), we can apply Lemma 9.13.14 to show
hv1 (t), ∇2 L(Φ(xη (t)))[F (z(γ)) − Φ(xη (t))]i ≥ hv1 (t), ∇2 L(Φ(xη (t)))[z(γ) − Φ(xη (t))]i +O(η 2 ).
1 2
w(e t + 2) − ∂F (2) (u(e
t + 2) − u(e t) − u(e
t))(w(e t)) ≤ O( w(e
t) − u(e
t) ), (9.38)
η
" #
∇L(u(e
t))∇L(u(et))> ∇2 L(u(et))
t)) = I − η I −
Aet := ∂F (u(e 2 ,
∇L(u(et)) ∇L(u(e t))
and µ(e
t) is given by
1 1
µ(e
t) = max 2 + , (9.39)
γ∈[0,1]:z(γ)=γu(e
t)+(1−γ)w(e
t) k∇L(z(γ))k k∇L(F (z(γ)))k2
Now we define Bet and claim Aet can be approximated as below with kBet k = O(η).
Furthermore, kAet k ≤ O(1).
∇2 L(Φ(xη (t)))
Bet = Aet − I − η I − v1 (t)v1 (t)>
,
hv1 (t), ∇2 L(Φ(xη (t)))[u(e
t) − Φ(xη (t))]i
The following strategies have been used to obtain the above approximation. First,
∇2 L(u(e
t)) − ∇2 L(Φ(xη (t))) ≤ O( u(e
t) − Φ(xη (t)) ) = O( u(e
t) − Φ(u(0)) ) +
O(kΦ(u(0)) − Φ(xη (t))k) = O(η). Therefore, Using taylor expansion, ∇L(u(e
t)) =
∇2 L(Φ(xη (t)))(u(e
t) − Φ(xη (t))) + O(η 2 ) = u t) + O(η 2 ).
e(e Using the update
486
from Equation (9.36) and note tesc = O(log 1/η), we must have ∇L(u(e
t)) ≥
t) − Φ(xη (t))i − O(η 2 ) ≥ βη − O(η 3 tesc ) ≥ Ω(η). Finally we use the condi-
hv1 (t), u(e
>
∇L(u(e
t)) ∇L(u(e
t))
tion from Equation (9.35) to show that ∇L(u(et)) = v1 (t)v1 (t)> + O(η).
k k k∇L(u(et))k
Similarly, we can show that:
∇2 L(Φ(xη (t)))
= I − η I − v1 (t)v1 (t)>
Aet+1 + Bet+1 ,
ηλ1 (t) − hv1 (t), u
e(e
t)i
Y
err(e t + 2) − u(e
t) := w(e t + 2) − H(u(i))(w(0) − u(0)), (9.40)
0≤i≤e
t:2|i
Finally, we use Lemma 9.13.15 and Lemma 9.13.16 to handle the main and error
terms in Equation (9.42),
Y
|hvk (t), w(tesc ) − u(tesc )i| = vk (t)> t))(w(0) − u(0)) + vk (t)> err(tesc )
H(u(e
t≤tesc :2|e
0≤e t
Y
≥ vk (t)> t))(w(0) − u(0)) − kerr(tesc )k
H(u(e
t≤tesc :2|e
0≤e t
Lemma 9.13.15.
Y
vk (t)> t))(w(0) − u(0)) ≥ Ω(η 2 ).
H(u(e
t≤tesc :2|e
0≤e t
487
Lemma 9.13.16.
Proof of Lemma 9.13.15. For simplicity of presentation, we have used Met to define
" #" #
h i ∇2 L(Φ(xη (t))) h i ∇2 L(Φ(x (t)))
η
I − η I − v1 (t)v1 (t)> I − η I − v1 (T )v1 (t) >
.
ηλ1 (t) − hv1 (t), u
e(e
t)i hv1 (t), u
e(et)i
Y
t))(w(0) − u(0))
H(u(e
t≤tesc :2|e
0≤e t
Y
= η3r H(u(e
t))vk (t)
t≤tesc :2|e
0≤e t
Y
= η3r Aet+1 Aet vk (t)
t≤tesc :2|e
0≤e t
Y
= η3r Met vk (t) + rem,
t≤tesc :2|e
0≤e t
where using the bounds on {Aet , Aet+1 , Bet , Bet+1 }0≤et≤tesc , we have
X Y
kremk ≤ max (kBet k + Bet+1 ) · max (kAet k + Aet+1 ) · Mj
t≤tesc
e t≤tesc
e
t≤tesc :2|e
0≤e t 0≤j≤tesc :2|j,j6=e
t
X Y
≤ O(kη 1 2r) · Mj . (9.41)
t≤tesc :2|e
0≤e t 0≤j≤tesc :2|j,j6=e
t
From the behavior of u(e t) from Equation (9.37), we have hv1 (t), u
e(et)i ≤
1
1 − 200M gmax (t)η. Recall that gmax (t) was chosen as max1≤k≤M gt (λk (t)). It
turns out that for the chosen upper bound of gmax (t), vk (t) acts as the top eigenvector
t ≤ tesc .
of Met for any e
488
For all j ∈ [2, M ] and e
t ∈ [tesc ], we have:
" #" #
λj (t)/λ1 (t) λj (t)/λ1 (t)
Met vj (t) = 1 − η 1−η vj (t),
η − hv1 (t), u(e
t) − Φ(xη (t))i hv1 (t), u(et) − Φ(xη (t))i
with Met v1 (t) = v1 (t). When hv1 (t), u t)i ≤ gmax (t), kMet vk t)k ≥ kMet v1 (t)k, for all
e(e
j ≥ 2. Furthermore, kMet vj (t)k maximizes when j = k. Therefore,
" #" #
λk (t)/λ1 (t) λk (t)/λ1 (t)
kMet k = 1 − η 1−η
η − hv1 (t), u(e
t) − Φ(xη (t))i hv1 (t), u(e
t) − Φ(xη (t))i
λk (t) λk (t)
≥ λ1 (t) − 1− , for all e t ∈ [tesc ],
λ1 (t) − 0.99gmax (t) 0.99gmax (t)
* +
Y
1
vk (t), Kη 2r Met vk (t) = Θ(η 2 ).
t≤tesc :2|e
0≤e t
That is, we select the time step e t, where the magnitude of the useful term
Q
t Me
t≤tesc :2|e
0≤e t vk (t) along the eigenvector vk (t) reaches cesc η. With gmax (t) = gt (λk (t)),
h ih i
λk (t)/λ1 (t) λk (t)/λ1 (t)
we have 1 − 1−0.99g max (t)
1 − 0.99gmax (t)
≥ 1.001 and so, we just need tesc ≤
O(log(cesc /η)).
With this choice of tesc , we must have from Equation (9.41), kremk ≤ O(η 3 ) and
therefore
* +
Y
vk (t), t))(w(0) − u(0))
H(u(e ≥ Ω(η 2 ) − O(η 3 ) ≥ Ω(η 2 ),
t≤tesc :2|e
0≤e t
Thus, we have shown that with the appropriate choice of tesc , the magnitude of
2
Q
t H(u(t))(w(0) − u(0)) can reach at least Ω(η ) along the eigenvector vk (t).
t≤tesc :2|e
0≤e
e
489
Proof of Lemma 9.13.16. We first recall the definition of the error term:
Y
err(e t) − u(e
t) := w(e t) − H(u(i))(w(0) − u(0)), (9.42)
t−2:2|i
0≤i≤]
2
t + 2) − H(u(e
err(e t) ≤ %( w(e
t))err(e t) − u(e
t) /η)
We will use induction hypothesis to show for all even t ≤ tesc , err(e
t) ≤
2
t) − w(e
C u(e t) /η for some sufficiently large constant C. The base case is t = 0
which holds by definition. Now suppose the induction hypothesis holds for all even
0 ≤ t0 ≤ e
t} and below we will show for e
t + 2.
t − 2, we know
First, by induction hypothesis at e
Y 2
t) − u(e
w(e t) − H(u(i))(w(0) − u(0)) ≤ err(e
t) ≤ C( w(e
t) − u(e
t) /η).
t−2:2|i
0≤i≤e
Y
t) − u(e
w(e t) ≤ (1 + O(η)) H(u(i))(w(0) − u(0)) .
0≤i≤e
t:2|i
490
Meanwhile, we have
err(e
t + 2)
Y
≤ u(e
t + 2) − w(e
t + 2) + H(u(i))(w(0) − u(0))
0≤i≤e
t+2:2|i
Y
≤ u(e
t + 2) − w(e
t + 2) + H(u(e t) − u(e
t)) w(e t) − H(u(i))(w(0) − u(0))
0≤i≤e
t:2|i
H(u(e
t)) t) − u(e
w(e t)
Y
≤(1 + O(η)) H(u(e
t)) H(u(i))(w(0) − u(0))
0≤i≤e
t:2|i
Y
=(1 + O(η)) H(u(i))(w(0) − u(0))
0≤i≤e
t+2:2|i
≤(1 + O(η)) t + 2) − w(e
u(e t + 2) + err(e
t + 2) .
2 2 2
H(u(e
t)) t) − u(e
w(e t) ≤ (1 + O(η)) u(e
t + 2) − w(e
t + 2) + O(η 2 ) err(e
t + 2)
491
Denote by ϕ = minet≤tesc kMet k. From previous analysis we know ϕ > 1 and thus we
have
η err(e
t + 2)
2
≤% w(e
t + 2) − u(e
t + 2) + H(u(e
t))err(e
t) η
2 2
≤% w(e
t + 2) − u(e
t + 2) + C H(u(e
t)) t) − u(e
w(e t)
2 C 2
≤% w(e
t + 2) − u(e
t + 2) +( t + 2) − w(e
+ O(η)) u(e t + 2) + O(η 2 ) err(e
t + 2) .
ϕ
C 2
t + 2) ≤ (% + O(η)) w(e
η err(e t + 2) − u(e
t + 2) .
ϕ
Proof of Theorem 9.5.7. According to the proof of Theorem 9.5.4, we know for all t,
it holds that Rj (xη (t)) ≤ O(η 2 ). Thus SL (xη (t), ηt ) = ηt · sup0≤s≤ηt λ1 (∇2 L(xη (t) −
k∇L(xη (t))k
s∇L(xη (t)))) = ηt (λ1 (t) + O(η)), which implies that [SL (xη (t), ηt )]−1 = ηλ1 (t)
+
ke
xη (t)k
O(η) = ηλ1 (t)
+ O(η). The proof for the first claim is completed by noting that
1
η
(ke
xη (t)k+ ke
xη (t + 1)k) = λ1 (t) + O(η + θt ) as an analog of the quadratic case.
ke
xη (t)k
p
For the second claim, it’s easy to check that L(xη (t)) = √ + O(ηθt ).
2λ1 (t)
kexη (t)k ke
xη (t+1)k
p p
Thus have L(xη (t)) + L(xη (t + 1)) = √ +√ + O(η(θt + θt+1 )). Note
2λ1 (t) 2λ1 (t+1)
p
that λ1 (t) − λ1 (t + 1) = O(η 2 ) and θt+1 = O(θt ), we conclude that L(xη (t)) +
q
2
L(xη (t + 1)) = η λ1 (∇ L(x η (t))
p
2
) + O(ηθt ).
492
9.14 Some Useful Lemmas About Eigenvalues and
Eigenvectors
λ(X0 ) = λ0 , u(X0 ) = u0 ,
and
Moreover, the functions λ and u are C ∞ on N (X0 ) and the differentials at X0 are
dλ = u>
0 (dX)u0 , du = (λ0 In − X0 )† (dX)u0 .
λi − λ
bi ≤ Σ − Σ
b .
2
The next theorem is the Davis-Kahan sin(θ) theorem, that bounds the change in
the eigenvectors of a matrix on perturbation. Before presenting the theorem, we need
493
to define the notion of unitary invariant norms. Examples of such norms include the
frobenius norm and the spectral norm.
b ∈ Rp×p be symmetric,
Theorem 9.14.4. [Davis-Kahan sin(θ) theorem [260]] Let Σ, Σ
with eigenvalues λ1 ≥ . . . ≥ λp and λ
b1 ≥ . . . ≥ λ
bp respectively. Fix 1 ≤ r ≤ s ≤ p, let
Here Θ(Vb , V ) ∈ Rd×d , with Θ(Vb , V )j,j = arccos σj for any j ∈ [d] and Θ(Vb , V )i,j = 0
for all i 6= j ∈ [d]. σ1 ≥ σ2 ≥ · · · ≥ σd denotes the singular values of Vb > V. [sin Θ]ij is
defined as sin(Θij ).
√
9.15 Analysis of L
The analysis will follow the same line of proof used for the analysis of Normalized
GD. Hence, we write down the main lemmas that are different from the analysis of
Normalized GD. Rest of the lemmas are nearly the same and hence, we have omitted
them.
√
The major difference between the results of Normalized GD and GD with L is in
the behavior along the manifold Γ (for comparison, see Lemma 9.10.12 for Normalized
√
GD and Lemma 9.15.10 for GD with L). Another difference between the results of
494
√
Normalized GD and GD with L is in the error rates mentioned in Theorem 9.5.4
and Theorem 9.5.6. The difference comes from the stronger behavior of the projection
along the top eigenvector that we showed for Normalized GD in Lemma 9.13.10, but
√
doesn’t hold for GD with L (see Lemma 9.15.6). This difference shows up in the
sum of angles across the trajectory (for comparison, see Lemma 9.13.5 for Normalized
√
GD and Lemma 9.15.4 for GD with L), and is finally reflected in the error rates.
9.15.1 Notations
The notations will be the same as Section 9.10 . However, here we will use x
eη (t) to
1/2
denote (2∇2 L(Φ(xη (t)))) (xη (t) − Φ(xη (t))). We will now denote Y as the limiting
flow given by Equation (9.7).
Z τ
1 ⊥
X(τ ) = Φ(xinit ) − PX(s),Γ ∇λ1 (X(s))ds, X(τ ) ∈ Γ. (9.7)
8 s=0
v
uM
uX
eη (t̄)i2 ≤ ηλj (t̄) + O(η 2 ),
t hvi (t̄), x (9.43)
i=j
495
Proof. The proof exactly follows the strategy used in Lemma 9.11.1. We can use the
noisy update formulation from Lemma 9.15.7 and the bound on the movement in Φ
from Lemma 9.15.10 to get for any time t with t̄ ≥ t ≥ t0 (similar to Equation (9.29)):
∇2 L(Φ(xη (t)))
eη (t + 1) = I − η
x eη (t) + O(η 2 ) + O(kxη (t) − Φ(xη (t))k η),
x
ke
xη (t)k
Hence, similar to Lemma 9.12.1, we can derive the following property that continues
to hold true throughout the trajectory, once the condition Equation (9.43) is satisfied:
Lemma 9.15.2. There is some constant Ψnorm > 0, such that if the condition Equa-
ηλ1 (t)
tion (9.43) holds true and ke
xη (t)k > 2
, the following must hold true:
ηλ1 (t)
ke
xη (t + 1)k ≤ + Ψnorm η 2 .
2
We also have the counterpart of Corollary 9.12.3 with the same proof, which follows
√
from using the noisy update of GD on L from Lemma 9.15.7 and using the quadratic
update result from Lemma 9.9.5.
ηλ1 (t)
Lemma 9.15.3. If at time t, ke
xη (t)k ≤ 2
+ Ψnorm η 2 and stability condition
(Equation (9.43)) holds true, the following must hold true:
v1 (t + 1)> x
eη (t + 1) ≥ v1 (t)> x
eη (t) − O(η 2 ).
Z τ
1 ⊥
X(τ ) = Φ(xinit ) − PX(s),Γ ∇λ1 (X(s))ds, X(τ ) ∈ Γ. (9.7)
8 s=0
496
Let T2 be the time up until which solution to the limiting flow exists.
Lemma 9.15.10 shows the movement in Φ, which can be informally given as follows:
in each step t,
η2 ⊥
Φ(xη (t + 1)) − Φ(xη (t)) = − Pt,Γ ∇λ1 (∇2 L(Φ(xη (t)))) + O(η 2 (θt + kxη (t) − Φ(xη (t))k)),
8
(9.44)
Average of the angles The first lemma shows that the sum of the angles in an
interval [0, t2 ] of length Ω(1/η 2 ) is at most O(t2 η 1/2 ).
Lemma 9.15.4. For any T2 > 0 for which solution of Equation (9.7) exists, con-
sider an interval [0, t2 ], with Ω(η −2 ) = t2 ≤ bT2 /η 2 c. Suppose Algorithm 8 is
run with learning rate η for t2 steps, starting from a point xη (0) that satisfies (1)
maxj∈[D] Rj (xη (0)) ≤ O(η 2 ), and (2) |v1 (0), xη (0) − Φ(xη (0))| ≥ βη for some constant
√ 2
0.5λ1 (0)
0 < β independent of η, with ke xη (0)k ≤ 2
η + Ψnorm η 2 , the following holds true
with probability at least 1 − η 10 :
t2
1X √
θ` ≤ O ( η) ,
t2 `=0
497
Proof of Lemma 9.15.4. The proof is very similar to the proof of Lemma 9.13.5, except
we replace Lemma 9.13.6 by Lemma 9.15.5 in the analysis of case (B). Hence the final
√
average angle becomes O( η).
Lemma 9.15.5. Consider the setting of Lemma 9.15.4. Consider any time interval
[t, t0 ], where t ≤ ` < t0 , xη (`)xη (` + 1) ⊂ Y and Ω(η) ≤ Gl := |hv1 (t), x
eη (t)i| ≤
λ1 (`)η
2
− Ω(η), we have that
X X p √
θt = θt ≤ O( t0 − t + (t0 − t) η).
t∈[t,t0 ] t∈N0 ∪N1 ∪N2
Proof. The proof will follow exactly as Lemma 9.13.6, except we replace Lemma 9.13.10
√ √
by Lemma 9.15.6, which changes the rate into O( t0 − t + (t0 − t) η)
Lemma 9.15.6. [Behavior along the top eigenvector] Consider any time t, such that
xη (t)k ≤ 21 ηλ1 (t) + Ψnorm η 2 holds true, then the following holds
xη (t) ∈ Y , where ke
true:
Proof. Here, we will follow a much simpler approach than Lemma 9.13.10 to have a
weaker error bound. The stronger error bounds in Lemma 9.13.10 were due to the
very specific update rule of Normalized GD.
p
eη (t) = 2∇2 L(Φ(x))(xη (t) − Φ(xη (t))). By Lemma 9.15.10, we have
First recall x
p
kΦ(xη (t + 1)) − Φ(xη (t))k = O(η 2 ), thus x eη (t) = 2∇2 L(Φ(x))(xη (t +
eη (t + 1) − x
1) − xη (t)) = η 2∇2 L(Φ(x)) ∇L(x
√ η (t)) . From Lemma 9.15.7, we have
p
2 L(xη (t))
498
where we have used the fact that ke
xη (t)k = O(η).
Hence, the update is similar to the update in a quadratic model, with ∇2 L(Φ(xη (t)))
guiding the updates with an additional O(η 2 ) perturbation. As a result we also get a
O(η 2 ) perturbation in Gt . Here we use the assumption Gt = Ω(η) so that GD updates
are O(1)-lipschitz.
√
9.15.5 Geometric Lemmas for L
(2:M )
p PΦ(x),Γ x
e
e = 2∇2 L(Φ(x))(x − Φ(x)) and θ = arctan
First recall our notations, x |hv1 (x),e
xi|
.
499
And therefore,
∇L(x) p ζ 1/2 ν
p ≤ 2λ1 (Φ(x)) + O( kx − Φ(x)k) = O(ζ 1/2 ).
L(x) µ
1
∇L(x) − ∇2 L(Φ(x))(x − Φ(x)) ≤ ν kx − Φ(x)k2 .
2
Since Φ(x) is a local minimizer of zero loss, we have ∇L(Φ(x)) = 0, we have that
1
L(x) = ∂ 2 L(Φ(x))[x − Φ(x), x − Φ(x)] + O(ν kx − Φ(x)k3 ).
2
2
By Lemma 9.10.8, we know ∂ 2 L(Φ(x))[x−Φ(x), x−Φ(x)] ≥ Ω( kx−Φ(x)k
µ
) and therefore
q
1 2
2
∂ L(Φ(x))[x − Φ(x), x − Φ(x)] ν
p = 1 + O( kx − Φ(x)k).
L(x) µ
p
For the second claim, with x
e= 2∇2 L(Φ(x))(x − Φ(x)), we have that
r
∇2 L(Φ(x))(x − Φ(x)) 1 2 x p
∇ L(Φ(x)) ≤ 2λ1 (x).
e
q =
1 2
∂ L(Φ(x))[x − Φ(x), x − Φ(x)] 2 ke
xk
2
2µ ζ 1/2 ν µ
By Lemma 9.10.7, we have kx − Φ(x)k ≤ ζν
, thus µ
kx − Φ(x)k = O( ζ 1/2 )=
O(ζ 1/2 ).
500
The following two lemmas are direct implications of Lemma 9.15.7.
And therefore,
(2:M )
PΦ(x),Γ x
e p
where θ = arctan |hv1 (x),e
xi|
, with x
e= 2∇2 L(Φ(x))(x − Φ(x)).
p
Lemma 9.15.10. For any xy ∈ Y where y = x − η∇ L(x) is the one step update
√
on L loss from x, we have
η2 ⊥
Φ(y) − Φ(x) = − Px,Γ ∇(λ1 (∇2 L(x)))
8
ζ 3/2 νξ
+O(η 2 ζξθ) + O( 3/2 kx − Φ(x)k η 2 ) + O(ζχ kx − Φ(x)k η 2 ).
µ
(2:M )
PΦ(x),Γ x
e p
Here θ = arctan |hv1 (x),e
xi|
, with x
e= 2∇2 L(Φ(x))(x − Φ(x)).
501
Proof of Lemma 9.15.10. We outline the major difference from the proof of
Lemma 9.10.12. Using taylor expansion for the function Φ, we have
Φ(y) − Φ(x)
1
= ∂Φ(x) (y − x) + ∂ 2 Φ(x)[y − x, y − x] + err
2
! " # 3
∇L(x) η2 2 ∇L(x) ∇L(x) ∇L(x)
= ∂Φ(x) −η p + ∂ Φ(x) p , p + O(χη 3 p )
2 L(x) 2 2 L(x) 2 L(x) L(x)
" #
η2 ∇L(x) ∇L(x)
= ∂ 2 Φ(x) p , p + O(χζ 3/2 η 3 ),
2 2 L(x) 2 L(x)
where in the final step, we used the property of Φ from Lemma 9.10.14 to kill the first
∇L(x)
term and use the bound on √ from Lemma 9.15.7 for the third term.
L(x)
Also, at Φ(x), since v1 (x) is the top eigenvector of the hessian ∇2 L, we have from
Corollary 9.10.21,
1
∂ 2 Φ(Φ(x)) v1 (x)v1 (x)> = − ∂Φ(Φ(x))∂ 2 (∇L)(Φ(x))[v1 (x), v1 (x)].
2λ1 (x)
∇L(x)(∇L(x))> ζ 3/2 ν
λ1 (x)v1 (x)v1 (x)> − ≤ ζθ + O( 3/2 kx − Φ(x)k)
2L(x) µ
(2:M )
PΦ(x),Γ (x−Φ(x))
where recall our notation of θ = arctan |hv1 (x),x−Φ(x)i|
.
With further simplification, it turns out that
η2 2
∂ Φ(Φ(x)) λ1 (x)v1 (x)v1 (x)>
Φ(y) − Φ(x) = −
8
ζ 3/2 νξ
+O(η 2 ζξθ) + O( 3/2 kx − Φ(x)k η 2 ) + O(ζχ kx − Φ(x)k η 2 ).
µ
We provide the code for running a single step of the riemannian flow (9.5) corresponding
to Normalized GD. The pseudocode is given in Algorithm 10.
Loss setting: The algorithm described in Algorithm 10 works for the following
scenario. The loss L is equal to the average of n loss functions `i : RD → R+ and
each `i is defined as follows. Suppose we use n functions fi : RD → R, that share a
common parameter x ∈ RD , to approximate n true labels {bi ∈ R}ni=1 Then, we define
each `i (x) = `(fi (x), bi ), where ` : R × R → R+ denotes a general loss function, that
takes in the prediction of a function and the true label and returns a score. ` should
have the following properties:
Example of such a loss function ` is the `2 loss function. The scenario described above
contains regression tasks. Moreover, it can also represent binary classification tasks,
since binary classification can be viewed as regression with 0, 1 label.
503
We can also represent the multi-class classification tasks, which we use for our
experiments. Consider the setting, where we are trying to train some function (e.g. a
neural network) f : RD × Rd → R|C| with the parameter space in RD , input examples
from Rd , and the set of classes being denoted by C. For each class c, we can think of
fc (x, a) as the likelihood score for label c to an input a returned by the function with
parameter x. If S = {(a, b) ∈ Rd × C} denotes the set of all input and label pairs in
the training set, we define our loss function as
1 X X
L(x) = |fc (x, a) − I(b = c)|2 .
|S||C| c∈C
(a,b)∈S
Thus, each `i in Algorithm 10 represents one of the terms {|fc (x, a) − I(b = c)|2 }(a,b)∈S,c∈C
in the multi-class classification setting.
504
Algorithm 10 Simulation for the limiting flow (9.5) of Normalized GD
Input: n loss functions `i : RD → R+ , initial point x(0) with L(x(0)) ≈ 0, maximum
number of iteration T , LR η, Projection LR ηproj , maximum number of projection
iterations Tproj . P
Define L(x) as n1 ni=1 `i (x) and Px,Γ as projection matrix onto the subspace spanned
by ∇f1 (x), · · · , ∇fn (x) for any x ∈ RD .
for t = 0 to T − 1 do
Compute v1 , the top eigenvector of ∇2 L(x(t)).
Compute ∇λ1 (x(t)) = ∇3 L(x(t))[v1 , v1 ]. //This is by Theorem 9.14.1.
Compute Px(t),Γ ∇λ1 (x(t)) by solving least square.
η
y(0) ← x(t) − λ1 (x(t)) (I − Px(t),Γ )∇λ1 (x(t)).
t = 0 to Tproj − 1 do
for e
y(e
t +1) = y(e t)−ηproj ∇L(y(et)). // Inner loop: project GD back to manifold.
x(t + 1) ← y(Tproj ).
505
Bibliography
[1] Diederik P. Kingma and Jimmy Ba. Adam: A method for stochastic optimization.
In 3rd International Conference on Learning Representations, ICLR 2015, San
Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings, 2015. Cited on
pages 2 & 90.
[2] John Duchi, Elad Hazan, and Yoram Singer. Adaptive subgradient methods
for online learning and stochastic optimization. Journal of Machine Learning
Research, 12(Jul):2121–2159, 2011. Cited on page 2.
[4] Chiyuan Zhang, Samy Bengio, Moritz Hardt, Benjamin Recht, and Oriol Vinyals.
Understanding deep learning requires rethinking generalization. In International
Conference on Learning Representations, 2017. Cited on pages 4 & 158.
[5] Shengchao Liu, Dimitris Papailiopoulos, and Dimitris Achlioptas. Bad global
minima exist and sgd can reach them. Advances in Neural Information Processing
Systems, 33:8543–8552, 2020. Cited on page 4.
[7] Yiding Jiang, Pierre Foret, Scott Yak, Daniel M Roy, Hossein Mobahi,
Gintare Karolina Dziugaite, Samy Bengio, Suriya Gunasekar, Isabelle Guyon,
and Behnam Neyshabur. Neurips 2020 competition: Predicting generalization
in deep learning. arXiv preprint arXiv:2012.07976, 2020. Cited on page 4.
[8] Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-
aware minimization for efficiently improving generalization. In International
Conference on Learning Representations, 2021. URL https://openreview.net/
forum?id=6Tm1mposlrM. Cited on pages 5 & 399.
[9] Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy,
and Ping Tak Peter Tang. On large-batch training for deep learning: General-
ization gap and sharp minima. arXiv preprint arXiv:1609.04836, 2016. Cited
on pages 5, 312 & 399.
506
[10] David McAllester. Simplified pac-bayesian margin bounds. In Learning theory
and Kernel machines, pages 203–215. Springer, 2003. Cited on page 5.
[11] Gintare Karolina Dziugaite and Daniel M Roy. Computing nonvacuous general-
ization bounds for deep (stochastic) neural networks with many more parameters
than training data. arXiv preprint arXiv:1703.11008, 2017. Cited on page 5.
[12] Yiding Jiang*, Behnam Neyshabur*, Hossein Mobahi, Dilip Krishnan, and
Samy Bengio. Fantastic generalization measures and where to find them. In
International Conference on Learning Representations, 2020. URL https://
openreview.net/forum?id=SJgIPJBFvH. Cited on page 5.
[13] Daniel Soudry, Elad Hoffer, and Nathan Srebro. The implicit bias of gradient
descent on separable data. In International Conference on Learning Representa-
tions, 2018. Cited on pages 5 & 160.
[14] Kaifeng Lyu and Jian Li. Gradient descent maximizes the margin of homogeneous
neural networks. arXiv preprint arXiv:1906.05890, 2019. Cited on pages 5, 252,
306 & 400.
[15] Arthur Jacot, Franck Gabriel, and Clément Hongler. Neural tangent kernel: Con-
vergence and generalization in neural networks. arXiv preprint arXiv:1806.07572,
2018. Cited on pages 5, 121, 252, 312, 333 & 400.
[16] Suriya Gunasekar, Blake E Woodworth, Srinadh Bhojanapalli, Behnam
Neyshabur, and Nati Srebro. Implicit regularization in matrix factorization. In
I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan,
and R. Garnett, editors, Advances in Neural Information Processing Systems
30, pages 6151–6159. Curran Associates, Inc., 2017. Cited on pages 6, 10, 157,
159, 160, 161, 162, 173, 193 & 306.
[17] Jeremy Cohen, Simran Kaur, Yuanzhi Li, J Zico Kolter, and Ameet Talwalkar.
Gradient descent on neural networks typically occurs at the edge of stabil-
ity. In International Conference on Learning Representations, 2020. Cited on
pages 7 & 65.
[18] Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep
network training by reducing internal covariate shift. In International conference
on machine learning, pages 448–456. PMLR, 2015. Cited on pages 8, 14, 19,
91 & 119.
[19] Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization.
arXiv preprint arXiv:1607.06450, 2016. Cited on pages 8, 15, 19 & 91.
[20] Blake Woodworth, Suriya Gunasekar, Jason D Lee, Edward Moroshko, Pedro
Savarese, Itay Golan, Daniel Soudry, and Nathan Srebro. Kernel and rich
regimes in overparametrized models. In Conference on Learning Theory, pages
3635–3673. PMLR, 2020. Cited on pages 11, 249, 250, 251, 252, 261, 262, 271,
306, 311, 312, 326, 328, 330, 332 & 400.
507
[21] Jeff Z HaoChen, Colin Wei, Jason D Lee, and Tengyu Ma. Shape mat-
ters: Understanding the implicit bias of the noise covariance. arXiv preprint
arXiv:2006.08680, 2020. Cited on pages 11, 252, 306, 311, 312, 326 & 329.
[22] Zhiyuan Li and Sanjeev Arora. An exponential learning rate schedule for deep
learning. In International Conference on Learning Representations, 2019. Cited
on pages 12, 91, 92 & 97.
[23] Zhiyuan Li, Srinadh Bhojanapalli, Manzil Zaheer, Sashank Reddi, and Sanjiv
Kumar. Robust training of neural networks using scale invariant architectures.
In International Conference on Machine Learning, pages 12656–12684. PMLR,
2022. Cited on pages 12 & 401.
[24] Zhiyuan Li, Yi Zhang, and Sanjeev Arora. Why are convolutional nets more
sample-efficient than fully-connected nets? In International Conference on
Learning Representations, 2020. Cited on pages 12 & 392.
[25] Zhiyuan Li, Yuping Luo, and Kaifeng Lyu. Towards resolving the implicit
bias of gradient descent for matrix factorization: Greedy low-rank learning. In
International Conference on Learning Representations, 2020. Cited on pages 12,
249, 252, 312 & 400.
[26] Zhiyuan Li, Tianhao Wang, and Sanjeev Arora. What happens after sgd reaches
zero loss?–a mathematical framework. In International Conference on Learning
Representations, 2021. Cited on page 12.
[27] Sanjeev Arora, Zhiyuan Li, and Abhishek Panigrahi. Understanding gradient
descent on the edge of stability in deep learning. In Kamalika Chaudhuri,
Stefanie Jegelka, Le Song, Csaba Szepesvari, Gang Niu, and Sivan Sabato,
editors, Proceedings of the 39th International Conference on Machine Learning,
volume 162 of Proceedings of Machine Learning Research, pages 948–1024. PMLR,
17–23 Jul 2022. URL https://proceedings.mlr.press/v162/arora22a.html.
Cited on page 12.
[28] Shibani Santurkar, Dimitris Tsipras, Andrew Ilyas, and Aleksander Madry.
How does batch normalization help optimization? In S. Bengio, H. Wallach,
H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, editors, Ad-
vances in Neural Information Processing Systems 31, pages 2488–2498. Curran
Associates, Inc., 2018. Cited on pages 14, 18 & 31.
[29] Sanjeev Arora, Nadav Cohen, and Elad Hazan. On the optimization of deep
networks: Implicit acceleration by overparameterization. In International Con-
ference on Machine Learning, pages 244–253. PMLR, 2018. Cited on pages 14,
51, 177, 312 & 400.
[30] Yuxin Wu and Kaiming He. Group normalization. In Proceedings of the European
conference on computer vision (ECCV), pages 3–19, 2018. Cited on pages 15,
19 & 91.
508
[31] Dmitry Ulyanov, Andrea Vedaldi, and Victor Lempitsky. Instance normalization:
The missing ingredient for fast stylization. arXiv preprint arXiv:1607.08022,
2016. Cited on pages 15 & 19.
[32] Elad Hoffer, Itay Hubara, and Daniel Soudry. Fix your classifier: the marginal
value of training the last weight layer. In International Conference on Learning
Representations, 2018. URL https://openreview.net/forum?id=S1Dh8Tg0-.
Cited on pages 15 & 54.
[33] Leslie N Smith. Cyclical learning rates for training neural networks. In 2017
IEEE Winter Conference on Applications of Computer Vision (WACV), pages
464–472. IEEE, 2017. Cited on page 16.
[34] Ilya Loshchilov and Frank Hutter. SGDR: Stochastic Gradient Descent with
Warm Restarts. arXiv e-prints, art. arXiv:1608.03983, Aug 2016. Cited on
pages 16 & 33.
[35] Minhyung Cho and Jaehyung Lee. Riemannian approach to batch normalization.
In I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan,
and R. Garnett, editors, Advances in Neural Information Processing Systems
30, pages 5225–5235. Curran Associates, Inc., 2017. Cited on page 17.
[36] Sanjeev Arora, Zhiyuan Li, and Kaifeng Lyu. Theoretical analysis of auto
rate-tuning by batch normalization. In International Conference on Learning
Representations, 2018. Cited on pages 18, 19, 29, 31, 52, 62, 91, 93, 94 & 401.
[37] Xiaoxia Wu, Rachel Ward, and Léon Bottou. WNGrad: Learn the Learning
Rate in Gradient Descent. arXiv preprint arXiv:1803.02865, 2018. Cited on
pages 18 & 51.
[38] Jonas Kohler, Hadi Daneshmand, Aurelien Lucchi, Ming Zhou, Klaus Neymeyr,
and Thomas Hofmann. Exponential convergence rates for batch normalization:
The power of length-direction decoupling in non-convex optimization. arXiv
preprint arXiv:1805.10694, 2018. Cited on page 18.
[39] Nils Bjorck, Carla P Gomes, Bart Selman, and Kilian Q Weinberger. Understand-
ing batch normalization. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman,
N. Cesa-Bianchi, and R. Garnett, editors, Advances in Neural Information
Processing Systems 31, pages 7705–7716. Curran Associates, Inc., 2018. Cited
on page 18.
[40] Guodong Zhang, Chaoqi Wang, Bowen Xu, and Roger Grosse. Three mecha-
nisms of weight decay regularization. In International Conference on Learning
Representations, 2018. Cited on pages 18 & 91.
[41] Elad Hoffer, Ron Banner, Itay Golan, and Daniel Soudry. Norm matters:
efficient and accurate normalization schemes in deep networks. arXiv preprint
arXiv:1803.01814, 2018. Cited on pages 18 & 91.
509
[42] Twan Van Laarhoven. L2 regularization versus batch and weight normalization.
arXiv preprint arXiv:1706.05350, 2017. Cited on pages 18, 91 & 95.
[43] Ilya Sutskever, James Martens, George Dahl, and Geoffrey Hinton. On the
importance of initialization and momentum in deep learning. In Proceedings
of the 30th International Conference on International Conference on Machine
Learning - Volume 28, ICML’13, pages III–1139–III–1147. JMLR.org, 2013. URL
http://dl.acm.org/citation.cfm?id=3042817.3043064. Cited on page 19.
[44] Adam Paszke, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang,
Zachary DeVito, Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer.
Automatic differentiation in pytorch. 2017. Cited on pages 19 & 58.
[45] Robert Mansel Gower, Nicolas Loizou, Xun Qian, Alibek Sailanbayev, Egor
Shulgin, and Peter Richtárik. Sgd: General analysis and improved rates. arXiv
preprint arXiv:1901.09401, 2019. Cited on page 46.
[46] Sanjoy Dasgupta and Anupam Gupta. An elementary proof of a theorem of
johnson and lindenstrauss. Random Structures & Algorithms, 22(1):60–65, 2003.
Cited on page 46.
[47] Yang You, Igor Gitman, and Boris Ginsburg. Large Batch Training of Convo-
lutional Networks. arXiv e-prints, art. arXiv:1708.03888, Aug 2017. Cited on
page 53.
[48] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning
for image recognition. In Proceedings of the IEEE conference on computer vision
and pattern recognition, pages 770–778, 2016. Cited on pages 53 & 117.
[49] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Identity mappings
in deep residual networks. In European conference on computer vision, pages
630–645. Springer, 2016. Cited on page 53.
[50] Zhiyuan Li, Kaifeng Lyu, and Sanjeev Arora. Reconciling modern deep learning
with traditional optimization analyses: The intrinsic learning rate. Advances
in Neural Information Processing Systems, 33, 2020. Cited on pages 65, 91, 93,
306, 312 & 395.
[51] Ekaterina Lobacheva, Maxim Kodryan, Nadezhda Chirkova, Andrey Malinin,
and Dmitry P Vetrov. On the periodic behavior of neural network training
with batch normalization and weight decay. Advances in Neural Information
Processing Systems, 34, 2021. Cited on page 65.
[52] Ramon van Handel. Probability in high dimension. 2016. Cited on page 74.
[53] Liyuan Liu, Xiaodong Liu, Jianfeng Gao, Weizhu Chen, and Jiawei Han. Under-
standing the difficulty of training transformers. In Bonnie Webber, Trevor Cohn,
Yulan He, and Yang Liu, editors, Proceedings of the 2020 Conference on Empir-
ical Methods in Natural Language Processing, EMNLP 2020, Online, November
510
16-20, 2020, pages 5747–5763. Association for Computational Linguistics, 2020.
Cited on pages 86, 88 & 90.
[54] Jingzhao Zhang, Sai Praneeth Karimireddy, Andreas Veit, Seungyeon Kim,
Sashank J. Reddi, Sanjiv Kumar, and Suvrit Sra. Why are adaptive methods
good for attention models? In Hugo Larochelle, Marc’Aurelio Ranzato, Raia
Hadsell, Maria-Florina Balcan, and Hsuan-Tien Lin, editors, Advances in Neural
Information Processing Systems 33: Annual Conference on Neural Information
Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual, 2020.
Cited on pages 86, 90 & 94.
[55] Aitor Lewkowycz, Yasaman Bahri, Ethan Dyer, Jascha Sohl-Dickstein, and Guy
Gur-Ari. The large learning rate phase of deep learning: the catapult mechanism.
arXiv preprint arXiv:2003.02218, 2020. Cited on pages 88 & 401.
[57] John Duchi, Elad Hazan, and Yoram Singer. Adaptive subgradient methods
for online learning and stochastic optimization. Journal of machine learning
research, 12(Jul):2121–2159, 2011. Cited on page 90.
[58] Tijmen Tieleman and Geoffrey Hinton. Lecture 6.5-rmsprop: Divide the gradient
by a running average of its recent magnitude. COURSERA: Neural networks
for machine learning, 4(2):26–31, 2012. Cited on pages 90, 184 & 188.
[59] Sashank J Reddi, Satyen Kale, and Sanjiv Kumar. On the convergence of ADAM
and beyond. arXiv preprint arXiv:1904.09237, 2019. Cited on page 90.
[60] Yang You, Jing Li, Sashank J. Reddi, Jonathan Hseu, Sanjiv Kumar, Srinadh
Bhojanapalli, Xiaodan Song, James Demmel, Kurt Keutzer, and Cho-Jui Hsieh.
Large batch optimization for deep learning: Training BERT in 76 minutes. In
8th International Conference on Learning Representations, ICLR 2020, Addis
Ababa, Ethiopia, April 26-30, 2020. OpenReview.net, 2020. Cited on page 90.
[61] Noam Shazeer and Mitchell Stern. Adafactor: Adaptive learning rates with
sublinear memory cost. In International Conference on Machine Learning, pages
4596–4604. PMLR, 2018. Cited on page 90.
[62] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,
Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need.
In Advances in neural information processing systems, pages 5998–6008, 2017.
Cited on pages 90 & 92.
[63] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-
training of deep bidirectional transformers for language understanding. arXiv
preprint arXiv:1810.04805, 2018. Cited on pages 90, 92 & 101.
511
[64] Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang,
Michael Matena, Yanqi Zhou, Wei Li, and Peter J Liu. Exploring the limits
of transfer learning with a unified text-to-text transformer. arXiv preprint
arXiv:1910.10683, 2019. Cited on page 90.
[65] Rohan Anil, Vineet Gupta, Tomer Koren, and Yoram Singer. Memory efficient
adaptive optimization. Advances in Neural Information Processing Systems, 32,
2019. Cited on page 90.
[66] Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav
Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton,
Sebastian Gehrmann, et al. Palm: Scaling language modeling with pathways.
arXiv preprint arXiv:2204.02311, 2022. Cited on page 90.
[67] Elad Hazan, Kfir Levy, and Shai Shalev-Shwartz. Beyond convexity: Stochas-
tic quasi-convex optimization. In Advances in Neural Information Processing
Systems, pages 1594–1602, 2015. Cited on page 90.
[68] Kfir Y Levy. The power of normalization: Faster evasion of saddle points. arXiv
preprint arXiv:1611.04831, 2016. Cited on page 90.
[69] Lei Huang, Xianglong Liu, Bo Lang, and Bo Li. Projection based weight
normalization for deep neural networks. ArXiv, abs/1710.02338, 2017. Cited on
page 90.
[70] Razvan Pascanu, Tomas Mikolov, and Yoshua Bengio. On the difficulty of
training recurrent neural networks. In Sanjoy Dasgupta and David McAllester,
editors, Proceedings of the 30th International Conference on Machine Learning,
volume 28 of Proceedings of Machine Learning Research, pages 1310–1318,
Atlanta, Georgia, USA, 17–19 Jun 2013. PMLR. URL https://proceedings.
mlr.press/v28/pascanu13.html. Cited on page 90.
[71] Xiangyi Chen, Zhiwei Steven Wu, and Mingyi Hong. Understanding gradient
clipping in private SGD: A geometric perspective. CoRR, abs/2006.15429, 2020.
URL https://arxiv.org/abs/2006.15429. Cited on page 90.
[72] Jingzhao Zhang, Tianxing He, Suvrit Sra, and Ali Jadbabaie. Why gradient
clipping accelerates training: A theoretical justification for adaptivity. In 8th
International Conference on Learning Representations, ICLR 2020, Addis Ababa,
Ethiopia, April 26-30, 2020. OpenReview.net, 2020. Cited on page 90.
[73] Tim Salimans and Durk P Kingma. Weight normalization: A simple reparame-
terization to accelerate training of deep neural networks. Advances in neural
information processing systems, 29:901–909, 2016. Cited on page 91.
[74] Dmitry Ulyanov, Andrea Vedaldi, and Victor Lempitsky. Instance normalization:
The missing ingredient for fast stylization. arXiv preprint arXiv:1607.08022,
2016. Cited on page 91.
512
[75] Ruosi Wan, Zhanxing Zhu, Xiangyu Zhang, and Jian Sun. Spherical motion
dynamics: Learning dynamics of neural network with normalization, weight
decay, and sgd. arXiv preprint arXiv:2006.08419, 2020. Cited on pages 91 & 93.
[76] Dan Hendrycks and Kevin Gimpel. Gaussian error linear units (gelus). arXiv
preprint arXiv:1606.08415, 2016. Cited on pages 97, 320 & 417.
[77] Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, and Percy Liang. Squad:
100,000+ questions for machine comprehension of text. arXiv preprint
arXiv:1606.05250, 2016. Cited on page 103.
[78] Pranav Rajpurkar, Robin Jia, and Percy Liang. Know what you don’t know:
Unanswerable questions for squad. arXiv preprint arXiv:1806.03822, 2018. Cited
on page 103.
[79] Adina Williams, Nikita Nangia, and Samuel Bowman. A broad-coverage chal-
lenge corpus for sentence understanding through inference. In Proceedings of the
2018 Conference of the North American Chapter of the Association for Compu-
tational Linguistics: Human Language Technologies, Volume 1 (Long Papers),
pages 1112–1122. Association for Computational Linguistics, 2018. Cited on
page 103.
[80] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton. Imagenet classification
with deep convolutional neural networks. In Advances in neural information
processing systems, pages 1097–1105, 2012. Cited on page 117.
[81] Gao Huang, Zhuang Liu, Laurens Van Der Maaten, and Kilian Q Weinberger.
Densely connected convolutional networks. In Proceedings of the IEEE conference
on computer vision and pattern recognition, pages 4700–4708, 2017. Cited on
page 117.
[83] Simon S Du, Yining Wang, Xiyu Zhai, Sivaraman Balakrishnan, Russ R Salakhut-
dinov, and Aarti Singh. How many samples are needed to estimate a convolutional
neural network? In Advances in Neural Information Processing Systems, pages
373–383, 2018. Cited on page 120.
[84] Yossi Arjevani and Ohad Shamir. On the iteration complexity of oblivious
first-order optimization algorithms. In International Conference on Machine
Learning, pages 908–916, 2016. Cited on page 120.
[85] Colin Wei, Jason D Lee, Qiang Liu, and Tengyu Ma. Regularization matters:
Generalization and optimization of neural nets vs their induced kernel. In
Advances in Neural Information Processing Systems, pages 9709–9721, 2019.
Cited on page 121.
513
[86] Zeyuan Allen-Zhu and Yuanzhi Li. What can resnet learn efficiently, going
beyond kernels? In Advances in Neural Information Processing Systems, pages
9015–9025, 2019. Cited on page 121.
[87] Anselm Blumer, A. Ehrenfeucht, David Haussler, and Manfred K. Warmuth.
Learnability and the vapnik-chervonenkis dimension. J. ACM, 36(4):929–965,
October 1989. ISSN 0004-5411. doi: 10.1145/76359.76371. URL https://doi.
org/10.1145/76359.76371. Cited on page 123.
[88] Gyora M Benedek and Alon Itai. Learnability with respect to fixed distributions.
Theoretical Computer Science, 86(2):377–389, 1991. Cited on pages 134 & 137.
[89] Philip M. Long. On the sample complexity of PAC learning half-spaces against
the uniform distribution. IEEE Transactions on Neural Networks, 6(6):1556–
1559, 1995. Cited on page 134.
[90] Michel Talagrand. Upper and lower bounds for stochastic processes: modern
methods and classical problems, volume 60. Springer Science & Business Media,
2014. Cited on page 141.
[91] Stanislaw J Szarek. Metric entropy of homogeneous spaces. arXiv preprint
math/9701213, 1997. Cited on page 145.
[92] Zongming Ma and Yihong Wu. Volume ratio, sparsity, and minimaxity under
unitarily invariant norms. IEEE Transactions on Information Theory, 61(12):
6939–6956, 2015. Cited on page 146.
[93] Roman Vershynin. High-Dimensional Probability: An Introduction with Ap-
plications in Data Science. Cambridge Series in Statistical and Probabilistic
Mathematics. Cambridge University Press, 2018. doi: 10.1017/9781108231596.
Cited on page 146.
[94] Yuejie Chi, Yue M Lu, and Yuxin Chen. Nonconvex optimization meets low-rank
matrix factorization: An overview. IEEE Transactions on Signal Processing, 67
(20):5239–5269, 2019. Cited on page 158.
[95] Sanjeev Arora, Nadav Cohen, Wei Hu, and Yuping Luo. Implicit regularization
in deep matrix factorization. arXiv preprint arXiv:1905.13655, 2019. Cited on
pages 159, 161, 249, 252, 312 & 400.
[96] Gauthier Gidel, Francis Bach, and Simon Lacoste-Julien. Implicit regulariza-
tion of discrete gradient dynamics in linear neural networks. In H. Wallach,
H. Larochelle, A. Beygelzimer, F. d’ Alché-Buc, E. Fox, and R. Garnett, editors,
Advances in Neural Information Processing Systems 32, pages 3196–3206. Curran
Associates, Inc., 2019. Cited on pages 159, 161 & 165.
[97] Daniel Gissin, Shai Shalev-Shwartz, and Amit Daniely. The implicit bias of depth:
How incremental learning drives generalization. In International Conference on
Learning Representations, 2020. Cited on pages 159, 160, 161, 164 & 165.
514
[98] Noam Razin and Nadav Cohen. Implicit regularization in deep learning may
not be explainable by norms. arXiv preprint arXiv:2005.06398, 2020. Cited on
pages 159, 161, 252 & 400.
[99] Ashia C Wilson, Rebecca Roelofs, Mitchell Stern, Nathan Srebro, and Benjamin
Recht. The marginal value of adaptive gradient methods in machine learning.
arXiv preprint arXiv:1705.08292, 2017. Cited on page 160.
[100] Daniel Soudry, Elad Hoffer, Mor Shpigel Nacson, Suriya Gunasekar, and Nathan
Srebro. The implicit bias of gradient descent on separable data. The Journal of
Machine Learning Research, 19(1):2822–2878, 2018. Cited on pages 160, 252,
306, 312 & 400.
[101] Mor Shpigel Nacson, Jason Lee, Suriya Gunasekar, Pedro Henrique Pamplona
Savarese, Nathan Srebro, and Daniel Soudry. Convergence of gradient descent
on separable data. In Kamalika Chaudhuri and Masashi Sugiyama, editors,
Proceedings of Machine Learning Research, volume 89 of Proceedings of Ma-
chine Learning Research, pages 3420–3428. PMLR, 16–18 Apr 2019. Cited on
pages 160 & 252.
[102] Mor Shpigel Nacson, Nathan Srebro, and Daniel Soudry. Stochastic gradient
descent on separable data: Exact convergence with a fixed learning rate. In
Kamalika Chaudhuri and Masashi Sugiyama, editors, Proceedings of Machine
Learning Research, volume 89 of Proceedings of Machine Learning Research,
pages 3051–3059. PMLR, 16–18 Apr 2019. Cited on page 160.
[103] Ziwei Ji and Matus Telgarsky. A refined primal-dual analysis of the implicit
bias. arXiv preprint arXiv:1906.04540, 2019. Cited on page 160.
[104] Ziwei Ji and Matus Telgarsky. Gradient descent aligns the layers of deep linear
networks. arXiv preprint arXiv:1810.02032, 2018. Cited on page 160.
[106] Mor Shpigel Nacson, Suriya Gunasekar, Jason Lee, Nathan Srebro, and Daniel
Soudry. Lexicographic and depth-sensitive margins in homogeneous and non-
homogeneous deep models. In Kamalika Chaudhuri and Ruslan Salakhutdinov,
editors, Proceedings of the 36th International Conference on Machine Learning,
volume 97 of Proceedings of Machine Learning Research, pages 4683–4692, Long
Beach, California, USA, 09–15 Jun 2019. PMLR. Cited on page 160.
[107] Kaifeng Lyu and Jian Li. Gradient descent maximizes the margin of homogeneous
neural networks. In International Conference on Learning Representations, 2020.
Cited on page 160.
515
[108] Arthur Jacot, Franck Gabriel, and Clement Hongler. Neural tangent kernel:
Convergence and generalization in neural networks. In S. Bengio, H. Wallach,
H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, editors, Ad-
vances in Neural Information Processing Systems 31, pages 8571–8580. Curran
Associates, Inc., 2018. Cited on page 160.
[109] Sanjeev Arora, Simon S Du, Wei Hu, Zhiyuan Li, Ruslan Salakhutdinov, and
Ruosong Wang. On exact computation with an infinitely wide neural net. arXiv
preprint arXiv:1904.11955, 2019. Cited on pages 160 & 400.
[110] Lénaı̈c Chizat and Francis Bach. Implicit bias of gradient descent for wide two-
layer neural networks trained with the logistic loss. volume 125 of Proceedings
of Machine Learning Research, pages 1305–1338. PMLR, 09–12 Jul 2020. Cited
on page 160.
[111] Lénaı̈c Chizat, Edouard Oyallon, and Francis Bach. On lazy training in differen-
tiable programming. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d’ Alché-
Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing
Systems 32, pages 2937–2947. Curran Associates, Inc., 2019. Cited on page 161.
[112] Yuanzhi Li, Tengyu Ma, and Hongyang Zhang. Algorithmic regularization in over-
parameterized matrix sensing and neural networks with quadratic activations. In
Conference On Learning Theory, pages 2–47. PMLR, 2018. Cited on pages 161,
252, 312 & 400.
[113] Mohamed Ali Belabbas. On implicit regularization: Morse functions and appli-
cations to matrix factorization. arXiv preprint arXiv:2001.04264, 2020. Cited
on page 161.
[114] Zheng Wang, Ming-Jun Lai, Zhaosong Lu, Wei Fan, Hasan Davulcu, and Jieping
Ye. Rank-one matrix pursuit for matrix completion. In International Conference
on Machine Learning, pages 91–99, 2014. Cited on pages 168 & 189.
[115] Quanming Yao and James Tin Yau Kwok. Greedy learning of generalized low-
rank models. In IJCAI International Joint Conference on Artificial Intelligence,
2016. Cited on page 168.
[116] Shai Shalev-Shwartz and Yoram Singer. On the equivalence of weak learnabil-
ity and linear separability: New relaxations and efficient boosting algorithms.
Machine learning, 80(2-3):141–163, 2010. Cited on page 168.
[117] Rajiv Khanna, Ethan Elenberg, Alexandros G Dimakis, and Sahand Negahban.
On approximation guarantees for greedy low rank optimization. arXiv preprint
arXiv:1703.02721, 2017. Cited on page 168.
[118] Benjamin D. Haeffele and René Vidal. Structured Low-Rank Matrix Factor-
ization: Global Optimality, Algorithms, and Applications. IEEE Transactions
on Pattern Analysis and Machine Intelligence (PAMI), 42(6):1468–1482, 2019.
Cited on page 168.
516
[119] Jason D Lee, Ioannis Panageas, Georgios Piliouras, Max Simchowitz, Michael I
Jordan, and Benjamin Recht. First-order methods almost always avoid saddle
points. arXiv preprint arXiv:1710.07406, 2017. Cited on pages 172, 174, 222,
331 & 416.
[120] Ioannis Panageas, Georgios Piliouras, and Xiao Wang. First-order methods
almost always avoid saddle points: The case of vanishing step-sizes. In H. Wallach,
H. Larochelle, A. Beygelzimer, F. d’ Alché-Buc, E. Fox, and R. Garnett, editors,
Advances in Neural Information Processing Systems 32, pages 6474–6483. Curran
Associates, Inc., 2019. Cited on page 172.
[121] Jason D Lee, Max Simchowitz, Michael I Jordan, and Benjamin Recht. Gradient
descent only converges to minimizers. In Conference on learning theory, pages
1246–1257. PMLR, 2016. Cited on pages 174, 331 & 416.
[122] Jeff Bezanson, Stefan Karpinski, Viral B Shah, and Alan Edelman. Julia: A
fast dynamic language for technical computing. arXiv preprint arXiv:1209.5145,
2012. Cited on page 187.
[123] Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury,
Gregory Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga,
Alban Desmaison, Andreas Kopf, Edward Yang, Zachary DeVito, Martin Raison,
Alykhan Tejani, Sasank Chilamkurthy, Benoit Steiner, Lu Fang, Junjie Bai,
and Soumith Chintala. Pytorch: An imperative style, high-performance deep
learning library. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d’ Alché-Buc,
E. Fox, and R. Garnett, editors, Advances in Neural Information Processing
Systems 32, pages 8024–8035. Curran Associates, Inc., 2019. Cited on page 187.
[125] Akshay Agrawal, Robin Verschueren, Steven Diamond, and Stephen Boyd. A
rewriting system for convex optimization problems. Journal of Control and
Decision, 5(1):42–60, 2018. Cited on page 189.
[131] Francis H Clarke, Yuri S Ledyaev, Ronald J Stern, and Peter R Wolenski.
Nonsmooth analysis and control theory, volume 178. Springer Science & Business
Media, 2008. Cited on page 227.
[133] Zhiyuan Li, Tianhao Wang, and Sanjeev Arora. What happens after SGD
reaches zero loss? –a mathematical framework. In International Conference on
Learning Representations, 2022. Cited on pages 249, 252, 271, 399 & 400.
[134] Tomas Vaskevicius, Varun Kanade, and Patrick Rebeschini. Implicit regulariza-
tion for optimal sparse recovery. Advances in Neural Information Processing
Systems, 32:2972–2983, 2019. Cited on pages 250, 251, 252, 261 & 312.
[135] Chulhee Yun, Shankar Krishnan, and Hossein Mobahi. A unifying view on
implicit bias in training linear neural networks. arXiv preprint arXiv:2010.02501,
2020. Cited on pages 250, 251 & 252.
[136] Ehsan Amid and Manfred K Warmuth. Winnowing with gradient descent. In
Conference on Learning Theory, pages 163–182. PMLR, 2020. Cited on pages 250,
251, 252, 261 & 274.
[138] Shahar Azulay, Edward Moroshko, Mor Shpigel Nacson, Blake Woodworth,
Nathan Srebro, Amir Globerson, and Daniel Soudry. On the implicit bias
of initialization shape: Beyond infinitesimal mirror descent. arXiv preprint
arXiv:2102.09769, 2021. Cited on pages 250, 251, 252, 271, 274, 330 & 400.
[139] Arkadij Semenovič Nemirovskij and David Borisovich Yudin. Problem complexity
and method efficiency in optimization. 1983. Cited on pages 250 & 258.
[140] Amir Beck and Marc Teboulle. Mirror descent and nonlinear projected sub-
gradient methods for convex optimization. Operations Research Letters, 31(3):
167–175, 2003. Cited on pages 250 & 258.
[141] Udaya Ghai, Zhou Lu, and Elad Hazan. Non-convex online learning via algo-
rithmic equivalence. arXiv preprint arXiv:2205.15235, 2022. Cited on pages 251,
253 & 274.
518
[142] Suriya Gunasekar, Jason Lee, Daniel Soudry, and Nathan Srebro. Characterizing
implicit bias in terms of optimization geometry. In International Conference
on Machine Learning, pages 1832–1841. PMLR, 2018. Cited on pages 252, 259,
274, 312 & 400.
[143] Ziwei Ji and Matus Telgarsky. Characterizing the implicit bias via a primal-dual
analysis. In Algorithmic Learning Theory, pages 772–804. PMLR, 2021. Cited
on page 252.
[144] Edward Moroshko, Suriya Gunasekar, Blake Woodworth, Jason D Lee, Nathan
Srebro, and Daniel Soudry. Implicit bias in deep linear classification: Initializa-
tion scale vs training accuracy. arXiv preprint arXiv:2007.06738, 2020. Cited
on page 252.
[145] Ziwei Ji and Matus Telgarsky. Directional convergence and alignment in deep
learning. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin,
editors, Advances in Neural Information Processing Systems, volume 33, pages
17176–17186. Curran Associates, Inc., 2020. Cited on page 252.
[146] Ziwei Ji and Matus Telgarsky. Risk and parameter convergence of logistic
regression. arXiv preprint arXiv:1803.07300, 2018. Cited on page 252.
[147] Ziwei Ji and Matus Telgarsky. The implicit bias of gradient descent on nonsepa-
rable data. In Conference on Learning Theory, pages 1772–1798. PMLR, 2019.
Cited on page 252.
[148] Lenaic Chizat, Edouard Oyallon, and Francis Bach. On lazy training in differen-
tiable programming. arXiv preprint arXiv:1812.07956, 2018. Cited on pages 252,
312 & 400.
[149] Simon S Du, Xiyu Zhai, Barnabas Poczos, and Aarti Singh. Gradient de-
scent provably optimizes over-parameterized neural networks. arXiv preprint
arXiv:1810.02054, 2018. Cited on pages 252 & 312.
[150] Simon Du, Jason Lee, Haochuan Li, Liwei Wang, and Xiyu Zhai. Gradient
descent finds global minima of deep neural networks. In International Conference
on Machine Learning, pages 1675–1685. PMLR, 2019. Cited on pages 252,
312 & 400.
[151] Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A convergence theory for deep
learning via over-parameterization. In International Conference on Machine
Learning, pages 242–252. PMLR, 2019. Cited on pages 252, 312 & 400.
[152] Zeyuan Allen-Zhu, Yuanzhi Li, and Yingyu Liang. Learning and generalization
in overparameterized neural networks, going beyond two layers. Advances in
neural information processing systems, 2019. Cited on pages 252, 312 & 400.
519
[153] Difan Zou, Yuan Cao, Dongruo Zhou, and Quanquan Gu. Gradient descent
optimizes over-parameterized deep relu networks. Machine Learning, 109(3):
467–492, 2020. Cited on pages 252, 312 & 400.
[154] Sanjeev Arora, Simon Du, Wei Hu, Zhiyuan Li, and Ruosong Wang. Fine-
grained analysis of optimization and generalization for overparameterized two-
layer neural networks. In International Conference on Machine Learning, pages
322–332. PMLR, 2019. Cited on pages 252, 312 & 400.
[155] Greg Yang. Scaling limits of wide neural networks with weight sharing: Gaussian
process behavior, gradient independence, and neural tangent kernel derivation.
arXiv preprint arXiv:1902.04760, 2019. Cited on pages 252, 312 & 400.
[156] Arthur Jacot, François Ged, Franck Gabriel, Berfin Şimşek, and Clément Hongler.
Deep linear networks dynamics: Low-rank biases induced by initialization scale
and l2 regularization. arXiv preprint arXiv:2106.15933, 2021. Cited on page 252.
[157] Suriya Gunasekar, Jason D Lee, Daniel Soudry, and Nati Srebro. Implicit
bias of gradient descent on linear convolutional networks. Advances in Neural
Information Processing Systems, 31, 2018. Cited on page 252.
[158] Lenaic Chizat and Francis Bach. Implicit bias of gradient descent for wide two-
layer neural networks trained with the logistic loss. In Conference on Learning
Theory, pages 1305–1338. PMLR, 2020. Cited on page 252.
[159] Kaifeng Lyu, Zhiyuan Li, Runzhe Wang, and Sanjeev Arora. Gradient descent
on two-layer nets: Margin maximization and simplicity bias. Advances in Neural
Information Processing Systems, 34, 2021. Cited on pages 252 & 400.
[160] Noam Razin, Asaf Maman, and Nadav Cohen. Implicit regularization in hi-
erarchical tensor factorization and deep convolutional neural networks. arXiv
preprint arXiv:2201.11729, 2022. Cited on page 252.
[161] Dominik Stöger and Mahdi Soltanolkotabi. Small random initialization is akin
to spectral learning: Optimization and generalization guarantees for overpa-
rameterized low-rank matrix reconstruction. Advances in Neural Information
Processing Systems, 34, 2021. Cited on page 252.
[162] Rong Ge, Yunwei Ren, Xiang Wang, and Mo Zhou. Understanding deflation pro-
cess in over-parametrized tensor decomposition. Advances in Neural Information
Processing Systems, 34, 2021. Cited on page 252.
[163] Greg Yang and Edward J Hu. Tensor programs iv: Feature learning in infinite-
width neural networks. In International Conference on Machine Learning, pages
11727–11737. PMLR, 2021. Cited on page 252.
[164] Yuanzhi Li, Colin Wei, and Tengyu Ma. Towards explaining the regularization
effect of initial large learning rate in training neural networks. arXiv preprint
arXiv:1907.04595, 2019. Cited on pages 252, 306 & 312.
520
[165] Guy Blanc, Neha Gupta, Gregory Valiant, and Paul Valiant. Implicit regulariza-
tion for deep neural networks driven by an ornstein-uhlenbeck like process. In
Conference on learning theory, pages 483–513. PMLR, 2020. Cited on pages 252,
305, 307, 308, 311, 325, 326 & 400.
[166] Alex Damian, Tengyu Ma, and Jason Lee. Label noise sgd provably prefers flat
global minimizers. arXiv preprint arXiv:2106.06530, 2021. Cited on pages 252,
309, 311, 325, 326, 327 & 400.
[167] Difan Zou, Jingfeng Wu, Vladimir Braverman, Quanquan Gu, Dean P Foster,
and Sham Kakade. The benefits of implicit regularization from sgd in least
squares problems. Advances in Neural Information Processing Systems, 34:
5456–5468, 2021. Cited on page 252.
[168] Qian Qian and Xiaoyuan Qian. The implicit bias of adagrad on separable
data. Advances in Neural Information Processing Systems, 32, 2019. Cited on
page 252.
[169] Bohan Wang, Qi Meng, Huishuai Zhang, Ruoyu Sun, Wei Chen, and Zhi-
Ming Ma. Momentum doesn’t change the implicit bias. arXiv preprint
arXiv:2110.03891, 2021. Cited on page 252.
[170] Bohan Wang, Qi Meng, Wei Chen, and Tie-Yan Liu. The implicit bias for adap-
tive optimization algorithms on homogeneous neural networks. In International
Conference on Machine Learning, pages 10849–10858. PMLR, 2021. Cited on
page 252.
[171] Ziwei Ji, Nathan Srebro, and Matus Telgarsky. Fast margin maximization via
dual acceleration. In International Conference on Machine Learning, pages
4860–4869. PMLR, 2021. Cited on page 252.
[172] Suriya Gunasekar, Blake Woodworth, and Nathan Srebro. Mirrorless mirror
descent: A natural derivation of mirror descent. In International Conference on
Artificial Intelligence and Statistics, pages 2305–2313. PMLR, 2021. Cited on
pages 252 & 400.
521
[176] Scott Pesme, Loucas Pillaud-Vivien, and Nicolas Flammarion. Implicit bias
of sgd for diagonal linear networks: a provable benefit of stochasticity. arXiv
preprint arXiv:2106.09524, 2021. Cited on pages 271, 313 & 330.
[177] John Nash. The imbedding problem for riemannian manifolds. Annals of
mathematics, pages 20–63, 1956. Cited on page 274.
[180] Heinz H Bauschke, Jonathan M Borwein, et al. Legendre functions and the
method of random bregman projections. Journal of convex analysis, 4(1):27–67,
1997. Cited on pages 277 & 278.
[181] Lev M Bregman. The relaxation method of finding the common point of convex
sets and its application to the solution of problems in convex programming.
USSR computational mathematics and mathematical physics, 7(3):200–217, 1967.
Cited on page 278.
[182] Yair Censor and Arnold Lent. An iterative row-action method for interval convex
programming. Journal of Optimization theory and Applications, 34(3):321–353,
1981. Cited on page 278.
[183] Felipe Alvarez, Jérôme Bolte, and Olivier Brahic. Hessian riemannian gradient
flows in convex programming. SIAM journal on control and optimization, 43(2):
477–501, 2004. Cited on pages 278 & 279.
[184] Serge Lang. Introduction to differentiable manifolds. Springer Science & Business
Media, 2006. Cited on page 279.
[186] Robert L Foote. Regularity of the distance function. Proceedings of the American
Mathematical Society, 92(1):153–155, 1984. Cited on page 301.
[188] Stanislaw Jastrzebski, Zachary Kenton, Devansh Arpit, Nicolas Ballas, Asja
Fischer, Yoshua Bengio, and Amos Storkey. Three factors influencing minima in
sgd. arXiv preprint arXiv:1711.04623, 2017. Cited on pages 306 & 312.
522
[189] Bin Shi, Weijie J Su, and Michael I Jordan. On learning rates and schr\” odinger
operators. arXiv preprint arXiv:2004.06977, 2020. Cited on pages 306 & 325.
[190] Qianxiao Li, Cheng Tai, and E Weinan. Stochastic modified equations and
adaptive stochastic gradient algorithms. In International Conference on Machine
Learning, pages 2101–2110. PMLR, 2017. Cited on pages 307, 312, 335 & 401.
[191] Xiang Cheng, Dong Yin, Peter Bartlett, and Michael Jordan. Stochastic gradient
and langevin processes. In International Conference on Machine Learning, pages
1810–1819. PMLR, 2020. Cited on pages 307 & 312.
[192] Garvesh Raskutti, Martin J Wainwright, and Bin Yu. Minimax-optimal rates
for sparse additive models over kernel classes via convex programming. Journal
of machine learning research, 13(2), 2012. Cited on page 311.
[193] C Daniel Freeman and Joan Bruna. Topology and geometry of half-rectified
network optimization. arXiv preprint arXiv:1611.01540, 2016. Cited on page 311.
[194] Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin, Dmitry Vetrov, and
Andrew Gordon Wilson. Loss surfaces, mode connectivity, and fast ensembling
of dnns. arXiv preprint arXiv:1802.10026, 2018. Cited on page 311.
[195] Felix Draxler, Kambis Veschgini, Manfred Salmhofer, and Fred Hamprecht.
Essentially no barriers in neural network energy landscape. In International
conference on machine learning, pages 1309–1318. PMLR, 2018. Cited on
page 311.
[196] Luca Venturi, Afonso S Bandeira, and Joan Bruna. Spurious valleys in two-layer
neural network optimization landscapes. arXiv preprint arXiv:1802.06384, 2018.
Cited on page 311.
[197] Shiyu Liang, Ruoyu Sun, Yixuan Li, and Rayadurgam Srikant. Understanding
the loss surface of neural networks for binary classification. In International
Conference on Machine Learning, pages 2835–2843. PMLR, 2018. Cited on
page 311.
[198] Quynh Nguyen, Mahesh Chandra Mukkamala, and Matthias Hein. On the loss
landscape of a class of deep neural networks with no bad local valleys. arXiv
preprint arXiv:1809.10749, 2018. Cited on page 311.
[200] Rohith Kuditipudi, Xiang Wang, Holden Lee, Yi Zhang, Zhiyuan Li, Wei Hu,
Sanjeev Arora, and Rong Ge. Explaining landscape connectivity of low-cost
solutions for multilayer nets. arXiv preprint arXiv:1906.06247, 2019. Cited on
page 311.
523
[201] Yaim Cooper. The loss landscape of overparameterized neural networks. arXiv
preprint arXiv:1804.10200, 2018. Cited on page 312.
[203] Benjamin Fehrman, Benjamin Gess, and Arnulf Jentzen. Convergence rates
for the stochastic gradient descent method for non-convex objective functions.
Journal of Machine Learning Research, 21, 2020. Cited on pages 312, 320 & 439.
[204] Yann A LeCun, Léon Bottou, Genevieve B Orr, and Klaus-Robert Müller.
Efficient backprop. In Neural networks: Tricks of the trade, pages 9–48. Springer,
2012. Cited on page 312.
[205] Elad Hoffer, Itay Hubara, and Daniel Soudry. Train longer, generalize better:
closing the generalization gap in large batch training of neural networks. arXiv
preprint arXiv:1705.08741, 2017. Cited on page 312.
[206] Zhanxing Zhu, Jingfeng Wu, Bing Yu, Lei Wu, and Jinwen Ma. The anisotropic
noise in stochastic gradient descent: Its behavior of escaping from sharp minima
and regularization effects. arXiv preprint arXiv:1803.00195, 2018. Cited on
page 312.
[207] Jian Li, Xuanyuan Luo, and Mingda Qiao. On generalization error bounds of
noisy gradient methods for non-convex learning. arXiv preprint arXiv:1902.00621,
2019. Cited on page 312.
[208] Yeming Wen, Kevin Luk, Maxime Gazeau, Guodong Zhang, Harris Chan, and
Jimmy Ba. Interplay between optimization and generalization of stochastic
gradient descent with covariance noise. arXiv preprint arXiv:1902.08234, 2019.
Cited on page 312.
[209] Jingfeng Wu, Wenqing Hu, Haoyi Xiong, Jun Huan, Vladimir Braverman, and
Zhanxing Zhu. On the noisy gradient descent that generalizes as sgd. In
International Conference on Machine Learning, pages 10367–10376. PMLR,
2020. Cited on page 312.
[210] Jianqing Fan, Zhuoran Yang, and Mengxin Yu. Understanding implicit reg-
ularization in over-parameterized nonlinear statistical model. arXiv preprint
arXiv:2007.08322, 2020. Cited on page 312.
[211] Peng Zhao, Yun Yang, and Qiao-Chu He. Implicit regularization via hadamard
product over-parametrization in high-dimensional linear regression. arXiv
preprint arXiv:1903.09367, 2019. Cited on page 312.
[212] Amit Daniely. Sgd learns the conjugate kernel class of the network. arXiv
preprint arXiv:1702.08503, 2017. Cited on page 312.
524
[213] Yuanzhi Li and Yingyu Liang. Learning overparameterized neural networks via
stochastic gradient descent on structured data. arXiv preprint arXiv:1808.01204,
2018. Cited on pages 312 & 400.
[214] Qianxiao Li, Cheng Tai, and E Weinan. Stochastic modified equations and dy-
namics of stochastic gradient algorithms i: Mathematical foundations. The
Journal of Machine Learning Research, 20(1):1474–1520, 2019. Cited on
pages 312 & 401.
[215] Alex Krizhevsky. One weird trick for parallelizing convolutional neural networks.
arXiv preprint arXiv:1404.5997, 2014. Cited on page 312.
[216] Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski,
Aapo Kyrola, Andrew Tulloch, Yangqing Jia, and Kaiming He. Accurate, large
minibatch sgd: Training imagenet in 1 hour. arXiv preprint arXiv:1706.02677,
2017. Cited on page 312.
[217] Zhiyuan Li, Sadhika Malladi, and Sanjeev Arora. On the validity of modeling sgd
with stochastic differential equations (sdes). arXiv preprint arXiv:2102.12470,
2021. Cited on page 312.
[218] Zeke Xie, Issei Sato, and Masashi Sugiyama. A diffusion theory for deep learning
dynamics: Stochastic gradient descent escapes from sharp minima exponentially
fast. arXiv preprint arXiv:2002.03495, 2020. Cited on page 312.
[219] Stephan Wojtowytsch. Stochastic gradient descent with noise of machine learning
type. part ii: Continuous time analysis. arXiv preprint arXiv:2106.02588, 2021.
Cited on page 312.
[220] Ioannis Karatzas and Steven Shreve. Brownian motion and stochastic calculus,
volume 113. springer, 2014. Cited on page 313.
[221] Patrick Billingsley. Convergence of probability measures. John Wiley & Sons,
2013. Cited on page 313.
[222] David Pollard. Convergence of stochastic processes. Springer Science & Business
Media, 2012. Cited on pages 313 & 319.
[223] Augustin Banyaga and David Hurtubise. Lectures on Morse homology, volume 29.
Springer Science & Business Media, 2013. Cited on pages 320 & 360.
[225] Boris T Polyak. Gradient methods for solving equations and inequalities. USSR
Computational Mathematics and Mathematical Physics, 4(6):17–32, 1964. Cited
on page 331.
525
[226] Joel A Tropp. Convex recovery of a structured signal from independent random
linear measurements. In Sampling Theory, a Renaissance, pages 67–101. Springer,
2015. Cited on pages 332, 363, 365, 367 & 368.
[229] Elton P Hsu. Stochastic analysis on manifolds. Number 38. American Mathe-
matical Soc., 2002. Cited on page 343.
[231] Andrew Holbrook. Differentiating the pseudo determinant. Linear Algebra and
its Applications, 548:293–304, 2018. Cited on page 355.
[232] Jeff Kahn, János Komlós, and Endre Szemerédi. On the probability that a
random±1-matrix is singular. Journal of the American Mathematical Society, 8
(1):223–240, 1995. Cited on page 361.
[233] Lawrence M. Perko. Differential equations and dynamical systems. 2001. Cited
on page 378.
[234] Jeremy Cohen, Simran Kaur, Yuanzhi Li, J Zico Kolter, and Ameet Talwalkar.
Gradient descent on neural networks typically occurs at the edge of stability.
In International Conference on Learning Representations, 2021. URL https://
openreview.net/forum?id=jh-rTtvkGeM. Cited on pages 394, 395, 396 & 420.
[235] Kwangjun Ahn, Jingzhao Zhang, and Suvrit Sra. Understanding the unstable
convergence of gradient descent. arXiv preprint arXiv:2204.01050, 2022. Cited
on page 395.
[236] Prajit Ramachandran, Barret Zoph, and Quoc V Le. Searching for activation
functions. arXiv preprint arXiv:1710.05941, 2017. Cited on page 398.
[238] Yiding Jiang, Behnam Neyshabur, Hossein Mobahi, Dilip Krishnan, and
Samy Bengio. Fantastic generalization measures and where to find them.
In International Conference on Learning Representations, 2020. URL https:
//openreview.net/forum?id=SJgIPJBFvH. Cited on page 399.
526
[239] Laurent Dinh, Razvan Pascanu, Samy Bengio, and Yoshua Bengio. Sharp
minima can generalize for deep nets. In International Conference on Machine
Learning, pages 1019–1028. PMLR, 2017. Cited on page 399.
[240] Mingyang Yi, Qi Meng, Wei Chen, Zhi-ming Ma, and Tie-Yan Liu. Positively
scale-invariant flatness of relu neural networks. arXiv preprint arXiv:1903.02237,
2019. Cited on page 400.
[241] Mingyang Yi, Huishuai Zhang, Wei Chen, Zhi-Ming Ma, and Tie-Yan Liu. Bn-
invariant sharpness regularizes the training model to better generalization. In
Proceedings of the Twenty-Eighth International Joint Conference on Artificial
Intelligence, IJCAI-19, pages 4164–4170. International Joint Conferences on
Artificial Intelligence Organization, 7 2019. Cited on page 400.
[242] Yusuke Tsuzuku, Issei Sato, and Masashi Sugiyama. Normalized flat minima:
Exploring scale invariant definition of flat minima for neural networks using PAC-
Bayesian analysis. In Hal Daumé III and Aarti Singh, editors, Proceedings of the
37th International Conference on Machine Learning, volume 119 of Proceedings
of Machine Learning Research, pages 9636–9647. PMLR, 13–18 Jul 2020. Cited
on page 400.
[243] Akshay Rangamani, Nam H. Nguyen, Abhishek Kumar, Dzung Phan, Sang Peter
Chin, and Trac D. Tran. A scale invariant measure of flatness for deep network
minima. In ICASSP 2021 - 2021 IEEE International Conference on Acoustics,
Speech and Signal Processing (ICASSP), pages 1680–1684, 2021. Cited on
page 400.
[244] Mingyang Yi, Qi Meng, Wei Chen, and Zhi-Ming Ma. Towards accelerat-
ing training of batch normalization: A manifold perspective. arXiv preprint
arXiv:2101.02916, 2021. Cited on page 400.
[245] Jungmin Kwon, Jeongseop Kim, Hyunseo Park, and In Kwon Choi. Asam:
Adaptive sharpness-aware minimization for scale-invariant learning of deep
neural networks. In Marina Meila and Tong Zhang, editors, Proceedings of the
38th International Conference on Machine Learning, volume 139 of Proceedings
of Machine Learning Research, pages 5905–5914. PMLR, 18–24 Jul 2021. Cited
on page 400.
[246] Haowei He, Gao Huang, and Yang Yuan. Asymmetric valleys: Beyond sharp and
flat local minima. arXiv preprint arXiv:1902.00744, 2019. Cited on page 400.
[247] Suriya Gunasekar, Jason Lee, Daniel Soudry, and Nathan Srebro. Implicit bias
of gradient descent on linear convolutional networks. In Advances in Neural
Information Processing Systems, 2018. Cited on page 400.
[248] Lei Wu, Zhanxing Zhu, et al. Towards understanding generalization of deep
learning: Perspective of loss landscapes. arXiv preprint arXiv:1706.10239, 2017.
Cited on page 401.
527
[249] Chao Ma and Lexing Ying. On linear stability of SGD and input-smoothness
of neural networks. In A. Beygelzimer, Y. Dauphin, P. Liang, and J. Wortman
Vaughan, editors, Advances in Neural Information Processing Systems, 2021.
Cited on page 401.
[250] David Barrett and Benoit Dherin. Implicit gradient regularization. In Interna-
tional Conference on Learning Representations, 2021. Cited on page 401.
[251] Yuqing Wang, Minshuo Chen, Tuo Zhao, and Molei Tao. Large learning
rate tames homogeneity: Convergence and balancing effect. arXiv preprint
arXiv:2110.03677, 2021. Cited on page 401.
[252] Chi Jin, Rong Ge, Praneeth Netrapalli, Sham M. Kakade, and Michael I. Jordan.
How to escape saddle points efficiently. In Proceedings of the 34th International
Conference on Machine Learning, pages 1724–1732, 2017. Cited on page 415.
[253] Karen Simonyan and Andrew Zisserman. Very deep convolutional networks for
large-scale image recognition. arXiv preprint arXiv:1409.1556, 2014. Cited on
page 417.
[254] Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. Cifar-10 (canadian insti-
tute for advanced research). URL http://www.cs.toronto.edu/~kriz/cifar.
html. Cited on page 417.
[255] Y-Lan Boureau, Jean Ponce, and Yann LeCun. A theoretical analysis of feature
pooling in visual recognition. In Proceedings of the 27th international conference
on machine learning (ICML-10), pages 111–118, 2010. Cited on page 417.
[256] Yann LeCun and Corinna Cortes. MNIST handwritten digit database. 2010.
URL http://yann.lecun.com/exdb/mnist/. Cited on page 419.
[259] Roger A Horn and Charles R Johnson. Matrix analysis. Cambridge university
press, 2012. Cited on page 493.
[260] Chandler Davis and William Morton Kahan. The rotation of eigenvectors by a
perturbation. iii. SIAM Journal on Numerical Analysis, 7(1):1–46, 1970. Cited
on page 494.
528