0% found this document useful (0 votes)
66 views540 pages

2022PhD - Princeton - Bridging Theory and Practice in Deep Learning Optimization and Generalization

Uploaded by

Hu Shengquan
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
66 views540 pages

2022PhD - Princeton - Bridging Theory and Practice in Deep Learning Optimization and Generalization

Uploaded by

Hu Shengquan
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 540

Bridging Theory and Practice

in Deep Learning:
Optimization and Generalization

Zhiyuan Li

A Dissertation
Presented to the Faculty
of Princeton University
in Candidacy for the Degree
of Doctor of Philosophy

Recommended for Acceptance


by the Department of
Computer Science
Adviser: Sanjeev Arora

September 2022
c Copyright by Zhiyuan Li, 2022.

All rights reserved.


Abstract

Deep learning has been hugely successful for several important applications in
the past decade, yet mathematical understanding has lagged behind its breathtaking
empirical success. Classic machine learning theory is insufficient to explain various
new phenomena in deep learning and to provide guidance on algorithmic choices,
largely due to an oversimplified black box view that ignores the interaction between
the model and the optimization algorithm. This dissertation presents a collection of
theoretical results that take the interplay between the model and the optimization
algorithm into account and aims to bridge the gaps between theory and practice in
deep learning for both generalization and optimization.
For optimization, we first illustrate the mismatches between traditional optimiza-
tion theory and deep networks with normalization layers by presenting an exponentially
increasing learning rate schedule that works well empirically. We explain this surprise
by establishing its equivalence to SGD with Weight Decay and proving that their
convergence rates are fast and insensitive to initialization scale. Based on this, we
design a variant of BERT named SIBERT, which is trainable by SGD and thus more
memory-efficient than adaptive algorithms like ADAM. Finally we present the first
provable yet general setting where gradient descent decreases loss in a non-monotone
way, as observed empirically.
For generalization, we study the implicit bias of optimization algorithms, which
refers to the phenomenon that the algorithm returns solutions with good generalization
despite the existence of solutions with poor generalization due to the overparametrized
models. We first give a rigorous justification of why convolutional networks are
more sample-efficient than fully-connected networks. Then we provide theoretical
justification for the empirical observation that deep linear networks, including matrix
factorization, trained by gradient descent from small initialization implicitly bias
to low-rank solutions. We also identify a condition when gradient descent with
iii
reparametrization is equivalent to mirror descent which can be used to understand
implicit bias of non-linear models and recovers several previous results. We further
show gradient descent has an implicit bias for ‘flatter’ solutions when having certain
gradient noise or its learning rate is larger than two over sharpness of loss.

iv
Acknowledgements

I would like to express my deepest appreciation to my advisor Sanjeev Arora for


his guidance and support. Over the last five years, he has been a continuous source
of knowledge, insights, and inspiration. His research philosophy and vision have
deeply influenced my research taste. His passion, persistence, and commitment to
research constantly encourages me when I got stuck. I’m also extremely grateful for
his generosity in spending time discussing research and improving my writing and
presentation. He is always there for help —no matter knocking on his office door
directly without an appointment or starting a zoom call at midnight. I could not
expect a better advisor.
I owe a lot of thanks to Elad Hazan, Behnam Neyshabur, Simon Du, Chi Jin and
Jason Lee for being wonderful mentors and collaborators. It is a great pleasure to work
with them on various interesting projects. From them, I learned both mathematical
techniques and high-level thinking. I would also thank Elad Hazan, Jason Lee, Chi
Jin and Danqi Chen for serving on my thesis committee and providing invaluable
support and feedback during my job search.
I very much appreciate the internship opportunity at Google working with my host
Sashank Reddi and two other colleagues Srinadh Bhojanapalli and Manzil Zaheer.
Their excellent guidance and the computational resources at Google allowed me to
verify my early theoretical intuition and make it practical and useful, finally becoming
a part of my dissertation.
I was very fortunate to collaborate with many brilliant researchers. I would like to
thank Sanjeev Arora, Srinadh Bhojanapalli, Simon Du, Yaqi Duan, Wei Hu, Rong Ge,
Elad Hazan, Chi Jin, Holden Lee, Jason Lee, Yuanzhi Li, Yuping Luo, Kaifeng Lyu,
Behnam Neyshabur, Sadhika Malladi, Abhishek Panigrahi, Sashank Reddi, Runzhe
Wang, Ruosong Wang, Tianhao Wang, Xiang Wang, Xiaoxia Wu, Dingli Yu, Manzil
Zaheer and Yi Zhang, among others. Many of them have become my friends and
v
helped me in various ways throughout my graduate study, both academically and
personally. I also want to thank my other friends at Princeton University, without
whom my graduate life will be very different.
Many thanks to the Computer Science department staff for their support, warmth
and reliability over the years, particularly Nicki and Mitra.
Thanks to Andrew Chi-Chih Yao for creating the fantastic Yao’s special pilot class
where I did my undergraduate study. Thanks to Jian Li and Pingzhong Tang for
advising my undergraduate research. I’m also grateful to CS department of Cornell
University where I was an undergraduate visiting student under the supervision of
Karthik Sridharan, Carla Gomes and Yexiang Xue. Special thanks to Karthik for first
introducing me to research of theoretical machine learning.
I owe the most to my mom Xuanying Ping and dad Pangao Li for their unconditional
love and support and for developing my interest in Math in my early childhood.
Lastly, I would like to thank my fiancée Juechun for her love and company.
The research in this dissertation was supported in part by NSF, ONR, Simons
Foundation, Schmidt Foundation, Amazon Research, DARPA, SRC, the William G.
Bowen Merit Fellowship and a Microsoft Research Fellowship.

vi
To my family.

vii
Contents

Abstract . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . iii
Acknowledgements . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . v

1 Introduction 1
1.1 The Black Box View in Existing Theory . . . . . . . . . . . . . . . . 2
1.2 Gaps in Generalization Theory and Practice . . . . . . . . . . . . . . 4
1.3 Gaps in Optimization Theory and Practice . . . . . . . . . . . . . . . 7
1.4 Our Contributions . . . . . . . . . . . . . . . . . . . . . . . . . . . . 8
1.5 Previously Published Works . . . . . . . . . . . . . . . . . . . . . . . 12

I Optimization Analysis for Normalized Deep Nets:


The Power of Scale Invariance 13

2 An Exponentially Increasing Learning Rates Schedule for Normal-


ized Networks 14
2.1 Introduction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 15
2.2 Related Work . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 17
2.3 Preliminaries and Notations . . . . . . . . . . . . . . . . . . . . . . . 18
2.4 Deriving Exponential Learning Rate Schedule . . . . . . . . . . . . . 20
2.5 Example Illustrating Interplay of Weight Decay and Normalization Layer 28

viii
2.6 Viewing Exponential Learning Rates via Canonical Optimization Frame-
work . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 30
2.7 Experiments . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 32
2.8 Proofs . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 35
2.9 Side Results on Parameter Norm Convergence . . . . . . . . . . . . . 51
2.10 Scale Invariance in Modern Network Architectures . . . . . . . . . . . 53

3 Convergence Analysis for Gradient Descent on Normalized Net-


works 62
3.1 Introduction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 62
3.2 Preliminary . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 63
3.3 Convergence of GD+WD . . . . . . . . . . . . . . . . . . . . . . . . . 64
3.4 Convergence of SGD+WD . . . . . . . . . . . . . . . . . . . . . . . . 66
3.5 Useful Lemmas . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 68
3.6 Proofs for Convergence of GD+WD . . . . . . . . . . . . . . . . . . . 71
3.7 Proofs for Convergence of SGD+WD . . . . . . . . . . . . . . . . . . 74
3.8 Convergence of SGD for Multi-group Scale Invariant Functions . . . . 81

4 Robust and Memory-efficient Optimization via Designing Scale In-


variant Architectures 85
4.1 Introduction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 86
4.2 Related Work and Background . . . . . . . . . . . . . . . . . . . . . . 90
4.3 Methods . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 91
4.4 Design Details of Scale Invariant BERT . . . . . . . . . . . . . . . . . 96
4.5 Convergence of SGD with Relative Global Clipping . . . . . . . . . . 99
4.6 Experiments . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 101
4.7 Proofs for Convergence of SGD with Relative Global Clipping . . . . 104

ix
II Implicit Bias Along Entire Optimization Trajectory 115

5 Why do ConvNets Generalize Better Than Fully-connected Nets? 116


5.1 Introduction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 117
5.2 Related Works . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 120
5.3 Notation and Preliminaries . . . . . . . . . . . . . . . . . . . . . . . . 121
5.4 Algorithmic Equivariance in Fully-connected Nets Trained by SGD . 125
5.5 Warm-up Examples and Proof Idea for Main Results . . . . . . . . . 131
5.6 Main Results: Sample Complexity Lower Bounds for Equivariant Algo-
rithms . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 135
5.7 Proofs . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 140

6 Low-Rank Implicit Bias of Matrix Factorization 157


6.1 Introduction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 158
6.2 Related Works . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 160
6.3 Preliminaries . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 161
6.4 Warmup Examples . . . . . . . . . . . . . . . . . . . . . . . . . . . . 163
6.5 Main Results: Equivalence between Gradient Descent and Greedy
Low-Rank Learning (GLRL) . . . . . . . . . . . . . . . . . . . . . . . 166
6.6 Benefits of Depth: A View from GLRL . . . . . . . . . . . . . . . . . 177
6.7 The Marginal Value of Being Deeper . . . . . . . . . . . . . . . . . . 184
6.8 Experiments . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 187
6.9 Future Directions . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 192
6.10 Preliminary Lemmas . . . . . . . . . . . . . . . . . . . . . . . . . . . 192
6.11 Proofs for Counter-example . . . . . . . . . . . . . . . . . . . . . . . 193
6.12 Proofs for Dynamical System . . . . . . . . . . . . . . . . . . . . . . 197
6.13 Eigenvalues of Jacobians and Hessians . . . . . . . . . . . . . . . . . 210
6.14 Proofs for the Depth-2 Case . . . . . . . . . . . . . . . . . . . . . . . 218

x
6.15 Proofs for Deep Matrix Factorization . . . . . . . . . . . . . . . . . . 224
6.16 Proof of Linear Convergence to Minimizer . . . . . . . . . . . . . . . 238

7 Implicit Bias of Parametrization: On Equivalence to Mirror Descent248


7.1 Introduction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 248
7.2 Related work . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 252
7.3 Preliminaries and notations . . . . . . . . . . . . . . . . . . . . . . . 253
7.4 Any gradient flow with commuting parametrization is a mirror flow . 260
7.5 Every mirror flow is a gradient flow with commuting parametrization 272
7.6 Related basics for convex analysis . . . . . . . . . . . . . . . . . . . . 277
7.7 Omitted proofs in Section 7.3 . . . . . . . . . . . . . . . . . . . . . . 279
7.8 Omitted proofs in Section 7.4 . . . . . . . . . . . . . . . . . . . . . . 282
7.9 Omitted proofs in Section 7.5 . . . . . . . . . . . . . . . . . . . . . . 299

III Implicit Bias Around Manifold of Minimizers 304

8 Implicit Bias of Stochastic Gradient Descent 305


8.1 Introduction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 306
8.2 Related Works . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 311
8.3 Notations . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 313
8.4 Preliminaries on Stochastic Processes . . . . . . . . . . . . . . . . . . 313
8.5 Main Result: Limiting Diffusion of SGD on Manifold of Minimizers . 319
8.6 Implications and Examples . . . . . . . . . . . . . . . . . . . . . . . . 325
8.7 Provable Generalization Benefit with Label Noise . . . . . . . . . . . 328
8.8 Derivation for Limiting Diffusion of SGD . . . . . . . . . . . . . . . . 333
8.9 Explicit Formula of the Limiting Diffusion . . . . . . . . . . . . . . . 348
8.10 Proof of results in Section 8.6 . . . . . . . . . . . . . . . . . . . . . . 355
8.11 Proof of results in Section 8.7 . . . . . . . . . . . . . . . . . . . . . . 358

xi
9 Implicit Bias of Gradient Descent Operating on Edge of Stability:
Sharpness Reduction 394
9.1 Introduction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 395
9.2 Related Works . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 399
9.3 Warm-up: Quadratic Loss Functions . . . . . . . . . . . . . . . . . . 401
9.4 Notations . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 405
9.5 Main Results: Sharpness Reduction . . . . . . . . . . . . . . . . . . . 406
9.6 Proof Overview . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 411
9.7 Experiments . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 416
9.8 Limitation and Future Work . . . . . . . . . . . . . . . . . . . . . . . 419
9.9 Proofs for Results for Quadratic Loss Functions . . . . . . . . . . . . 420
9.10 Setups for General Loss Functions . . . . . . . . . . . . . . . . . . . . 435
9.11 Analysis of Normalized GD on General Loss Functions . . . . . . . . 454
9.12 Phase I, Proofs of the Main Lemmas . . . . . . . . . . . . . . . . . . 461
9.13 Phase II, Proofs of the Main Lemmas . . . . . . . . . . . . . . . . . . 467
9.14 Some Useful Lemmas About Eigenvalues and Eigenvectors . . . . . . 493

9.15 Analysis of L . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 494
9.16 Additional Experimental Details . . . . . . . . . . . . . . . . . . . . . 503

Bibliography 506

xii
Chapter 1

Introduction

Despite enormous successful applications, deep learning still lacks good mathematical
understanding. Classic machine learning theory is often found incapable of explaining
or predicting various new phenomena in deep learning, not to mention aiding in the
design of better learning algorithms. One main issue behind this failure is that existing
theory usually holds a black box view which decouples the roles of the model and the
optimization algorithm, meaning when analyzing one of them, the rest one is treated
as a black box. Such decoupling typically leads to oversimplified assumptions and
vacuous bounds for deep learning practice.
In recent years, an effort has emerged to develop mathematical understanding
of deep learning via analyzing the trajectory of the optimization algorithm using
the specific property of the model. This dissertation is along this line of works and
presents some recent progress towards identifying and bridging the gap between
machine learning theory and deep learning practice, in both aspects of optimization
and generalization. Our approaches have a special focus on interplay between the
model and the optimization algorithm.

1
1.1 The Black Box View in Existing Theory

We consider the setting of supervised learning. Given a set of training data Zn = {zi }ni=1
and labels Yn = {yi }ni=1 where the data and label are jointly sampled from some
unknown distribution P , the goal of a machine learning algorithm A is to output a
function hn = A(Zn , yn ) that can predict the label of unseen data, and the quality of
prediction is measured by some given metric called loss function, `(ŷi , yi ), where ŷi is
the prediction made by the learned function hn . The notion of loss function can be
extended to the domain of prediction function as well, where L(h) := ni=1 `(h(zi ), yi )
P

and L(h) := E(z,y)∼P `(h(z), y) are used to denote the training and population loss of
a given function h respectively.
A typical machine learning algorithm consists of two parts: a model, which is a
function class H, and an optimization algorithm, which picks a function from the given
function class attaining a small average loss L(hn ) over the given training dataset
(Zn , Yn ). In the context of deep learning, the model is parametrized by real numbers
x ∈ RD and is in the form of artificial neural networks, which can be viewed as
the concatenation of a sequence of parametrized transformations. The optimization
algorithms are typically iterative and based on first-order local search, including
(stochastic) gradient descent (GD) and its variants, e.g., ADAM [1], AdaGrad [2], etc.
The decoupling of model and optimization algorithm originates from the following
standard three-part decomposition of the error of a learning algorithm (compared to the
ground truth function h∗ ): approximation error, optimization error and generalization
error.

E[L(hn )] − L(h∗ ) =E[L(hn ) − L(hn )] +E[L(hn ) − inf L(h)] + E[ inf L(h)] − L(h∗ )
h∈H h∈H

≤ E[sup L(h) − L(h) ] + E[L(hn ) − inf L(h)] + inf L(h) − L(h∗ )



h∈H h∈H h∈H
| {z } | {z } | {z }
Generalization Error Optimization Error Approximation Error

2
Given a function class H, approximation error refers to the gap between the smallest
loss achieved within the function class and the loss of the groundtruth function.
Optimization error is the gap between the training loss of the returned function and
the smallest training loss within the function class. The generalization error refers
to the difference between the loss of the function returned by the learning algorithm
evaluated on training data and new unseen data and occurs due to the finiteness of
training dataset.
Decoupling model and optimization algorithm is actually useful and ease the
analysis and design for some classical learning methods. In statistical learning theory,
the efficacy of the model can be evaluated without knowing details of the optimization
algorithm, i.e., by treating the optimization algorithm as a black box except assuming
its ability to attain small training error. The common approach there is to relax the
generalization error of the learned function to the supremum of generalization error of
all functions in the function class and further upper bound it by certain complexity
measure of the function class via uniform convergence bounds such as VC dimension,
Rademacher complexity, etc. Thus the design principle of the model is to balance the
trade-off between approximation and generalization error.
Similarly, in optimization theory, a lot of convergence results (and matching lower
bounds) have been derived in the oracle setting, where the optimization algorithm
can only access the training loss via querying an oracle regarding information like loss
value and gradient at certain point and the goal is to return parameters achieving
small optimzation error using as few queries as possible. The entire training loss (and
therefore the model) is viewed as a black box except assuming a few properties like
convexity and global smoothness [3]. Here global smoothness refers to the supremum
of the largest eigenvalue of Hessian matrix of the training loss function over the entire
domain.

3
However, this black box view becomes incapable of explaining phenomena in deep
learning. There are gaps between theory and practice in both aspects of generalization
and optimization, and we discuss them below respectively.

1.2 Gaps in Generalization Theory and Practice

Modern neural networks are typically over-parametrized, meaning the number of


parameters can be much larger than the size of dataset, and can even fit dataset with
random labels Zhang et al. [4], suggesting that the uniform convergence bound must
be vacuous. Liu et al. [5] further shows that stochastic gradient descent (SGD) does
converge to solutions with small training loss yet high generalization error, if initialized
from certain “bad” initialization. Thus without using the property of the optimization
algorithm, there is no hope for explaining the small generalization error of the learned
network in practice, not to mention comparing the generalization benefit of different
algorithmic choices.
In face of the inefficacy of the uniform convergence bound using only properties of
function class, researchers have turned to complexity measures of the learned network
instead of the function class, such as parameter norm, margin in prediction, robustness
against input perturbation and sharpness of the loss landscape, etc., and hope to
establish a bound for generalization error using above features of the network, which
is small for networks learned in practice. (See [6] for a good reference) However, even
in this manner, the achieved generalization bound are usually still vacuous and they
even often fail to predict the trend of the change of the generalization error given some
changes in network architecture, which turns out to be a very challenging task [7].
Moreover, even if the bounds get improved and become tight enough for deep learning
in the future, they will not tell us how standard training algorithms like stochastic
gradient descent can find such solutions with explicitly regularizing these complexity

4
measures. As a result, such bounds typically don’t provide strong guidance on how to
make algorithmic choices for better generalization.
One exception here is sharpness of the loss landscape, where Foret et al. [8]
successfully reduced the generalization error by explicitly minimizing the -sharpness
proposed by Keskar et al. [9]. Sharpness based bounds originates from PAC Bayesian
theory [10] and can be made non-vacuous by directly optimizing bound for simple
tasks [11]. Jiang* et al. [12] empirically found that sharpness along the worst direction
1
and average sharpness of all directions correlate best with the generalization error.
However, despite its empirical success, it’s still open why normal training methods
would find solutions with low sharpness and why particular algorithmic choice leads
to flatter solutions than others, e.g. small batch v.s. large batch SGD [9].
In recent years, there is an emerging effort on understanding how particular
optimization algorithms, e.g., gradient descent, can reach solutions with small value
for complexity measures mentioned above or other interesting properties, including
margin maximization in linear models [13] and homogeneous models [14] on separable
data and norm control for infinitely wide neural networks [15]. Such phenomena is
called the implicit bias of the optimization algorithm. We don’t aim to provide a
complete list of related works here, but defer the discussion into each chapter.
However, some of these results are only for simplified or ideal settings and cannot
justify important algorithmic choices in practice towards better generalization, for
example, in the infinitely wide neural networks (or networks in Neural Tangent Kernel
regime [15]), stochastic gradient descent provably converges to the same solution
as full-batch deterministic gradient descent. However in practice, gradient noise in
stochastic gradient descent is observed to be beneficial in terms of generalization,
suggesting neural networks in practice do not completely operate in NTK regime.

1
In the differential form these two sharpness notions are just the the largest eigenvalue of the
Hessian of loss, λ1 (∇2 L(x)), and sum (or average) of eigenvalues of the Hessian, Tr[∇2 L(x)]. For
simplicity we will use sharpness to denote λ1 (∇2 L(x)) in later sections.
5
Thus the generalization part of the dissertation is along the above line of works
on implicit bias but aims to mathematically understand the following previously
unexplained phenomena which are in more realistic settings:

• Why is Convolutional Networks (ConvNets) more sample-efficient than Fully-


Connected Networks (FC Nets) on vision tasks? It’s well known that ConvNets
can express functions related to vision using far less parameters than FC Nets
and thus reduce the approximation error. However, a highly over-parametrized
FC Net has negligible approximation error as it is able to simulate ConvNets of
moderate size by freezing the unnecessary connections. Why FC Nets cannot
achieve small generalization error by realizing the small ConvNet hidden it when
trained by gradient descent? Can we come up with a setting where we can
rigorously justify the gap?

• How does parametrizing the same function class in a different way affect gen-
eralization? It has been observed empirically [16] that gradient on matrix
factorization (writing a symmetric matrix W as U U > and doing gradient descent
with respect to U , which has the same shape of W ) empirically generalizes better
than plain gradient descent (doing gradient descent on W ) when the ground
truth is low-rank. A recent line of works have established correspondence be-
tween mirror descent and gradient descent with a different parametrization and
explains the above phenomena in a restrictive setting. What’s the limit of this
approach using equivalence between mirror descent and reparametrized gradient
descent? How can we resolve the implicit bias for general matrix factorization
problems?

• What is the generalization benefit of stochasticity and large learning rate in


stochastic gradient descent? Can it be explained by sharpness? Large learning
rate and small batch are known to be able to improve generalization empirically
6
in deep learning but the mechanism behind is still in debate. Given the fact
that stochastic gradient descent with any learning rate learns the same solu-
tion as gradient descent with infinitesimal learning rate for linear models, it’s
natural to ask how the implicit bias of stochastic gradient noise is affected by
the optimization landscape and the model, and in particular, its connection to
sharpness.

1.3 Gaps in Optimization Theory and Practice

The high dimensional loss landscape is complicated and the largest eigenvalue of
Hessian, usually called sharpness or local smoothness, can vary drastically depending
on the position. It is often either too pessimistic to think about the convergence rate
of an optimization algorithm in the worst case under some given global smoothness
constant, or too optimistic (sometimes even invalid) to assume that there exists such
a global smoothness constant. As a result, the optimization behavior for gradient-
based algorithms in practice are very different from theoretical predictions and thus
optimization theory cannot give effective guidance on algorithmic choice towards faster
optimization. Indeed, the hyperparameter tuning of the optimization algorithm in
practice is more based on a trial-and-error principle.
For example, standard convergence analysis for gradient descent requires the
learning rate to be smaller than two over sharpness, 2/λ1 (∇2 L(x)) to ensure loss
decrease. However, Cohen et al. [17] empirically showed that for all reasonably large
learning rate, gradient descent in deep learning doesn’t decrease loss in a monotone
way and violate the descent lemma. Instead, they found gradient descent typically
operates in a regime named “Edge of Stability”, which means the sharpness hovers
just above the value 2/learning rate), and the training loss behaves non-monotonically
over short timescales, yet consistently decreases over long timescales.

7
Another mystery in optimization for deep neural networks are the success of
normalization methods, including Batch Normalization [18], Layer Normalization [19],
etc., which makes network training much more robust and efficient. There have been a
lot of debates on its mechanism but no consensus is reached. Part of the reason that
the classical optimization theory fails on networks with normalization layers is that
the the usage of normalization layers also makes the sharpness of the network vary
drastically, or more specifically, scaling inversely to the squared norm of the parameter.
Thus again no single global smoothness constant is correct to be assumed, as it could
be either too optimistic or too pessimistic depending on the actual trajectory of the
optimization algorithm.
In the optimization part of the dissertation, we aim to answer the following
optimization questions emerging with the usage of modern network architectures:

• Under what settings can Edge of Stability regime occur? How can gradient
descent decrease the training loss without the descent lemma and in a non-
monotone way?

• How are optimization on networks with normalization layers different from


standard ones? What is the mathematical mechanism behind robust optimization
brought by Normalization layers? Can we design more efficient algorithms by
leveraging the power of normalization layers?

1.4 Our Contributions

This dissertation presents a collection of recent progress towards bridging practice


and theory in deep learning towards answering questions mentioned in Sections 1.2
and 1.3. The dissertation is split into three parts.

8
In the first part, we study the unconventional optimization behavior of networks
with normalization layers and design more robust and memory-efficient training
methods for BERT as an application.

• In Chapter 2, we show that training can be done using SGD with momentum and
an exponentially increasing learning rate schedule, i.e., learning rate increases by
some (1 + α) factor in every epoch for some α > 0 and prove that it is equivalent
to the standard setting of BatchNorm + SGD + Standard Rate Tuning + Weight
Decay + Momentum. This equivalence holds for other normalization layers as
well, as long as their usage can make the loss function invariant to the scaling of
the parameters.

• In Chapter 3, we theoretically justify efficacy of the above mentioned exponential


learning rate schedule for SGD (or equivalently, the usage of Weight Decay for
scale invariant loss) by proving its convergence rate. In detail, we show that with
Weight Decay, GD and SGD converges to -approximate stationary in O( 12 ) and
O( 14 ) steps. Moreover, we show using exponentially increasing learning rates or
Weight Decay makes optimization more robust and only having a logarithmic
dependency on the scaling of parameter and loss. In comparison, without Weight
Decay or exponentially increasing learning rates, the convergence rate has a
polynomial dependence on the initialization scale and can be significantly slowed
down by large initialization.

• In Chapter 4 we use the theoretical results in Chapter 3 to solve the training


instability issue in Transformers. As an application, we design a scale invariant
version of BERT, called SIERT, which when trained simply by vanilla SGD
achieves performance comparable to BERT trained by adaptive methods like
ADAM on downstream tasks and thus is more memory-efficient. We further

9
propose a novel clipping method named Global Relative Clipping and show that
it enhances training stability both theoretically and experimentally.

In the second part, we present theoretical results related to generalization and


implicit bias of optimization algorithm via analyzing the entire optimization trajectory
using specific properties of models. The results in this part relies on the particular
initialization of the optimization algorithm.

• In Chapter 5, we construct a single natural distribution on Rd × {±1} on


which any orthogonal-invariant algorithm (e.g. fully-connected networks trained
with most gradient-based methods from gaussian initialization) requires Ω(d2 )
samples to generalize while O(1) samples suffice for convolutional architectures.
Furthermore, we demonstrate a single target function, learning which on all
possible distributions leads to an O(1) vs Ω(d2 /ε) gap. Similar results are
achieved for `2 regression and adaptive training algorithms, e.g. Adam and
AdaGrad, which are only permutation equivariant.

• In Chapter 6, we extend the study of implicit bias of gradient gradient descent


for matrix factorization. We provide both theoretical and empirical evidence that
for depth-2 matrix factorization, gradient flow with infinitesimal initialization is
mathematically equivalent to a simple heuristic rank minimization algorithm,
Greedy Low-Rank Learning, under some reasonable assumptions. This result
enables us to construct counter-examples to refute the conjecture from Gunasekar
et al. [16]. We also extend the results to the case where depth ≥ 3, and we
show that the benefit of being deeper is that the above convergence has a much
weaker dependence over initialization magnitude so that this rank minimization
is more likely to take effect for initialization with practical scale.

• In Chapter 7, we give a characterization of the phenomena that optimization


trajectory of gradient descent over an overparametrized model can be understood
10
as a mirror descent over a different objective under a notion termed commuting
parametrization, which encompasses all the previous results in this setting. It
is shown that gradient flow with any commuting parametrization is equivalent
to continuous mirror descent with a related Legendre function. Conversely,
continuous mirror descent with any Legendre function can be viewed as gradient
flow with a related commuting parametrization.

In the third part, we present theoretical results related to generalization and


implicit bias of optimization algorithm via analyzing the optimization trajectory
when the training loss is small and show how optimization algorithms including SGD,
normalized GD and GD on non-smooth training loss can drive the solution from ’sharp’
are to ’flat’ area. The results for GD in Chapter 9 also gives the first provable setting
with accompanying theoretical analysis where GD can decrease the training loss in a
non-monotone way and thus violating descent lemma.

• In Chapter 8, we consider the implicit bias of stochastic gradient descent under


the setting where the minimizers of loss form a smooth manifold. We give a
complete characterization for its implicit bias around the manifold based on
stochastic differential equation, which allows arbitrary noise covariance.

As an application, we show that when the initialization is large, stochastic


gradient descent with certain type of noise called label noise can generalize
better than deterministic gradient descent by implicitly minimizing the trace of
the Hessian, which can be viewed as the average-direction sharpness. In detail,
we show that SGD with label noise only requires O(κ ln d) samples for learning
a κ-sparse overparametrized linear model in Rd [20], while gradient descent with
same initialization requires Ω(d) samples. This upper bound is minimax optimal
e 2 ) upper bound [21].
and improves the previous O(κ

11
• In Chapter 9, we mathematically analyze a new mechanism of implicit regular-
ization in the EoS phase, whereby GD updates due to non-smooth loss landscape
turn out to evolve along some deterministic flow on the manifold of minimum
loss. This is in contrast to many previous results about implicit bias either
relying on infinitesimal updates or noise in gradient. Formally, for any smooth
function L with certain regularity condition, this effect is demonstrated for (1)
η
Normalized GD, i.e., GD with a varying LR ηt = k∇L(x(t))k and loss L; (2) GD
p
with constant LR and loss L − minx L(x). Both provably enter the Edge of
Stability, with the associated flow on the manifold minimizing λ1 (∇2 L). The
above theoretical results have been corroborated by an experimental study.

1.5 Previously Published Works

This dissertation is based on the following previously published works. Chapter 2


is based on the joint work with Sanjeev Arora [22]. Chapters 3 and 4 are based
on the joint work with Srinadh Bhojanapalli, Manzil Zaheer, Sashank Reddi and
Sanjiv Kumar [23]. Chapter 5 is based on the joint work with Yi Zhang and Sanjeev
Arora [24]. Chapter 6 is based on the joint work with Kaifeng Lyu and Yuping Luo [25].
Chapter 8 is based on the joint work with Tianhao Wang and Sanjeev Arora [26].
Chapter 9 is based on the joint work with Ahbishek Panigraphi and Sanjeev Arora [27].

12
Part I

Optimization Analysis for


Normalized Deep Nets:
The Power of Scale Invariance

13
Chapter 2

An Exponentially Increasing
Learning Rates Schedule for
Normalized Networks

Intriguing empirical evidence exists that deep learning can work well with exotic
schedules for varying the learning rate. This chapter suggests that the phenomenon
may be due to Batch Normalization or BN[18], which is ubiquitous and provides
benefits in optimization and generalization across all standard architectures. The
following new results are shown about BN with weight decay and momentum (in other
words, the typical use case which was not considered in earlier theoretical analyses of
stand-alone BN [18, 28, 29]

• Training can be done using SGD with momentum and an exponentially increasing
learning rate schedule, i.e., learning rate increases by some (1 + α) factor in
every epoch for some α > 0. (Precise statement in the paper.) To the best of
our knowledge this is the first time such a rate schedule has been successfully
used, let alone for highly successful architectures. As expected, such training

14
rapidly blows up network weights, but the network stays well-behaved due to
normalization.

• Mathematical explanation of the success of the above rate schedule: a rigorous


proof that it is equivalent to the standard setting of BN + SGD + Standard
Rate Tuning + Weight Decay + Momentum. This equivalence holds for other
normalization layers as well, Group Normalization[30], Layer Normalization[19],
Instance Norm[31], etc.

• A worked-out toy example illustrating the above linkage of hyper-parameters.


Using either weight decay or BN alone reaches global minimum, but convergence
fails when both are used.

2.1 Introduction

Batch Normalization (BN) offers significant benefits in optimization and generalization


across architectures, and has become ubiquitous. Usually best performance is attained
by adding weight decay and momentum in addition to BN.
Usually weight decay is thought to improve generalization by controlling the norm
of the parameters. However, it is fallacious to try to separately think of optimization
and generalization because we are dealing with a nonconvex objective with multiple
optima. Even slight changes to the training surely lead to a different trajectory in
the loss landscape, potentially ending up at a different solution! One needs trajectory
analysis to have a hope of reasoning about the effects of such changes.
In the presence of BN and other normalization schemes, including GroupNorm,
LayerNorm, and InstanceNorm, the optimization objective is scale invariant to the
parameters, which means rescaling parameters would not change the prediction, except
the parameters that compute the output which do not have BN. However, [32] shows
that fixing the output layer randomly doesn’t harm the performance of the network
15
and thus the trainable parameters satisfy scale invariance.(See more in Section 2.10)
The current paper introduces new modes of analysis for such settings. This rigorous
analysis yields the surprising conclusion that the original learning rate (LR) schedule
and weight decay(WD) can be folded into a new exponential schedule for learning
rate: in each iteration multiplying it by (1 + α) for some α > 0 that depends upon
the momentum and weight decay rate.

Theorem 2.1.1 (Main, Informal). SGD on a scale-invariant objective with initial


learning rate η, weight decay factor λ, and momentum factor γ is equivalent to SGD
with momentum factor γ where at iteration t, the learning rate η̃t in the new exponential
learning rate schedule is defined as η̃t = α−2t−1 η without weight decay(λ̃ = 0) where α
is a non-zero root of equation

z 2 − (1 + γ − λη)z + γ = 0, (2.1)

Specifically, when momentum γ = 0, the above schedule can be simplified as


η̃t = (1 − λη)−2t−1 η.

The above theorem requires that the product of learning rate and weight decay

factor, λη, is small than (1 − γ)2 , which is almost always satisfied in practice. The
rigorous and most general version of above theorem is Theorem 2.4.12, which deals
with multi-phase LR schedule, momentum and weight decay.
There are other recently discovered exotic LR schedules, e.g. Triangular LR
schedule [33] and Cosine LR schedule [34], and our exponential LR schedule is an
extreme example of LR schedules that become possible in presence of BN. Such an
exponential increase in learning rate seems absurd at first sight and to the best of
our knowledge, no deep learning success has been reported using such an idea before.
It does highlight the above-mentioned viewpoint that in deep learning, optimization
and regularization are not easily separated. Of course, the exponent trumps the effect
16
of initial lr very fast (See Figure 2.3), which explains why training with BN and
WD is not sensitive to the scale of initialization, since with BN, tuning the scale of
initialization is equivalent to tuning the initial LR η while fixing the product of LR
and WD, ηλ (See Lemma 2.4.7).
Note that it is customary in BN to switch to a lower LR upon reaching a plateau in
the validation loss. According to the analysis in the above theorem, this corresponds
to an exponential growth with a smaller exponent, except for a transient effect when a
correction term is needed for the two processes to be equivalent (see discussion around
Theorem 2.4.12).
Thus the final training algorithm is roughly as follows: Start from a convenient LR
like 0.1, and grow it at an exponential rate with a suitable exponent. When validation
loss plateaus, switch to an exponential growth of LR with a lower exponent. Repeat
the procedure until the training loss saturates.
In Section 2.5, we demonstrate on a toy example how weight decay and normaliza-
tion are inseparably involved in the optimization process. With either weight decay or
normalization alone, SGD will achieve zero training error. But with both turned on,
SGD fails to converge to global minimum.
In Section 2.7, we experimentally verify our theoretical findings on CNNs and
ResNets. We also construct better exponential LR schedules by incorporating the
Cosine LR schedule on CIFAR10, which opens the possibility of even more general
theory of rate schedule tuning towards better performance.

2.2 Related Work

There have been other theoretical analyses of training models with scale-invariance.
Cho and Lee [35] proposed to run Riemanian gradient descent on Grassmann manifold
G(1, n) since the weight matrix is scaling invariant to the loss function. observed

17
ηw
that the effective stepsize is proportional to kxt k2
. Arora et al. [36] show the gradient
is always perpendicular to the current parameter vector which has the effect that
norm of each scale invariant parameter group increases monotonically, which has an
auto-tuning effect. Wu et al. [37] proposes a new adaptive learning rate schedule
motivated by scale-invariance property of Weight Normalization.
Previous work for understanding Batch Normalization. Santurkar et al.
[28] suggested that the success of BN has does not derive from reduction in Internal
Covariate Shift, but by making landscape smoother. Kohler et al. [38] essentially shows
linear model with BN could achieve exponential convergence rate assuming gaussian
inputs, but their analysis is for a variant of GD with an inner optimization loop rather
than GD itself. Bjorck et al. [39] observed that the higher learning rates enabled by
BN empirically improves generalization. Arora et al. [36] proved that with certain
mild assumption, (S)GD with BN finds approximate first order stationary point with
any fixed learning rate. None of the above analyses incorporated weight decay, but
Zhang et al. [40], Hoffer et al. [41], Van Laarhoven [42? ? ] argued qualitatively that
weight decay makes parameters have smaller norms, and thus the effective learning
ηw
rate, kxt k2
is larger. They described experiments showing this effect but didn’t have
a closed form theoretical analysis like ours. None of the above analyses deals with
momentum rigorously.

2.3 Preliminaries and Notations

For batch B = {zi }B


i=1 , network parameter x, we denote the network by fx and the

loss function at iteration t by Lt (fx ) = L(fx , Bt ) . When there’s no ambiguity, we also


use Lt (x) for convenience.
We say a loss function L(x) is scale invariant to its parameter x is for any c ∈ R+ ,
L(x) = L(cx). In practice, the source of scale invariance is usually different types of

18
normalization layers, including Batch Normalization [18], Group Normalization [30],
Layer Normalization [19], Instance Norm [31], etc.
Implementations of SGD with Momentum/Nesterov comes with subtle variations
in literature. We adopt the variant from [43], also the default in PyTorch [44]. L2
regularization (a.k.a. Weight Decay) is another common trick used in deep learning.
Combining them together, we get the one of the mostly used optimization algorithms
below.

Definition 2.3.1. [SGD with Momentum and Weight Decay] At iteration t, with
randomly sampled batch Bt , update the parameters xt and momentum vt as following:

xt =xt−1 − ηt−1 vt (2.2)


 
λt−1 2
vt =γvt−1 + ∇x Lt (xt−1 ) + kxt−1 k , (2.3)
2

where ηt , λt are the learning rate and weight decay factor at iteration t respectively
and γ is the momentum coefficient. Usually, v0 is initialized to be 0.
For ease of analysis, we will use the following equivalent of Definition 2.3.1.

 
xt − xt−1 xt−1 − xt−2 λt−1 2
=γ − ∇x (L(xt−1 ) + kxt−1 k2 , (2.4)
ηt−1 ηt−2 2
x0 −x−1
where η−1 and x−1 must be chosen in a way such that v0 = η−1
is satisfied, e.g.
when v0 = 0, x−1 = x0 and η−1 could be arbitrary.

A key source of intuition is the following simple lemma about scale-invariant


networks [36]. The first property ensures GD (with momentum) always increases the
norm of the weight.(See Lemma 2.9.1 in Section 2.9) and the second property says
that the gradients are smaller for parameteres with larger norm, thus stabilizing the
trajectory from diverging to infinity.

19
Lemma 2.3.2 (Scale Invariance). If for any c ∈ R+ , L(x) = L(cx), then
(1). h∇x L, xi = 0;
(2). ∇x L x=x0
= c∇x L x=cx0
, for any c > 0

2.4 Deriving Exponential Learning Rate Schedule

As a warm-up in Section 2.4.1 we show that if momentum is turned off then Fixed
LR + Fixed WD can be translated to an equivalent Exponential LR. In Section 2.4.2
we give a more general analysis on the equivalence between Fixed LR + Fixed WD
+ Fixed Momentum Factor and Exponential LR + Fixed Momentum Factor. While
interesting, this still does completely apply to real-life deep learning where reaching
full accuracy usually requires multiple phases in training where LR is fixed within a
phase and reduced by some factor from one phase to the next. Section 2.4.3 shows
how to interpret such a multi-phase LR schedule + WD + Momentum as a certain
multi-phase exponential LR schedule with Momentum.

2.4.1 Replacing WD by Exponential LR in Momentum-Free

SGD

We use notation of Section 2.3 and assume LR is fixed over iterations, i.e. ηt = η0 ,
and γ (momentum factor) is set as 0. We also use λ to denote WD factor and x0 to
denote the initial parameters.
The intuition should be clear from Lemma 2.3.2, which says that shrinking parame-
ter weights by factor ρ (where ρ < 1) amounts to making the gradient ρ−1 times larger
without changing its direction. Thus in order to restore the ratio between original
parameter and its update (LR×Gradient), the easiest way would be scaling LR by ρ2 .
This suggests that scaling the parameter x by ρ at each step is equivalent to scaling
the LR η by ρ−2 .
20
To prove this formally we use the following formalism. We’ll refer to the vector
(x, η) the state of a training algorithm and study how this evolves under various
combinations of parameter changes. We will think of each step in training as a
mapping from one state to another. Since mappings can be composed, any finite
number of steps also correspond to a mapping. The following are some basic mappings
used in the proof.

1. Run GD with WD for a step: GDρt (x, η) = (ρx − η∇Lt (x), η);

2. Scale the parameter x: Πc1 (x, η) = (cx, η);

3. Scale the LR η: Πc2 (x, η) = (x, cη).

For example, when ρ = 1, GD1t is vanilla GD update without WD, also abbreviated as
GDt . When ρ = 1 − λη0 , GD1−λη
t
0
is GD update with WD λ and LR η0 . Here Lt is
the loss function at iteration t, which is decided by the batch of the training samples
Bt in tth iteration. Below is the main result of this subsection, showing our claim that
GD + WD ⇔ GD+ Exp LR (when Momentum is zero). It will be proved after a
series of lemmas.

Theorem 2.4.1 (WD ⇔ Exp LR). For every ρ < 1 and positive integer t following
holds:

h t 2t
i −1 −2 −2 −1
GDρt−1 ◦ · · · ◦ GDρ0 = Πρ1 ◦ Πρ2 ◦ Πρ2 ◦ GDt−1 ◦ Π2ρ ◦ · · · ◦ GD1 ◦ Πρ2 ◦ GD0 ◦ Πρ2 .

With WD being λ, ρ is set as 1 − λη0 and thus the scaling factor of LR per iteration
is ρ−2 = (1 − λη0 )−2 , except for the first iteration it’s ρ−1 = (1 − λη0 )−1 .
We first show how to write GD update with WD as a composition of above defined
basic maps.

−1
Lemma 2.4.2. GDρt = Πρ2 ◦ Πρ1 ◦ GDt ◦ Πρ2 .

21
−2
Below we will define the proper notion of equivalence such that (1). Πρ1 ∼ Πρ2 ,
−1 −1
which implies GDρt ∼ Πρ2 ◦ GDt ◦ Πρ2 ; (2) the equivalence is preserved under future
GD updates.
We first extend the equivalence between weights (same direction) to that between
states, with additional requirement that the ratio between the size of GD update and
that of parameter are the same among all equivalent states, which yields the notion of
Equivalent Scaling.

Definition 2.4.3 (Equivalent States). (x, η) is equivalent to (x0 , η 0 ) iff ∃c > 0, (e


x, ηe) =
2 c 2
[Πc1 ◦ Πc2 ](x, η) = (cx, c2 η), which is also denoted by (e
x, ηe) ∼ (x, η). Πc1 ◦ Πc2 is called
Equivalent Scaling for all c > 0.

The following lemma shows that equivalent scaling commutes with GD update
with WD, implying that equivalence is preserved under GD update (Lemma 2.4.4).
This anchors the notion of equivalence — we could insert equivalent scaling anywhere
in a sequence of basic maps(GD update, LR/parameter scaling), without changing
the final network.

2 2
Lemma 2.4.4. For any constant c, ρ > 0 and t ≥ 0, GDρt ◦[Πc1 ◦Πc2 ] = [Πc1 ◦Πc2 ]◦GDρt .
c c
In other words, (x, η) ∼ (x0 , η 0 ) =⇒ GDρt (x, η) ∼ GDρt (x0 , η 0 ).

Now we formally define equivalence relationship between maps using equivalent


scalings.

Definition 2.4.5 (Equivalent Maps). Two maps F, G are equivalent iff ∃c > 0,
2 c
F = Πc1 ◦ Πc2 ◦ G, which is also denoted by F ∼ G.

ρ −1 −1
Proof of Theorem 2.4.1. By Lemma 2.4.2,, GDρt ∼ Πρ2 ◦GDt ◦Πρ2 . By Lemma 2.4.4,
c c
GD update preserves map equivalence, i.e. F ∼ G ⇒ GDρt ◦ F ∼ GDρt ◦ G, ∀c, ρ > 0.
Thus,

ρt −1 −2 −2 −1
GDρt−1 ◦ · · · ◦ GDρ0 ∼ Πρ2 ◦ GDt−1 ◦ Πρ2 ◦ · · · ◦ GD1 ◦ Πρ2 ◦ GD0 ◦ Πρ2 .
22
2.4.2 Replacing WD by Exponential LR: Case of constant

LR with momentum

In this subsection the setting is the same to that in Subsection 2.4.1 except that
the momentum factor is γ instead of 0. Suppose the initial momentum is v0 , we
set x−1 = x0 − v0 η. Presence of momentum requires representing the state of the
algorithm with four coordinates, (x, η, x0 , η 0 ), which stand respectively for the current
parameters/LR and the buffered parameters/LR (from last iteration) respectively.
Similarly, we define the following basic maps and equivalence relationships.
   
ρ 0 0 x−x0
1. Run GD with WD for a step: GDt (x, η, x , η ) = ρx + η γ η0 − ∇Lt (x) , η, x, η ;

2. Scale Current parameter x Πc1 (x, η, x0 , η 0 ) = (cx, η, x0 , η 0 );

3. Scale Current LR η: Πc2 (x, η, x0 , η 0 ) = (x, cη, x0 , η 0 );

4. Scale Buffered parameter x0 : Πc3 (x, η, x0 , η 0 ) = (x, η, cx0 , η 0 );

5. Scale Buffered parameter η 0 : Πc4 (x, η, x0 , η 0 ) = (x, η, x0 , cη 0 ).

Definition 2.4.6 (Equivalent States). (x, η, x0 , η 0 ) is equivalent to (e x, ηe, x e0 , ηe0 ) iff


h i
2 2
∃c > 0, (x, η, x0 , η 0 ) = Πc1 ◦ Πc2 ◦ Πc3 ◦ Πc4 (e e0 , ηe0 ) = (ce
x, ηe, x x0 , c2 ηe0 ), which is
x, c2 ηe, ce
c 2 2
also denoted by (x, η, x0 , η 0 ) ∼ (e e0 , ηe0 ). We call Πc1 ◦ Πc2 ◦ Πc3 ◦ Πc4 Equivalent
x, ηe, x
Scalings for all c > 0.

Again by expanding the definition, we show equivalent scalings commute with GD


update.
h i h i
2 2 2 2
Lemma 2.4.7. ∀c, ρ > 0 and t ≥ 0, GDρt ◦ Πc1 ◦ Πc2 ◦ Πc3 ◦ Πc4 = Πc1 ◦ Πc2 ◦ Πc3 ◦ Πc4 ◦
GDρt .

Similarly, we can rewrite GDρt as a composition of vanilla GD update and other


scalings by expanding the definition, when the current and buffered LR are the same
in the input of GDρt .
23
Figure 2.1: Taking PreResNet32 with standard hyperparameters and replacing WD
during first phase (Fixed LR) by exponential LR according to Theorem 2.4.9 to the
schedule ηet = 0.1 × 1.481t , momentum 0.9. Plot on right shows weight norm w of the
first convolutional layer in the second residual block grows exponentially, satisfying
kwt k2
ηet
= constant. Reason being that according to the proof it is essentially the norm
square of the weights when trained with Fixed LR + WD + Momentum, and published
hyperparameters kept this norm roughly constant during training.

Lemma 2.4.8. For any input (x, η, x0 , η), if α > 0 is a root of α + γα−1 = ρ + γ, then
h i
−1
GDρt (x, η, x0 , η) = Πα4 ◦ Πα2 ◦ Πα1 ◦ GDt ◦ Πα2 ◦ Πα3 ◦ Πα4 (x, η, x0 , η). In other words,
h −1 i
α −1 −1 −1
GDρt (x, η, x0 , η) ∼ Πα3 ◦ Πα4 ◦ Πα2 ◦ GDt ◦ Πα2 ◦ Πα3 ◦ Πα4 (x, η, x0 , η). (2.5)

Though looking complicated, the RHS of Equation (2.5) is actually the desired
−1 −1
Πα2 ◦ GDt ◦ Πα2 conjugated with some scaling on momentum part Πα3 ◦ Πα4 , and
−1 −1
Πα3 ◦ Πα4 in the current update cancels with the Πα3 ◦ Πα4 in the next update. Now we
are ready to show the equivalence between WD and Exp LR schedule when momentum
is turned on for both.

Theorem 2.4.9 (GD + WD ⇔ GD+ Exp LR; With Momentum). The following
defined two sequences of parameters ,{xt }∞ xt } ∞
t=0 and {e et = αt xt , thus they
t=0 , satisfy x

correspond to the same networks in function space, i.e. fxt = fxet , ∀t ∈ N, given
x e−1 = x−1 α, and ηet = η0 α−2t−1 .
e0 = x0 , x

xt −xt−1 γ(xt−1 −xt−2 )


1. η0
= η0
− ∇x (L(xt−1 ) + λ2 kxt−1 k22 )

et −e
x xt−1 xt−1 −e
γ(e xt−2 )
2. ηet
= ηet−1
− ∇x L(e
xt−1 )
24
where α is a positive root of equation x2 − (1 + γ − λη0 )x + γ = 0, which is always
smaller than 1(See Section 2.8.1). When γ = 0, α = 1 − λη0 is the unique non-zero
solution.


Remark 2.4.10. Above we implicitly assume that λη0 ≤ (1− γ)2 such that the roots
are real and this is always true in practice. For instance of standard hyper-parameters
λη0
where γ = 0.9, η0 = 0.1, λ = 0.0005, √
(1− γ)2
≈ 0.019  1.
h i
−1
Proof. Note that (e
x0 , ηe0 , x
e−1 , ηe−1 ) = Πα2 ◦ Πα3 ◦ Πα4 (x0 , η0 , x0 , η0 ), it suffices to
show that

h −1 −1 −1 −2 −2 −1
i
Πα3 ◦ Πα4 ◦ Πα2 ◦ GDt−1 ◦ Πα2 ◦ · · · ◦ GD1 ◦ Πα2 ◦ GD0 ◦ Πα2 ◦ Πα3 ◦ Πα4 (x0 , η0 , x0 , η0 )
αt
∼ GD1−λη
t−1
0
◦ · · · ◦ GD1−λη
0
0
(x0 , η0 , x0 , η0 ), ∀t ≥ 0.

which follows immediately from Lemma 2.4.7 and Lemma 2.4.8 by induction.

2.4.3 Replacing WD by Exponential LR: Case of multiple

LR phases

Usual practice in deep learning shows that reaching full training accuracy requires
reducing the learning rate a few times.

Definition 2.4.11. Step Decay is the (standard) learning rate schedule, where training
has K phases I = 0, 1, . . . , K − 1, where phase I starts at iteration TI (T0 = 0), and
all iterations in phase I use a fixed learning rate of ηI∗ .

The algorithm state in Section 2.4.2, consists of 4 components including buffered


and current LR. When LR changes, the buffered and current LR are not equal, and
thus Lemma 2.4.8 cannot be applied any more. In this section we show how to fix this
issue by adding extra momentum correction. In detail, we show the below defined Exp
LR schedule leads the same trajectory of networks in function space, with one-time
25
Figure 2.2: PreResNet32 trained with standard Step Decay and its corresponding
Tapered-Exponential LR schedule. As predicted by Theorem 2.4.12, they have similar
trajectories and performances.

momentum correction at the start of each phase. We empirically find on CIFAR10


that ignoring the correction term does not change performance much.

Theorem 2.4.12 (Tapered-Exponential LR Schedule). There exists a way to correct


the momentum only at the first iteration of each phase, such that the following
Tapered-Exponential LR schedule (TEXP) {e
ηt } with momentum factor γ and no
WD, leads the same sequence networks in function space as that of Step Decay LR
schedule(Definition 2.4.11) with momentum factor γ and WD λ.


)−2

ηe

t−1 × (αI−1 if TI−1 + 1 ≤ t ≤ TI − 1, I ≥ 1;
ηet = (2.6)
ηI∗
(αI∗ )−1 (αI−1

)−1

ηet−1 ×

∗ × if t = TI , I ≥ 1,
ηI−1

q
2
1+γ−ληI∗ + (1+γ−ληI∗ ) −4γ
where αI∗ = 2
, ηe0 = η0 · (α0∗ )−1 = η0∗ · (α0∗ )−1 .

The analysis in previous subsection give the equivalence within each phase, where
the same LR is used throughout the phase. To deal with the difference between
buffered LR and current LR when entering new phases, the idea is to pretend ηt−1 = ηt
xt −xt−1
and xt−1 becomes whatever it needs to maintain ηt−1
such that we can again apply
Lemma 2.4.8, which requires the current LR of the input state is equal to its buffered
LR. Because scaling α in RHS of Equation (2.5) is different in different phases, so
26
unlike what happens within each phase, they don’t cancel with each other at phase
transitions, thus remaining as a correction of the momentum. The proofs are delayed
to Section 2.8, where we proves a more general statement allowing phase-dependent
WD, {λI }K−1
I=0 .

Alternative interpretation of Step Decay to exponential LR sched-


ule:Below we present a new LR schedule, TEXP++, which is exactly equivalent to
Step Decay without the need of one-time correction of momentum when entering each
phase. We further show in Section 2.8.1 that when translating from Step Decay, the
TEXP++ we get is very close to the original TEXP (Equation (2.6)), i.e. the ratio
0
ηet+1 ηet+1
between the LR growth per round, ηet
/ ηe0 converges to 1 exponentially each phase.
t

For example, with WD 0.0005, max LR 0.1, momentum factor 0.9, the ratio is within
1 ± 0.0015 ∗ 0.9t−TI , meaning TEXP and TEXP++ are very close for Step Decay with
standard hyperparameters.

Theorem 2.4.13. The following two sequences of parameters ,{xt }∞ xt } ∞


t=0 and {e t=0 ,

define the same sequence of network functions, i.e. fxt = fxet , ∀t ∈ N, given the initial
conditions, x e0 = P0 x0 , xe−1 = P−1 x−1 .
 
1. xtη−x t−1
t−1
= γ xt−1 −xt−2
ηt−2
− ∇ x (L(xt−1 ) +
λt−1
2
kxt−1 k22 , for t = 1, 2, . . .;
et −e
x xt−1 −e
2. = γ xet−1ηet−2
ηet−1
xt−2
− ∇x L(e
xt−1 ), for t = 1, 2, . . .,
t
αi−1 , ∀t ≥ −1 and αt recursively defined as
Q
where ηet = Pt Pt+1 ηt , Pt =
i=−1

ηt−1 −1
αt = −ηt−1 λt−1 + 1 + γ(1 − αt−1 ), ∀t ≥ 1. (2.7)
ηt−2

ηt }∞
The LR schedule {e t=0 is called Tapered Exponential ++, or TEXP++.

27
2.5 Example Illustrating Interplay of Weight De-

cay and Normalization Layer


The paper so far has shown that effects of different hyperparameters in training are
not easily separated, since their combined effect on the trajectory is complicated. We
give a simple example to illustrate this, where convergence is guaranteed if we use
either BatchNorm or weight decay in isolation, but convergence fails if both are used.
(Momentum is turned off for clarity of presentation)
Setting: Suppose we are fine-tuning the last linear layer of the network, where the
input of the last layer is assumed to follow a standard Gaussian distribution N (0, Im ),
and m is the input dimension of last layer. We also assume this is a binary classification
task with logistic loss, l(u, y) = ln(1 + exp(−uy)), where label y ∈ {−1, 1} and u ∈ R
is the output of the neural network. The training algorithm is SGD with constant
LR and WD, and without momentum. For simplicity we assume the batch size B is
very large so we could assume the covariance of each batch Bt concentrates and is
approximately equal to identity, namely B1 B >
P
i=1 zt,b zt,b ≈ Im . We also assume the the

input of the last layer are already separable, and w.l.o.g. we assume the label is equal
to the sign of the first coordinate of z ∈ Rm , namely sign (x1 ) . Thus the training loss
and training error are simply

ln(1 + exp(−z > xy)) ,


 
L(x) = E
z∼N (0,Im ),y=sign(x1 )
 1 x1
z > xy ≤ 0 = arccos

err(x) = Pr
z∼N (0,Im ),y=sign(x1 ) π kxk

Case 1: WD alone: Since both the above objective with L2 regularization is strongly
convex and smooth in x, vanilla GD with suitably small learning rate could get
arbitrarily close to the global minimum for this regularized objective. In our case,

28
q
large batch SGD behaves similarly to GD and can achieve O( ηλ
B
) test error following
the standard analysis of convex optimization.
Case 2: BN alone: Add a BN layer after the linear layer, and fix scalar and bias term
to 1 and 0. The objective becomes

 
> x
LBN (x) = E [LBN (x, z)] = E ln(1 + exp(−z y)) .
z∼N (0,Im ),y=sign(x1 ) z∼N (0,Im ),y=sign(x1 ) kxk

From Section 2.8.6, there’s some constant C, such that ∀x ∈ Rm with constant
C
probability, k∇x LBN (x, z)k ≥ kxk
. By Pythagorean Theorem, kxt+1 k4 = (kxt k2 +
η 2 k∇x LBN (xt , z)k2 )2 ≥ kxt k4 + 2η 2 kxt k2 k∇x LBN (xt , z)k2 . As a result, for any fixed
learning rate, kxt+1 k4 ≥ 2 ti=1 η 2 kxk2 k∇x LBN (xi , z)k2 grows at least linearly with
P

high probability. Following the analysis by Arora et al. [36], this is like reducing the
effective learning rate, and when kxt k is large enough, the effective learning rate is
small enough, and thus SGD can find the local minimum, which is the unique global
minimum.
Case 3: Both BN and WD: When BN and WD are used together, no matter how
small the noise is, which comes from the large batch size, the following theorem shows

that SGD will not converge to any solution with error smaller than O( ηλ), which is
independent of the batch size (noise level).

Theorem 2.5.1. [Nonconvergence] Starting from iteration any T0 , with probability


ε
1 − δ over the randomness of samples, the training error will be larger than π
at least

1 k2 ε
64kwT0 B
once for the following consecutive 2(ηλ−2ε2 )
ln √
η m−2
+ 9 ln 1δ iterations.

Proof Sketch. (See full proof in Section 2.8.) The high level idea of this proof is that
if the test error is low, the weight is restricted in a small cone around the global
minimum, and thus the amount of the gradient update is bounded by the size of the
cone. In this case, the growth of the norm of the weight by Pythagorean Theorem is
not large enough to cancel the shrinkage brought by weight decay. As a result, the
29
norm of the weight converges to 0 geometrically. Again we need to use the lower bound
for size of the gradient, that k∇x Lt k = Θ( kxηt k m
p
B
) holds with constant probability.
Thus the size of the gradient will grow along with the shrinkage of kxt k until they’re
comparable, forcing the weight to leave the cone in next iteration.

2.6 Viewing Exponential Learning Rates via

Canonical Optimization Framework

This section tries to explain why the efficacy of exponential LR in deep learning is
mysterious to us, at least as viewed in the canonical framework of optimization theory.
Canonical framework for analysing 1st order methods This focuses on proving that
each —or most—steps of GD noticeably reduce the objective, by relying on some
assumption about the spectrum norm of the hessian of the loss, and most frequently,
the smoothness, denoted by β. Specifically, for GD update xt+1 = xt − η∇L(xt ), we
have

β βη
L(xt+1 ) − L(xt ) ≤ (xt+1 − xt )> ∇L(xt ) + kxt+1 − xt k2 = −η(1 − )k∇L(xt )k2 .
2 2

When β < η2 , the first order term is larger than the second order one, guaranteeing
the loss value decreases. Since the analysis framework treats the loss as a black box
(apart from the assumed bounds on the derivative norms), and the loss is non-convex,
the best one can hope for is to prove speedy convergence to a stationary point (where
gradient is close to 0). An increasing body of work proves such results.
Now we turn to difficulties in understanding the exponential LR in context of the
above framework and with scale-invariance in the network.

1. Since loss is same for x and c · x for all c > 0 a simple calculation shows that
along any straight line through the origin, smoothness is a decreasing function of
30
c, and is very high close to origin. (Note: it is also possible to one can show the
following related fact: In any ball containing the origin, the loss is nonconvex.)

Thus if one were trying to apply the canonical framework to argue convergence
to a stationary point, the natural idea would be to try to grow the norm
of the parameters until smoothness drops enough that the above-mentioned
Canonical Framework starts to apply. Arora et al. [36] showed this happens in
GD with fixed LR (WD turned off), and furthermore the resulting convergence
rate to stationary point is asymptotically similar to analyses of nonconvex
optimization with learning rate set as in the Canonical framework. Santurkar
et al. [28] observed similar phenomenon in experiments, which they described as
a smoothening effect of the objective due to BN.

2. The Canonical Framework can be thought of as a discretization of continuous


gradient descent (i.e., gradient flow): in principle it is possible to use arbitrarily
small learning rate, but one uses finite learning rate merely to keep the number
of iterations small. The discrete process approximates the continuous process
due to smoothness being small.

In case of gradient flow with weight decay (equivalently, with exponential LR


schedule) the discrete process cannot track the continuous process for very long,
which suggests that any explanation of the benefits of exponential LR may need
to rely on discrete process being somehow better. The reason being that for
gradient flow one can decouple the speed of the xt into the tangential and the
radial components, where the former one has no effect on the norm and the latter
one has no effect on the objective but scales the tangential gradient exponentially.
Thus the Gradient Flow with WD gives exactly the same trajectory as vanilla
Gradient Flow does, excepting a exponential reparametrization with respect to
time t.

31
2
3. It can be shown that if the local smoothness is upperbounded by η
(as stipulated
in Canonical Framework) during a sequence xt (t = 1, 2, . . .) of GD updates with
WD and constant LR then such sequence satisfies xt → 0. This contrasts with
the usual experimental observation that xt stays bounded away from 0. One
should thus conclude that in practice, with constant LR and WD, smoothness
doesn’t always stay small (unlike the above analyses where WD is turned off).

2.7 Experiments

The translation to exponential LR schedule is exact except for one-time momentum


correction term entering new phases. The experiments explore the effect of this
correction term. The Tapered Exponential(TEXP) LR schedule contains two parts
ηI
when entering a new phase I: an instant LR decay ( ηI−1 ) and an adjustment of the

growth factor (αI−1 → αI∗ ). The first part is relative small compared to the huge
exponential growing. Thus a natural question arises: Can we simplify TEXP LR
schedule by dropping the part of instant LR decay?
Also, previously we have only verified our equivalence theorem in Step Decay LR
schedules. But it’s not sure how would the Exponential LR schedule behave on more
rapid time-varying LR schedules such as Cosine LR schedule.
Settings: We train PreResNet32 on CIFAR10. The initial learning rate is 0.1 and
the momentum is 0.9 in all settings. We fix all the scalar and bias of BN, because
otherwise they together with the following conv layer grow exponentially, sometimes
exceeding the range of Float32 when trained with large growth rate for a long time.
We fix the parameters in the last fully connected layer for scale invariance of the
objective.

32
Figure 2.3: Instant LR decay has only temporary effect when LR growth ηet /eηt−1 − 1
is large. The blue line uses an exponential LR schedule with constant exponent. The
orange line multiplies its LR by the same constant each iteration, but also divide
LR by 10 at the start of epoch 80 and 120. The instant LR decay only allows the
parameter to stay at good local minimum for 1 epoch and then diverges, behaving
similarly to the trajectories without no instant LR decay.

2.7.1 The benefit of instant LR decay

We tried the following LR schedule (we call it TEXP--). Interestingly, up to correction


of momentum when entering a new phase, this schedule is equivalent to a constant
LR schedule, but with the weight decay coefficient reduced correspondingly at the
start of each phase. (See Theorem 2.8.2 and Figure 2.5)


)−2

ηet × (αI−1
 if TI−1 + 1 ≤ t ≤ TI − 1, I ≥ 1;
TEXP--: ηet+1 = (2.8)
ηet × (αI∗ )−1 (αI−1

)−1

 if t = TI , I ≥ 1,

q
2
1+γ−ληI∗ + (1+γ−ληI∗ ) −4γ
where αI∗ = 2
, ηe0 = η0 · (α0∗ )−1 = η0∗ · (α0∗ )−1 .

2.7.2 Better Exponential LR Schedule with Cosine LR

We applied the TEXP LR schedule (Theorem 2.4.12) on the Cosine LR schedule [34],
where the learning rate changes every epoch, and thus correction terms cannot be
1+cos( Tt π)
ignored. The LR at epoch t ≤ T is defined as: ηt = η0 2
. Our experiments

33
Figure 2.4: Instant LR decay is crucial when LR growth ηet /e ηt−1 − 1 is very small.
The original LR of Step Decay is decayed by 10 at epoch 80, 120 respectively. In the
third phase, LR growth ηet /e ηt−1 − 1 is approximately 100 times smaller than that in
the third phase, it would take TEXP-- hundreds of epochs to reach its equilibrium.
As a result, TEXP achieves better test accuracy than TEXP--. As a comparison, in
ηt−1 − 1 is only 10 times smaller than that in the first phase and
the second phase, ηet /e
it only takes 70 epochs to return to equilibrium.

Figure 2.5: The orange line corresponds to PreResNet32 trained with constant LR
and WD divided by 10 at epoch 80 and 120. The blue line is TEXP-- corresponding
to Step Decay schedule which divides LR by 10 at epoch 80 and 120. They have
similar trajectories and performances by a similar argument to Theorem 2.4.12.(See
Theorem 2.8.2 and its proof in Section 2.8)

34
Figure 2.6: Both Cosine and Step Decay schedule behaves almost the same as their
exponential counterpart, as predicted by our equivalence theorem. The (exponential)
Cosine LR schedule achieves better test accuracy, with a entirely different trajectory.

show this hybrid schedule with Cosine LR performs better on CIFAR10 than Step
Decay, but this finding needs to be verified on other datasets.

2.8 Proofs

2.8.1 Omitted Proof in Section 2.4

Lemma 2.8.1 (Some Facts about Equation (2.1)). Suppose z 1 , z 2 (z 1 ≥ z 2 ) are the
two real roots of the the following equation, we have

z 2 − (1 + γ − λη)z + γ = 0

√ √
1 1+γ−λη+ (1−γ)2 −2(1+γ)λη+λ2 η 2 2 1+γ−λη− (1−γ)2 −2(1+γ)λη+λ2 η 2
1. z = 2
, z = 2

√ 2
2. z 1 , z 2 are real ⇐⇒ λη ≤ (1 − γ) ;

3. z 1 z 2 = γ, z 1 + z 2 = (1 + γ − λη);

4. γ ≤ z 2 ≤ z 1 ≤ 1;

λη 1 λη
5. Let t = 1−γ
, we have z 1 ≥ 1+t
≥1−t=1− 1−γ
.

35
6. if we view z 1 (λη), z 2 (λη) as functions of λη, then z 1 (λη) is monotone decreasing,
z 2 (η) is monotone increasing.

Proof of Lemma 2.8.1.

4. Let f (x) = z 2 − (1 + γ − λη)z + γ, we have f (1) = f (γ) = λη ≥ 0. Note the


1+γ−λη
minimum of f is taken at x = 2
∈ [0, 1], the both roots of f (x) = 0 must
lie between 0 and 1, if exists.

5. It holds that
p
1 − γ + λη − (1 − γ)2 − 2(1 + γ)λη + λ2 η 2
1 − z1 =
q 2
1+γ
1+t− 1− 1−γ
t + t2
= (1 − γ)
2
1+γ
2t + 2 1−γ t
= (1 − γ) q
2(1 + t + 1 − 1+γ1−γ
t + t2 )
4
1−γ
t
≤ (1 − γ)
4(1 + t)
t
=
(1 + t)

6. Note that (z 1 − z 2 )2 = (z 1 + z 2 )2 − 4z 1 z 2 = (1 + γ − λη)2 − 4γ is monotone


decreasing, since z 1 (λη) + z 2 (λη) is constant, z 1 (λη) ≥ z 2 (λη), z 1 (λη) must be
decreasing and z 2 (λη) must be increasing.

2.8.2 Omitted proofs in Section 2.4.1

Proof of Lemma 2.4.2. For any (x, η), we have

η −1
GDρt (x, η) = (ρx − η∇Lt (x), η) = [Πρ1 ◦ Πρ2 ◦ GDt ](x, ) = [Πρ1 ◦ Πρ2 ◦ GDt ◦ Π2ρ ](x, η).
ρ

36
Proof of Lemma 2.4.4. For any (x, η), we have

h i
2 ∗
GDt ◦ Πc1 ◦ Πc2 (x, η) = GDt (cx, c2 η) = (cx − c2 x∇Lt (cx), c2 η) = (c(x − ∇Lt (x)), c2 η)
h 2
i
= Πc1 ◦ Πc2 ◦ GDt (x, η),


where = is because of Scale Invariance of Lt (Lemma 2.3.2).

2.8.3 Omitted proofs in Section 2.4.2

Proof of Lemma 2.4.7. For any input (x, η, x0 , η 0 ), it’s easy to check both composed
maps have the same outputs on the 2,3,4th coordinates, namely (c2 η, cx, c2 η 0 ). For
the first coordinate, we have

x − x0
 
0 2
 ρ 2
 2
GD (cx, c η, cx , c η) 1 = ρcx + c η γ − ∇Lt (cx)
η0
x − x0
  

=c x + η γ − ∇Lt (x)
η0
=c [GDρ (x, η, x0 , η)]1 ,


where = is because of Scale Invariance of Lt (Lemma 2.3.2).

Proof of Lemma 2.4.8. For any input (x, η, x0 , η 0 ), it’s easy to check both composed
maps have the same outputs on the 2,3,4th coordinates, namely (η, x, η). For the first
coordinate, we have

hh −1
i i
Πα3 Πα4 Πα2 Πα2 Πα3 Πα4 (x, η, x , η) = α GDt (x, α−1 η, αx0 , αη) 1
0
 
◦ ◦ ◦ ◦ GDt ◦ ◦
1
x − x0
 

=α x + α−1 η γ − ∇Lt (x)
η
x0
= α + γα−1 x − η∇Lt (x) − ηγ

η
= (ρ + γ) x − η∇Lt (x) − γx0 = [GDρt (x, η, x0 , η)]1

37
2.8.4 Omitted proofs of Theorem 2.4.12

In this subsection we will prove a stronger version of Theorem 2.4.12(restated below),


allowing the WD,λI changing each phase.

Theorem 2.8.2 (A stronger version of Theorem 2.4.12). There exists a way to correct
the momentum only at the first iteration of each phase, such that the following Tapered-
Exponential LR schedule (TEXP) {e
ηt } with momentum factor γ and no WD, leads
the same sequence networks in function space compared to that of Step Decay LR
schedule(Definition 2.4.11) with momentum factor γ and phase-dependent WD λ∗I in
phase I, where phase I lasts from iteration TI to iteration TI+1 , T0 = 0.


)−2

ηet × (αI−1
 if TI−1 + 1 ≤ t ≤ TI − 1, I ≥ 1
ηet+1 = , (2.9)
ηI∗
(αI∗ )−1 (αI−1

)−1

ηet ×

∗ × if t = TI , I ≥ 1
ηI−1

q
2
1+γ−λ∗I ηI∗ + (1+γ−λ∗I ηI∗ ) −4γ
where αI∗ = 2
, ηe0 = η0 (α0∗ )−1 = η0∗ (α0∗ )−1 .

Towards proving Theorem 2.4.12, we need the following lemma which holds by
expanding the definition, and we omit its proof.

Lemma 2.8.3 (Canonicalization). We define the Canonicalization map as


η
N (x, η, x0 , η 0 ) = (x, η, x − η0
(x − x0 ), η), and it holds that

1. GDρt ◦ N = GDρt , ∀ρ > 0, t ≥ 1.


h i h i
2 2 2 2
2. N ◦ Πc1 ◦ Πc2 ◦ Πc3 ◦ Πc4 = Πc1 ◦ Πc2 ◦ Πc3 ◦ Πc4 ◦ N , ∀c > 0.

Similar to the case of momentum-free SGD, we define the notion of equivalent


map below

Definition 2.8.4 (Equivalent Maps). For two maps F and G, we say F is equivalent
h i
c c2 c c2 c
to G iff ∃c > 0, F = Π1 ◦ Π2 ◦ Π3 ◦ Π4 ◦ G, which is also denoted by F ∼ G.

38
Note that for any (x, η, x0 , η 0 ), [N (x, η, x0 , η 0 )]2 = [N (x, η, x0 , η 0 )]4 . Thus as a direct
consequence of Lemma 2.4.8, the following lemma holds.
α −1 −1 −1 −1
Lemma 2.8.5. ∀ρ, α > 0, GDρt ◦ N ∼ Πα3 ◦ Πα4 ◦ Πα2 ◦ GDt ◦ Πα2 ◦ Πα3 ◦ Πα4 ◦ N .

Proof of Theorem2.4.12. Starting with initial state (x0 , η0 , x−1 , η−1 ) where η−1 = η0
and a given LR schedule {ηt }t≥0 , the parameters generated by GD with WD and
momentum satisfies the following relationship:

 ηt+1 
ηt 1−ηt λt
(xt+1 , ηt+1 , xt , ηt ) = Π2 ◦ GDt (xt , ηt , xt−1 , ηt−1 ).

b
Define Ft = Fb ◦ Fb−1 ◦ . . . ◦ Fa , for a ≤ b. By Lemma 2.8.3 and Lemma 2.8.5,
t=a
letting αt be the root of x2 − (γ + 1 − ηt−1 λt−1 )x + γ = 0, we have

 ηt+1 
T −1
Π2ηt
◦ GDt1−ηt λt
t=0
 ηt+1 
T −1
= Π2ηt
◦ GDt1−ηt λt ◦N
t=0
−1
TQ
αi T −1  ηt+1 
i=0 ηt α−1 α−1 α−1 α−1 α α (2.10)
∼ Π2 ◦ Π3 t+1 ◦ Π4 t+1 ◦ Π2 t+1 ◦ GDt ◦ Π2 t+1 ◦ Π3 t+1 ◦ Π4 t+1 ◦N
t=0
ηT
ηT −1 α−1 α−1 α−1
=Π2 ◦ Π3 T −1 ◦ Π4 T −1 ◦ Π2 T ◦ GDT −1 ◦
 i
T −1 h α−1 α−1
α−1
= Π2 t+1 t
◦ Ht ◦ GDt−1 ◦ Π2 1 ◦ Πα3 1 ◦ Πα4 1 ◦ N,
t=1

−1
TQ
αi
i=0
where ∼ is because of Lemma 2.8.5, and Ht is defined as

ηt−1 ηt
α α α−1 α−1 α−1 ηt−1
Ht = Πα2 t ◦ Π2 ηt
◦ Π3 t+1 ◦ Π4 t+1 ◦N ◦ Π3 t ◦ Π4 t ◦ Π2 t ◦ Π2 .

Since the canonicalization map N only changes the momentum part of the state, it’s
easy to check that Ht doesn’t touch the current parameter x and the current LR η.
Thus Ht only changes the momentum part of the input state. Now we claim that
39
Ht ◦ GDt−1 = GDt−1 whenever ηt = ηt−1 . This is because when ηt = ηt−1 , αt = αt+1 ,
thus Ht ◦ GDt−1 = GDt−1 . In detail,

Ht ◦ GDt−1
α−1 α−1 α−1
=Πα2 t ◦ Πα3 t ◦ Πα4 t ◦ N ◦ Π3 t ◦ Π4 t ◦ Π2 t ◦ GDt−1
∗ α−1 α−1 α−1
=Πα2 t ◦ Πα3 t ◦ Πα4 t ◦ Π3 t ◦ Π4 t ◦ Π2 t ◦ GDt−1

=GDt−1 ,


where = is because GD update GDt sets η 0 the same as η, and thus ensures the input
of N has the same momentum factor in buffer as its current momentum factor, which
makes N an identity map.
0
 Thus we could rewrite Equation (2.10) with a “sloppy”version of Ht , Ht =

Ht ηt 6= ηt−1 ;

:

Id o.w.

 ηt+1 
T −1
ηt 1−ηt λt
Π2 ◦ GDt
t=0
ηT  i
α−1 α−1 α−1 T −1 h α−1 α−1
ηT −1 α−1
=Π2 ◦ Π3 T −1 ◦ Π4 T −1 ◦ Π2 T ◦ GDT −1 ◦ Π2 t+1 t
◦ Ht0
◦ GDt−1 ◦ Π2 1 ◦ Πα3 1 ◦ Πα4 1 ◦ N
t=1
ηT  i
α−1 α−1 α−1 T −1 h α−1 −1
η t+1 αt 0 α−1
=Π2 T −1 ◦ Π3 T −1 ◦ Π4 T −1 ◦ Π2 T ◦ GDt ◦ Π2 ◦ Ht ◦ GD0 ◦ Π2 1 ◦ Πα3 1 ◦ Πα4 1 ◦ N,
t=1

(2.11)

Now we construct the desired sequence of parameters achieved by using the Tapered
Exp LR schedule 2.9 and the additional one-time momentum correction per phase.
Let (e
x0 , ηe0 , x
e−1 , ηe−1 ) = (x0 , η0 , x−1 , η0 ), and

40
α−1
h i
α1 α1
(e e0 , ηe0 ) = GD0 ◦ Π2 ◦ Π3 ◦ Π4 ◦ N (e
x1 , ηe1 , x 1
x0 , ηe0 , x
e−1 , ηe−1 )
α−1
h i
= GD0 ◦ Π2 1 ◦ Πα3 1 ◦ Πα4 1 (e x0 , ηe0 , x
e−1 , ηe−1 );
α−1 α−1
h i
(e et , ηet ) = GDt ◦ Π2 t+1 t ◦ Ht0 (e
xt+1 , ηet+1 , x xt , ηet , x
et−1 , ηet−1 ).

we claim {e
xt }t=0 is the desired sequence of parameters. We’ve already shown that
xt ∼ x
et , ∀t. Clearly {e
xt }t=0 is generated using only vanilla GD, scaling LR and
6 TI for any I, ηt = ηt−1 and thus
modifying the momentum part of the state. When t =
Ht0 = Id. Thus the modification on the momentum could only happen at TI (I ≥ 0).
Also it’s easy to check that αt = αI∗ , if TI + 1 ≤ t ≤ TI+1 .

2.8.5 Omitted proofs of Theorem 2.4.13

Proof of Theorem 2.4.13. We will prove by induction. By assumption S(t) : Pt xt = x


et
for t = −1, 0. Now we will show that S(t) =⇒ S(t + 1), ∀t ≥ 0.

41
 
xt − xt−1 xt−1 − xt−2 λt−1 2
=γ − ∇x (L(xt−1 ) + kxt−1 k2
ηt−1 ηt−2 2
Take gradient xt − xt−1 xt−1 − xt−2
=======⇒ =γ − ∇x L(xt−1 ) + λt−1 xt−1
ηt−1 ηt−2
Scale Invariance xt − xt−1 xt−1 − xt−2
=========⇒ =γ − Pt−1 ∇x L(ext−1 ) + λt−1 xt−1
ηt−1 ηt−2
Rescaling Pt (xt − xt−1 ) Pt−2 (xt−1 − xt−2 ) xt−1
=====⇒ =γ − ∇x L(e xt−1 ) − λt−1
Pt Pt−1 ηt−1 Pt−1 Pt−2 ηt−2 Pt−1
−1
Simplfying Pt xt − αt x et−1 et−1 − x
αt−1 x et−2 Pt xt−1
======⇒ =γ − ∇x L(e xt−1 ) − ηt−1 λt−1
ηet−1 ηet−2 ηt−1 Pt−1 Pt
−1
Simplfying Pt xt − αt x et−1 et−1 − x
αt−1 x et−2 α−1 x et−1
======⇒ =γ − ∇x L(e xt−1 ) − ηt−1 λt−1 t
ηet−1 ηet−2 ηet−1
−1
Simplfying Pt xt − αt (1 − ηt−1 λt−1 )e xt−1 et−1 − x
αt−1 x et−2
======⇒ =γ − ∇x L(e xt−1 )
ηet−1 ηet−2

To conclude that Pt xt = x
et , it suffices to show that the coefficients before x
et−1 is
the same to that in (2). In other words, we need to show

−1 + αt−1 (1 − ηt−1 λt−1 ) γ(1 − αt−1 )


= ,
ηet−1 ηet−2

which is equivalent to the definition of αt , Equation (2.7).

Lemma 2.8.6 (Sufficient Conditions for positivity of αt ). Let λmax = maxt λt , ηmax =
maxt ηt . Define zmin is the larger root of the equation z 2 − (1 + γ − λmax ηmax )z + γ = 0.

To guarantee the existence of zmax we also assume ηmax λmax ≤ (1 − γ)2 . Then we
have

∀α−1 , α0 = 1 =⇒ zmin ≤ αt ≤ 1, ∀t ≥ 0 (2.12)

42
Proof. We will prove the above theorem with a strengthened induction —

^ α−1 −1
0 t0 − 1 zmin −1
S(t) : ∀0 ≤ t ≤ t, zmin ≤ αt0 ≤ 1 ≤ .
ηt0 −1 ηmax

Since α0 = 1, S(0) is obviously true. Now suppose S(t) is true for some t ∈ N, we
will prove S(t + 1).
First, since 0 < αt ≤ 1, αt+1 = −ηt λt + 1 + ηt
ηt−1
γ(1 − αt−1 ) ≤ 1.
Again by Equation (2.7), we have

αt−1 − 1 z −1 − 1 −1
1 − αt+1 = ηt λt + ηt γ = ηt λt + min ηt γ ≤ ηt λt + (zmin − 1)γ = 1 − zmin ,
ηt−1 ηmax

which shows αt+1 ≥ zmin . Here the last step is by definition of zmin .
Because of αt+1 ≥ zmin , we have

−1
αt+1 −1 −1 1 − αt+1 −1 α−1 − 1
≤ zmin ≤ zmin (λt + t γ)
ηt ηt ηt−1
−1 z −1 − 1 −1 1 − zmin z −1 − 1
≤zmin (λmax + min γ) = zmin = min .
ηmax ηmax ηmax

Now we are ready to give the formal statement about the closeness of Equation (2.6)
and the reduced LR schedule by Theorem 2.4.13.

43
Theorem 2.8.7. Given a Step Decay LR schedule with {TI }K−1 ∗ K−1 ∗ K−1
I=0 , {ηI }I=0 , {λI }I=0 ,

the TEXP++ LR schedule in Theorem 2.4.13 is the following(α0 = α−1 = 1, T0 = 0):



−1
−ηI∗ λ∗I + 1 + γ(1 − αt−1

 ), ∀TI + 2 ≤ t ≤ TI+1 , I ≥ 0;
αt =
ηI∗ −1
−ηI∗ λ∗I + 1 +



ηI−1
γ(1 − αt−1 ), ∀t = TI + 1, I ≥ 0;
t
Y
Pt = αt−1 ;
i=−1

η̂t = Pt Pt+1 ηt .

It’s the same as the TEXP LR schedule({η˜t }) in Theorem 2.4.12 throughout each
phase I, in the sense that ∀TI + 1 ≤ t ≤ TI+1

  t−TI −1  (t−TI −1)


η̂t−1 ηet−1 λmax ηmax γ λmax ηmax λmax ηmax 2
−1 <3 2
≤3 γ(1 + ) .
η̂t ηet 1−γ zmin 1−γ 1−γ

where zmin is the larger root of z 2 − (1 + γ − λmax ηmax )z + γ = 0. In Section 2.8,


−1 ηmax λmax
we show that zmin ≤1+ 1−γ
. When λmax ηmax is small compared to 1 − γ, which
is usually the case in practice, one could approximate zmin by 1. For example, when
γ = 0.9, λmax = 0.0005, ηmax = 0.1, the above upper bound becomes


η̂t−1 ηet−1
− 1 ≤ 0.0015 × 0.9009t−TI −1 .
η̂t ηet

Proof of Theorem 2.8.7. Assuming zI1 and zI2 (zI1 ≥ zI2 ) are the roots of Equation (2.1)

with η = ηI and λ = λI , we have γ ≤ zI20 ≤ γ ≤ zmin ≤ zI1 ≤ 1, ∀I, I 0 ∈ [K − 1] by
Lemma 2.8.1.
We can rewrite the recursion in Theorem 2.4.13 as the following:

−1 −1
αt = −ηI λI + 1 + γ(1 − αt−1 ) = −(zI1 + zI2 ) + zI1 zI2 αt−1 . (2.13)

44
In other words, we have

zI2
αt − zI1 = (αt−1 − zI1 ), t ≥ 1. (2.14)
αt−1
z2I αt−1
By Lemma 2.8.6, we have αt ≥ zmin , ∀t ≥ 0. Thus | αz1t − 1| = |
αt−1 zI1
− 1| ≤
I

2
γ
zmin
| αzt−1
1 − 1| = 2
γ
zmin
| αzt−1
1 − 1| ≤ γ(1 + λη 2 αt−1
1−γ
) | z1 |, which means αt geometrically
I I I
ηet−1
converges to its stable fixed point zI1 . and ηet
= (zI1 )2 . Since that zmin ≤ αt ≤ 1,
αTI 1−zmin λmax ηmax
zmin ≤ zI1 ≤ 1, we have | zI1
− 1| ≤ zmin
= 1−γ
≤ 1 , and thus | αz1t − 1| ≤
I
λmax ηmax
1−γ
( z2γ )t−TI −1 ≤ 1, ∀TI + 1 ≤ t ≤ TI+1 .
min
η̂t−1
Note that αI∗ = zI1 , η̂t
= αt αt+1 By definition of TEXP and TEXP++, we have


1
)2

ηet−1 (zI−1
 if TI−1 + 1 ≤ t ≤ TI − 1
= (2.15)
ηet ∗
 ηI−1 z1z1


ηI∗ I I−1
if t = TI , I ≥ 1


η̂t−1 ηt−1 αt+1 αt
 if TI−1 + 1 ≤ t ≤ TI − 1
= αt+1 αt = (2.16)
η̂t ηt ∗
 ηI−1


η∗
αTI +1 αTI if t = TI , I ≥ 1
I

Thus we have when t = TI ,



η̂t−1 ηet−1 αTI +1 αTI αTI +1 αT αTI +1 αT
−1 ≤ 1 1
−1 ≤ 1
−1 + 1 I −1 + 1
−1 1 I −1
η̂t ηet zI zI−1 zI zI−1 zI zI−1
λmax ηmax
≤3 .
1−γ

When TI + 1 ≤ t ≤ TI+1 , we have



η̂t−1 ηet−1 αt+1 αt αt+1 αt αt+1 αt
−1 = 1 1
−1 ≤ 1 −1 + 1 −1 + 1 −1 1 −1
η̂t ηet zI−1 zI−1 zI−1 zI−1 zI−1 zI−1
λmax ηmax γ t−TI −1
≤3 ( 2 ) .
1 − γ zmin

45
Thus we conclude ∀I ∈ [K − 1], TI + 1 ≤ t ≤ TI+1 , we have

  t−TI −1
η̂t−1 ηet−1 λmax ηmax γ λmax ηmax t−TI −1 λmax ηmax 2(t−TI −1)
−1 ≤3 2
≤3 ·γ (1+ ) .
η̂t ηet 1−γ zmin 1−γ 1−γ

2.8.6 Proofs in Section 2.5

We will use ŵ to denote w


kwk
and ∠uw to arccos(û> ŵ). Note that training error ≤ ε
π

is equivalent to ∠e1 wt < ε.

Case 1: WD alone Since the objective is strongly convex, it has unique argmin
w∗ . By symmetry, w∗ = βe1 , for some β > 0. By KKT condition, we have

  r
|x1 | 2
λβ = E ≤ E [|x1 |] = ,
x1 ∼N (0,1) 1 + exp(β|x1 |) x1 ∼N (0,1) π

which implies kw∗ k = O( λ1 ).


η
By Theorem 3.1 of [45], for sufficiently large t, we have E kwt − w∗ k2 = O( Bλ ).
∗ t
Note that ∠e1 wt = ∠w∗ wt ≤ 2 sin ∠w∗ wt ≤ 2 kwkw−w
∗k
k
, we have E (∠e1 wt )2 = O( ηλ
B
),
q
so the expected error = E (∠e1 wt )/π ≤ E (∠e1 wt )2 /π = O( ηλ
p
B
).

Case 3: Both BN and WD We will need the following lemma when lower
bounding the norm of the stochastic gradient.

i.i.d.
Lemma 2.8.8 (Concentration of Chi-Square). Suppose X1 , . . . , Xk ∼ N (0, 1), then

" k
#
X  k2
Pr Xi2 < kβ ≤ βe1−β . (2.17)
i=1

Proof. This Chernoff-bound based proof is a special case of [46].

46
" k
# " k
! #
X k X
Pr Xi2 < kβ ≤ βe 1−β 2
= Pr exp ktβ − t Xi2 ≥1
i=1 i=1
" k
!#
X
≤ E exp ktβ − t Xi2 (Markov Inequality)
i=1

− k2
=ektβ (1 + 2t) .

(2.18)

The last equality uses the fact that E [tXi2 ] = √ 1 for t < 1
. The proof is
1−2t 2
1−β
completed by taking t = 2β
.

Setting for Theorem 2.5.1: Suppose WD factor is λ, LR is η, the width of the


last layer is m ≥ 3, Now the SGD updates have the form

B  
η X > wt λ 2
wt+1 =wt − ∇ ln(1 + exp(−xt,b yt,b) ) + kwt k
B b=1 kwt k 2
B
η X yt,b Π⊥
wt xt,b
=(1 − λη)wt − w ,
B b=1 1 + exp(xt,b > kwtt k yt,b ) kwt k

i.i.d. wt wt>
where xt,b ∼ N (0, Im ), yt,b = sign ([xt,b ]1 ), and Π⊥
wt = I − kwt k2
.

Proof of Theorem 2.5.1.



1 64kwT0 k2 ε B
Step 1: Let T1 = 2(ηλ−2ε2 )
ln √
η m−2
, and T2 = 9 ln 1δ . Thus if we assume the
training error is smaller than ε from iteration T0 to T0 + T1 + T2 , then by spherical
triangle inequality, ∠wt wt0 ≤ ∠e1 wt0 + ∠e1 wt = 2ε, for T0 ≤ t, t0 ≤ T0 + T1 + T2 .
Now let’s define wt0 = (1 − ηλ)wt and for any vector w, and we have the following
two relationships:

1. kwt0 k = (1 − ηλ)kwk.

kwt0 k
2. kwt+1 k ≤ cos 2ε
.

47
The second property is because by Lemma 2.3.2, (wt+1 − wt0 ) ⊥ wt0 and by
assumption of small error, ∠wt+1 wt0 ≤ 2ε.
Therefore

2T1  2T1
kwT1 +T0 k2

1 − ηλ 1 − ηλ 2 2T1 2
≤ e−2T1 (ηλ−2ε )

2
≤ ≤ 2
≤ 1 − (ηλ − 2ε )
kwT0 k cos 2ε 1 − 2ε
r
η m−2
= .
64kwT0 k2 ε B

q
η m−2
In other word, kwT0 +T1 k2 ≤ 64ε B
. Since kwT0 +t k is monotone decreasing,
q
η m−2
kwT0 +t k2 ≤ 64ε B
holds for any t = T1 , . . . , T1 + T2 .
Step 2: We show that the norm of the stochastic gradient is lower bounded
with constant probability. In other words, we want to show the norm of ξt =
PB yt,b Π⊥
wt xt,b
b=1 1+exp(xt,b > wt
yt,b ) kwt k
is lower bounded with high probability.
kwt k

Let Π⊥
wt ,e1 be the projection matrix for the orthogonal space spanned by wt and
e1 . W.L.O.G, we can assume the rank of Π⊥
wt ,e1 is 2. In case wt = e1 , we just exclude

a random direction to make Π⊥ ⊥


wt ,e1 rank 2. Now we have Πwt ,e1 xt,b are still i.i.d.

multivariate gaussian random variables, for b = 1, . . . , B, and moreover, Π⊥


wt ,e1 xt,b is
yt,b
independent to w
1+exp(xt,b > kwt k yt,b )
. When m ≥ 3, we can lower bound kξt k by dealing
t

with kΠ⊥
wt ,e1 ξt k.

It’s not hard to show that conditioned on {xt,b > kw


wt
tk
, [xt,b ]1 }B
b=1 ,

v !2
B
u B
X yt,b uX
d t yt,b
Π⊥ xt,b = Π⊥
wt ,e1 x,
b=1
1 + exp(xt,b kwt k yt,b ) wt
> wt
b=1
wt
1 + exp(xt,b > kw tk
yt,b )
(2.19)
where x ∼ N (0, Im ). We further note that kΠ⊥ 2 2
wt ,e1 xk ∼ χ (m − 2). By

Lemma 2.8.8,

48
 
m−2 1 m−2 1 1 1
Pr kΠ⊥
wt ,e1 xt k
2
≥ ≥ 1 − ( 7 ) 2 ≥ 1 − ( 7 )2 ≥ . (2.20)
8 8e 8 8e 8 3

 2
PB yt,b
Now we will give a high probability lower bound for b=1 w
1+exp(xt,b > kwt k yt,b )
.
t

Note that x> wt


t kwt k ∼ N (0, 1), we have

 
wt 1
Pr |x>
t,b |<1 ≥ , (2.21)
kwt k 2
h i
which implies the following, where At,b is defined as 1 |x> wt
t,b kwt k | <1≥ 1
2
:

" #
yt,b 1 1
E At,b = Pr k > wt
k≥ ≥ . (2.22)
1 + exp(xt,b kwt k yt ) 1+e 2
PB PB hP i
B B B
Note that b=1 At,b ≤ B, and E b=1 At,b ≥ 2
, we have Pr b=1 At,b < 4
≤ 23 .
Thus,

 !2 
" B #
B
X yt,b B X B 1
Pr 
wt ≥  ≥ Pr At,b ≥ ≥ . (2.23)
b=1
1 + exp(xt,b > kwtk
yt,b ) 4(1 + e)2
b=1
4 3

Thus w.p. at least 19 , Equation (2.23) and Equation (2.20) happen together, which
implies

B B
η X > wt η X yt,b Π⊥
wt xt,b
k ∇ ln(1 + exp(−xt,b yt,b ))k = k > wt k
B b=1 kwt k B b=1 1 + exp(xt,b kwt k yt ) kwt k
√ r
η m−2 η m−2
≥ ≥
1 + e 8kwt k 32kwt k B

Step 3. To stay in the cone {w|∠we1 ≤ ε}, the SGD update kwt+1 − wt0 k =
k Bη B > wt
P
b=1 ∇ ln(1 + exp(−xt,b kwt k yt,b ))k has to be smaller than kwt k sin 2ε for any

49
t = T0 + T1 , . . . , T0 + T1 + T2 . However, step 1 and 2 together show that k∇ ln(1 +
exp(−x> wt
t kwt k yt ))k ≥ 2kwt kε w.p.
1
per iteration. Thus the probability that wt always
9
T
stays in the cone for every t = T0 + T1 , . . . , T0 + T1 + T2 is less than 89 2 ≤ δ.

It’s interesting that the only property of the global minimum we use is that
the if both wt , wt+1 are ε−optimal, then the angle between wt and wt+1 is at
most 2ε. Thus we indeed have proved a stronger statement: At least once in every

1 64kwT0 k2 ε B
2(ηλ−2ε2 )
ln √
η m−2
+ 9 ln 1δ iterations, the angle between wt and wt+1 will be larger
than 2. In other words, if the the amount of the update stabilizes to some direction

in terms of angle, then the fluctuation in terms of angle must be larger than 2ηλ for
this simple model, no matter how small the noise is.

2.8.7 Proofs in Section 2.6

Lemma 2.8.9. Suppose loss L is scale invariant, then L is non-convex in the following
two sense:

1. The domain is non-convex: scale invariant loss can’t be defined at origin;

2. There exists no ball containing origin such that the loss is locally convex, unless
the loss is constant function.

Proof. Suppose L(x∗ ) = supx∈B L(x). W.L.O.G, we assume kx∗ k < 1. By convexity,
every line segment passing x∗ must have constant loss, which implies the loss is

constant over set B − {c kxx∗ k | −1 ≤ c ≤ 0}. Applying the above argument on any
other maximum point x0 implies the loss is constant over B − {0}.

Theorem 2.8.10. Suppose the momentum factor γ = 0, LR ηt = η is constant,


and the loss function L is lower bounded. If ∃c > 0 and T ≥ 0 such that ∀t ≥ T ,
f (xt+1 ) − f (xt ) ≤ −cηk∇L(xt )k2 , then limt→∞ kxt k = 0.

50
Proof in Item 3. By Lemma 2.3.2 and the update rule of GD with WD, we have

kxt k2 = k(1 − λη)kxkt−1 + η∇L(xt−1 )k2 = (1 − λη)2 kxt−1 k2 + η 2 k∇L(xt−1 )k2 ,

which implies

t−1
X
kxt k2 = (1 − λη)2(t−i−1) η 2 k∇L(xt−1 )k2 + (1 − λη)2(t−T ) kxT k2 .
i=T

Thus for any T 0 > T ,

T 0 0 −1
TX
! 0 −1
TX
!
X 1 1
kxt k2 ≤ k∇L(xt )k2 + kxT k2 ≤ k∇L(xt )k2 + kxT k2 .
t=T
1 − (1 − λη)2 t=T
λη t=T

P 0
Note that by assumption we have Tt=T−1 k∇L(xt )k2 = cη1 f (xT ) − f (xT 0 ).
P∞ 2
As a conclusion, we have 2
t=T kxt k ≤
f (xT )−minx f (x)
cη 2 λ
+ kxλη
Tk
, which implies
lim kxt k2 = 0.
t→∞

2.9 Side Results on Parameter Norm Convergence

Now we rigorously analyze norm growth in this algorithm. This greatly extends
previous analyses of effect of normalization schemes [29, 37] for vanilla SGD.

Theorem 2.9.1. Under the update rule 2.3.1 with λt = 0, the norm of scale invariant
parameter xt satisfies the following property:

• Almost Monotone Increasing: kxt+1 k2 − kxt k2 ≥ −γ t+1 ηη0t (kx0 k2 − kx−1 k2 ).

• Assuming ηt = η is a constant, then

t
2
X 1 − γ t−i+1  1 − γ t+1
kxt+1 k = kxi − xi+1 k2 + γkxi−1 − xi k2 −γ (kx0 k2 −kx−1 k2 )
i=0
1−γ 1−γ

51
Proof. Let’s use Rt , Dt , Ct to denote kxt k2 , kxt+1 − xt k2 , x>
t (xt+1 − xt ) respectively.

The only property we will use about loss is ∇x L>


t xt = 0.

Expanding the square of kxt+1 k2 = k(xt+1 − xt ) + xt k2 , we have

∀t ≥ −1 S(t) : Rt+1 − Rt = Dt + 2Ct .

We also have

Ct xt+1 − xt xt − xt−1 γ
= x>
t = x>
t (γ − λt xt ) = (Dt + Ct−1 ) − λt Rt ,
ηt ηt ηt−1 ηt−1

namely,

Ct γDt γ
∀t ≥ 0 P (t) : − = Ct−1 − λt Rt .
ηt ηt−1 ηt−1
S(t) γS(t−1)
Simplify ηt
− ηt−1
+ P (t), we have

Rt+1 − Rt Rt − Rt−1 Dt Dt−1


−γ = +γ − 2λt Rt . (2.24)
ηt ηt−1 ηt ηt−1

When λt = 0, we have

t
Rt+1 − Rt R0 − R−1 X t−i Di Di−1 R0 − R−1
= γ t+1 + γ ( +γ ) ≥ γ t+1 .
ηt η−1 i=0
η i ηi−1 η0

Further if ηt = η is a constant, we have

t
X 1 − γ t−i+1 1 − γ t+1
Rt+1 = R0 + (Di + γDi−1 ) − γ (R0 − R−1 ),
i=0
1−γ 1−γ

which covers the result without momentum in [36] as a special case:

t
X
Rt+1 = R0 + Di .
i=0

52
For general deep nets, we have the following result, suggesting that the mean square
of the update are constant compared to the mean square of the norm. The constant
is mainly determined by ηλ, explaining why the usage of weight decay prevents the
1
parameters to converge in direction.

Theorem 2.9.2. For SGD with constant LR η, weight decay λ and momentum γ,
P −1 P −1
when the limits R∞ = limT →∞ T1 Tt=0 kwt k2 , D∞ = limT →∞ T1 Tt=0 kwt+1 − wt k2
exist, we have
2ηλ
D∞ = R∞ .
1+γ

Proof of Theorem 2.9.2. Take average of Equation (2.24) over t, when the limits
P −1 P −1
R∞ = limT →∞ T1 Tt=0 kwt k2 , D∞ = limT →∞ T1 Tt=0 kwt+1 − wt k2 exists, we have

1+γ
D∞ = 2λR∞ .
η

2.10 Scale Invariance in Modern Network Archi-

tectures

In this section, we will discuss how Normalization layers make the output of the
network scale-invariant to its parameters. Viewing a neural network as a DAG, we
give a sufficient condition for the scale invariance which could be checked easily by
topological order, and apply this on several standard network architectures such as Fully
Connected(FC) Networks, Plain CNN, ResNet[48], and PreResNet[49]. For simplicity,
we restrict our discussions among networks with ReLU activation only. Throughout
this section, we assume the linear layers and the bias after last normalization layer are

1
? ] had a similar argument for this phenomenon by connecting this to the LARS[47], though
it’s not rigorous in the way it deals with momentum and equilibrium of norm.
53
fixed to its random initialization, which doesn’t harm the performance of the network
empirically[32].

2.10.1 Notations

Definition 2.10.1 (Homogeneous Functions). Suppose k is an integer and x is all


the parameters of the network, then f is said to be homogeneous of degree k, or
k-homogeneous, if ∀c > 0, f (cx) = ck f (x). The output of f can be multi-dimensional.
Specifically, scale invariance means degree of homogeneity is 0.

Definition 2.10.2. For a module with n inputs and m outputs, we say the module is
(a1 , ...an ; b1 , ..., bm )-homogeneous if the m outputs are bi -homogeneous to the network
parameters whenever the n inputs are ai -homogeneous to the network parameters. A
model is scale invariant iff its output is (; 0)-homogeneous. (A complete model doesn’t
take any input from another module)

Suppose the network only contains following modules, and we list the degree of
homogeneity of these basic modules, given the degree of homogeneity of its input.

(I) Input

(L) Linear Layer, e.g. Convolutional Layer or Fully Connected Layer

(B) Bias Layer(Adding Trainable Bias to the output of the previous layer)

(+) Addition Layer (adding the outputs of two layers with the same dimension2 .)

(N) Normalization Layer without affine transformation(including BN, GN, LN, IN


etc.)

(NA) Normalization Layer with affine transformation

2
Addition Layer(+) is mainly used in ResNet and other similar architectures. In this section, we
also use it as an alternative definition of Bias Layer(B). See Figure 2.7
54
Table 2.1: Table showing how degree of homogeneity of the output of basic modules
depends on the degree of homogeneity of the input. Input module doesn’t require
any input and thus the output are trivially scale invariant. ReLU, Pooling( and
other fixed linear maps) are ignored because they keep the degree of homogeneity, i.e.
(x; x)-homo, and thus can be omitted when creating the DAG in Theorem 2.10.4.

Symbol Module Homogeneity


I Input (;0)
B Adding Bias (1;1)
N Layer Normalization (no affine) (x;0)
L Linear Layer (x;x+1)
+ Addition Layer (x,x;x+1)
NA Layer Normalization with affine (x;1)

Remark 2.10.3. For the purpose of deciding the degree of homogeneity of a network,
we can ignore the difference among convolutional layers, fully connected layer and the
diagonal linear layer in the affine transformation of Normalization layer, since they’re
all linear and the degree of homogeneity is increased by 1 after applying them.
On the other hand, BN and IN has some benefit which GN and LN doesn’t have,
namely the bias term (per channel) immediately before BN or IN has zero effect on
the network output and thus can be removed. (See Figure 2.15)

We also demonstrate the homogeneity of the output of the modules via the following
figures, which will be reused to later to define network architectures.

Theorem 2.10.4. For a network only consisting of modules defined above and ReLU
activation, we can view it as a Directed acyclic graph and check its scale invariance
by Algorithm 1.

2.10.2 Networks without Affine Transformation and Bias

We start with the simple cases where all bias term(including that of linear layer and
normalization layer) and the scaling term of normalization layer are fixed to be 0 and
1 element-wise respectively, which means the bias and the scaling could be dropped

55
(a) Input(I) (b) Linear(L) (c) Addition(+) (d) Normaliza-
tion(N)

(e) Bias(B) (f) Alternative Definition of Bias(B)

(g) Normalization with (h) Definition of Normalization with Affine(NA)


Affine(NA)

Figure 2.7: Degree of homogeneity of the output of basic modules given degree of
homogeneity of the input.

Algorithm 1 Checking Scale Invariance of a Given Architecture


Require: DAG G = (V, E) translated from a neural network; the module type of
each node vi ∈ V .
for v in topological order of G do
Compute the degree of homogeneity of v using Table 2.1
if v is not homogeneous then
return False
if vouptut is 0-homogeneous then
return True
else
return False

56
from the network structure. We empirically find this doesn’t affect the performance of
network in a noticeable way. We will discuss the full case in the next subsection.

Plain CNN/FC networks: See Figure 2.8.

Figure 2.8: Degree of homogeneity for all modules in vanilla CNNs/FC networks.

Figure 2.9: An example of the full network structure of ResNet/PreResNet represented


by composite modules defined in Figure 2.10,2.11,2.13,2.14, where ‘S’ denotes the
starting part of the network, ‘Block’ denotes a normal block with residual link, ‘D-
Block’ denotes the block with downsampling, and ‘N’ denotes the normalization layer
defined previously. Integer x ∈ {0, 1, 2} depends on the type of network. See details
in Figure 2.10,2.11,2.13,2.14.

ResNet: See Figure 2.10. To ensure the scaling invariance, we add an additional
normalizaiton layer in the shortcut after downsampling. This implementation is
sometimes used in practice and doesn’t affect the performance in a noticeable way.

Preactivation ResNet: See Figure 2.11. Preactivation means to change the order
between convolutional layer and normalization layer. For similar reason, we add an
additional normalizaiton layer in the shortcut before downsampling.

57
(a) The starting part of ResNet

(b) A block of ResNet

(c) A block of ResNet with downsampling

Figure 2.10: Degree of homogeneity for all modules in ResNet without affine transfor-
mation in normalization layer. The last normalization layer is omitted.

2.10.3 Networks with Affine Transformation

Now we discuss the full case where the affine transformation part of normalization
layer is trainable. Due to the reason that the bias of linear layer (before BN) has 0
gradient as we mentioned in 2.10.3, the bias term is usually dropped from network
architecture in practice to save memory and accelerate training( even with other
normalization methods)(See PyTorch Implementation [44]). However, when LN or
GN is used, and the bias term of linear layer is trainable, the network could be scale
variant (See Figure 2.15).

Plain CNN/FC networks: See Figure 2.12.

58
(a) The starting part of PreResNet

(b) A block of PreResNet

(c) A block of PreResNet with downsampling

Figure 2.11: Degree of homogeneity for all modules in ResNet without affine transfor-
mation in normalization layer. The last normalization layer is omitted.

Figure 2.12: Degree of homogeneity for all modules in vanilla CNNs/FC networks.

59
ResNet: See Figure 2.13. To ensure the scaling invariance, we add an additional
normalizaiton layer in the shortcut after downsampling. This implementation is
sometimes used in practice and doesn’t affect the performance in a noticeable way.

(a) The starting part of ResNet

(b) A block of ResNet

(c) A block of ResNet with downsampling

Figure 2.13: Degree of homogeneity for all modules in ResNet with trainable affine
transformation. The last normalization layer is omitted.

Preactivation ResNet: See Figure 2.14. Preactivation means to change the order
between convolutional layer and normalization layer. For similar reason, we add an
additional normalizaiton layer in the shortcut before downsampling.

60
(a) The starting part of PreResNet

(b) A block of PreResNet

(c) A block of PreResNet with downsampling

Figure 2.14: Degree of homogeneity for all modules in PreResNet with trainable affine
transformation. The last normalization layer is omitted.

Figure 2.15: The network can be not scale variant if the GN or IN is used and the bias
of linear layer is trainable. The red ‘F’ means the Algorithm 1 will return False here.

61
Chapter 3

Convergence Analysis for Gradient


Descent on Normalized Networks

Recall in Chapter 2 we proved the equivalence between exponential increasing learning


rate schedule and constant learning rate with weight decay for SGD on scale invariant
loss functions. In this chapter we continue on this topic and complete the story with
a convergence analysis for SGD+WD on scale invariant functions.

3.1 Introduction

Given a loss function L : RD → R, the non-convex optimization problem aims to find


a parameter x with small gradient norm k∇L(x)k, as there is no hope to directly
minimize the loss itself L(x) efficiently without further assumptions and the hope
here is that small gradient norm will be a good surrogate for small loss. We call
x a -approximate stationary point if k∇L(x)k ≤ . However, this standard notion
of approximate stationary point is not useful for scale invariant loss, as one can
simply scale up the initialization x(0) to infinity and the gradient norm thus scales
inversely. A more reasonable notion of ‘stationary point’ is that the direction of
x
x, denoted by x := kxk2
, has small gradient norm, as first introduced in [36]. We
62
will use this definition of approximate stationary point throughout the chapter. The
main contributions of this section are convergence rates in terms of gradient norm of
direction reached by Gradient Descent and Stochastic Gradient Descent.
The key highlights of our theoretical analysis for SGD+WD include:

1
1. Parameter norm converges to Θ(( λη ) 4 ) in T1 = O(
e 1 ) steps with high probability
ηλ

where T1 is a function of loss L, initial norm kx(0)k2 , LR η and WD λ. Moreover,


ln |c|
T1 (L, kx(0)k2 , η, λ) changes most by ηλ
for operation (A1-3).

2. After step T1 , convergence to first order approximate stationary point happens


and the rate only depends on ηλ and is unaffected by operations (A1-3).

3.(A1). L → cL, for any c > 0.

(A2). x(0) → cx(0), for any c > 0.

(A3). (L, x(0)) → (L0 , cx(0)), where L0 is defined as L0 (x) := L( xc ) for any c > 0.

Properties (1) and (2) suggest our results are more robust to initialization scale (by
only having logarithmic dependence on it), showing the advantage of using scale
invariant functions while matching the standard convergence rates for non-convex
functions.

3.2 Preliminary

In this section we present the definition of scale invariant functions and some of their
x
useful properties. For x ∈ Rd , we define x := kxk2
. We say a function is C k iff it is
k-times continuously differentiable. We also assume the loss function L is a C 2 and
scale invariant function and ρ := max k∇2 L(x)k. Same to Chapter 2, we use λ to
kxk=1
denote weighrt decay factor and η to denote learning rate.

63
Definition 3.2.1. Given a cone U ⊂ Rd , we say a function f : U → R is (positively) k-
homogeneous or of homogeneity of degree k iff for any c > 0 and x ∈ U , f (cx) = ck f (x).
We say a function is scale invariant iff it is 0-homogeneous.

Now we present some useful properties of the derivatives of homogeneous functions.

Theorem 3.2.2 (Euler’s Homogeneous Function Theorem). For any k-homogeneous


C 1 function f , it holds that h∇f (x), xi = kf (x).

Lemma 3.2.3. For any k-homogeneous C l function f , ∇l f is k − l homogeneous.

Lemma 3.2.4 (Equivalent Scaling). The properties below hold:

1. For any loss L, LR η, WD λ and initialization x(0), rescaling (L, η, λ, x(0)) →


(cL, η/c, cλ, x(0)) doesn’t change GD iterate x(t) for any t ≥ 0.

2. For any scale invariant loss L, LR η, WD λ and initialization x(0), rescaling


(L, η, λ, x(0)) → (L, c2 η, λ/c2 , cx(0)) doesn’t change the direction of GD iterate
x(t) for any t ≥ 0. (Lemma 2.4.4)

3.3 Convergence of GD+WD

We first present the convergence result in the deterministic case, i.e., Gradient Descent
over L(x) + λ
2
kxk22 .

GD+WD: x(t + 1) = (1 − ηλ)x(t) − η∇L(x(t)) (3.1)

1
Theorem 3.3.1 (GD+WD). For ηλ ≤ 2
, let x(t) be defined by GD (3.1), and
kx(0)k2
l  m
1
T0 = 2ηλ ln ρπ2 η 2 + 3 . We have

min k∇L(x(t))k22 ≤ 8π 4 ρ2 λη. (3.2)


t=0,...,T0
64
This bound matches the standard O( √1T ) convergence rate to first order stationary
point for non-convex functions. Remarkably, for a given training budget T , once we
D
can set ηλ to be T
where D is a constant (e.g. 10), the convergence becomes robust
to the choice the hyperparameters due to just a logarithmic dependence on them. In
particular, GD+WD can work with any scaling of L (which affects the smoothness
kx(0)k22
on unit sphere, ρ), LR η and initial norm kx(0)k2 , as long as ρπ 2 η
∈ [e−D , eD ] .
This is in sharp contrast to GD on standard loss as it requires knowledge about the
smoothness to set the optimal LR.
However, one weakness of the above result is that with a fixed ηλ, longer training
does not guarantee further convergence. The intuition is that once the iterate converge
in direction and the gradient vanishes, Weight Decay will dominate the dynamics
and thus the norm approaches 0, which increases the sharpness. When the sharpness
gets larger than 2/η, the dynamics become unstable and results in divergence. This
phenomena is first observed in Li et al. [50] and verified by Lobacheva et al. [51]
in practical settings. This behavior can also be viewed as a special case of Edge of
Stability as described in Cohen et al. [17].

Proof Sketch of Theorem 3.3.1. Scale invariant functions do not have bounded smooth-
ness at 0 making it a challenge to use standard convergence analysis. Our key insight
is that for scale invariant loss function, even with a fixed LR η, GD can tune its
η
effective LR kx(t)k22
by changing the norm. Thus once GD passes the area of the
ρ
suitable norm, the smoothness of scale invariant loss function is upper bounded by r2

outside the ball with radius r centered at 0.


More concretely our proof consists of 2 steps. In the first step we show that

GD+WD iterates pass an area of suitable norm (≈ ρη). For large initial norm, WD
could bring the norm to correct scaling in log time and then converge (Theorem 3.6.2).
If the initial norm is too small and the direction is not approximately stationary, then
the large gradient due to the small norm will increase the parameter norm drastically
65
in a single step (Lemma 3.6.1), and again Weight Decay can bring the norm down
in log steps. In the second step we show that, once the norm reaches this suitable
value, the descent lemma (Lemma 3.3.2) starts to hold and the convergence analysis
is standard.

Lemma 3.3.2. Let x(t), x(t + 1) be defined as (3.1), we have

!
1 ρη
L(x(t)) − L(x(t + 1)) ≥ η − 2 k∇L(x(t))k22 .
1 − ηλ 2 kx(t)k2 (1 − ηλ)2

When ηλ ≤ 21 , the above can be simplified into

!
2ρη
L(x(t)) − L(x(t + 1)) ≥ η 1 − k∇L(x(t))k22 .
kx(t)k22

Remark 3.3.3. One might wonder why the upper bounds on loss and gradient norm
do not appear in Theorem 3.3.1. This is because we are working on a compact domain
(the unit sphere) and twice-differentiability implies those bounds implicitly. (See
Lemmas 3.5.3 and 3.5.4)

3.4 Convergence of SGD+WD

Below we present our convergence analysis for SGD+WD.

Setting: Let Γ be an index set and Lγ : Rd /{0} → R be a scale invariant loss


function for each γ ∈ Γ. We denote Eγ Lγ by L. We assume the largest possible
stochastic gradient norm is finite, i.e., M := supγ∈Γ max k∇Lγ (x)k. SGD is defined
kxk=1
as (3.3).
SGD+WD: x(t + 1) = (1 − ηλ)x(t) − η∇Lγt (x(t)), (3.3)

66
where γt ∈ Γ are i.i.d. random variables. We further assume there exists constants
σ and σ, such that σ 2 ≤ E k∇Lγ (x)k22 ≤ σ 2 , for any kxk2 = 1. We finally need the
following condition on ηλ to bound convergence.
q
σ2 2
Condition 3.4.1. M2
≥ 3e4ηλ λη ln 2Tδ .

The Condition 3.4.1 is useful for proving norm convergence in high probability. In
practice, typically ηλ is very small. Our experiments use η = 0.0008 and λ = 0.01.
Hence e4ηλ ≈ 1, and Condition 3.4.1 essentially requires the gradient norm square

cannot exceed its average multiplied by 1/ ηλ ≈ 350, which is reasonable for most
iterates.

Theorem 3.4.2 (SGD+WD). Let x(t) be defined by SGD (3.3). For ηλ ≤ 0.1,
under Condition 3.4.1, with probability 1 − 5δ,

σ2 2λ
∀T1 ≤ t ≤ T − 1, ≤ kx(t)k42 ≤ 4σ 2 , (3.4)
2 η

and

T −1
1 X π 2 ρσ p ρσ 3
k∇L(x(t))k22 ≤ √ + 4 ηλ 2
T − T1 t=T (T − T1 ) 2ηλ σ
1
s s (3.5)
ln 2δ πρM σ ln 2δ p M 2 ρσ
+ 4 + 4 λη ,
T − T1 σ T − T1 σ2

n 2
o
2e4 M 2
where T1 = 1
4ηλ
max ln Mσ2ηλ + ln kx(0)k4 −2
η
, 8 .
2

The proof of this theorem is presented in Section 3.7. Similar to our earlier
result for GD this bound matches the standard O(T −1/4 ) convergence rate of SGD
e 1 ). Further, it only has a logarithmic
for non-convex functions by setting T = O( ηλ

dependence on the initialization scale kx(0)k2 , and enjoys robustness to initialization

67
scale as discussed earlier for GD. We further extend this result to the case where the
scale invariant loss has multiple scale invariant parameter groups in Section 3.8.

3.5 Useful Lemmas

3.5.1 Scale Invariance

Lemma 3.5.1 (Smoothness). For any v, x ∈ Rd with hx, vi = 0, suppose L is scale-


invariant and twice differentiable with ρ := maxkxk2 =1 k∇2 L(x)k, we have

ρ kvk22
L(x + v) − L(x) ≤ hv, ∇L(x)i + .
2 kxk22

Proof of Lemma 3.5.1. Define γ(s) = x + sv, then we have L(γ(0)) = L(x) and
L(γ(1)) = L(x + v). Taking Taylor expansion of F (s) = L(γ(s)) at s = 0, we have

F 00 (s∗ )
F (1) − F (0) = F 0 (0) + , for some s∗ ∈ [0, 1].
2

Note F 0 (0) = hγ 0 (0), ∇L(γ(0))i = h∇L(x), vi and

ρ 0 ∗ 2
F 00 (s∗ ) =γ 0 (s∗ )∇2 L(γ(s∗ ))γ 0 (s∗ ) ≤ 2 kγ (s )k2 ,

kγ(s )k2

where the last inequality uses the fact that L is scale invariant. The proof is completed
by noting that kγ(s∗ )k2 ≥ kγ(0)k2 = kxk22 and that γ 0 (s∗ ) = v.

Lemma 3.5.2 (Smoothness, Multi-group). For any v, x ∈ Rd with hxk , vk i = 0 for


all k ∈ [K], suppose L is multi-group scale invariant (see Definition 3.8.1), we have

K
ρ X kvi k22
L(x + v) − L(x) ≤ hv, ∇L(x)i + .
2 k=1 kxi k22

68
Proof of Lemma 3.5.2. We first prove for the case where kxk k2 = 1, ∀k ∈ [K]. Similar
to the proof of Lemma 3.5.1, it suffices to show that the smoothness of L is at
most ρ along the line joining x and x + v. This holds because ∀s ∈ [0, 1], k ∈ [K],
kxi + svi k2 ≥ kxi k2 by assumption that hxk , vk i = 0 for all k ∈ [K].
x> x>
Now we turn to the general case. b = [ kx11k , . . . , kxKKk ]> and v 0 =
Define x
2 2
v> v>
[ kx11k , . . . , kxKKk ]> . Since L is multi-group scale invariant, we have L(x) = L(b
x)
2 2

x + v 0 ). The proof is completed by applying the previous argument


and L(x + v) = L(b
b and v 0 .
on x

π
Lemma 3.5.3. If L is scale invariant, k∇L(x)k2 ≤ kxk2
supkxk=1 k∇2 L(x)k2 .

Proof of Lemma 3.5.3. It suffices to prove the above bound for all x with kxk2 = 1.
Let x∗ be any local minimizer of L on Sd−1 and γ : [0, 1] → Sd−1 be the geodesic curve
satisfying that γ(0) = x∗ and γ(1) = x. We know the length of {γ(t)}1t=0 ≤ π and
thus

Z 1 Z 1
2 dγ(t) dγ(t)
k∇L(x)k = ∇ L(γ(t)) dt ≤ ∇2 L(γ(t)) 2
dt ≤ ρ · π
t=0 dt t=0 dt 2

π2
Lemma 3.5.4. If L is scale invariant, supx,x0 L(x) − L(x0 ) ≤ 2
supkxk=1 k∇2 L(x)k2 .

Proof of Lemma 3.5.4. Similar to the proof of Lemma 3.5.3.

3.5.2 Probablity

Definition 3.5.5. A random variable X ∈ R is said to be sub-Gaussian with variance


proxy σ 2 (denoted by X ∼ subG(σ 2 )) if its moment generating function satisfies

σ 2 s2
E[exp(sX)] ≤ exp( ), ∀s ∈ R.
2

69
In this work, we also use the following notion of conditional subgaussian. We say a
random variable X ∈ R is said to be sub-Gaussian with variance proxy σ 2 conditioned
on event E (denoted by X ∼ subG(σ 2 , E)) if its moment generating function satisfies

σ 2 s2
E[exp(sX)1[E]] ≤ exp( ), ∀s ∈ R.
2

Lemma 3.5.6 (Chernoff Bound with Conditioning). Let X ∼ subG(σ 2 , E). Then for
any t > 0, it holds that

t2 t2
P[X > t ∧ E] ≤ exp(− ), and P[X < −t ∧ E] ≤ exp(− )
2σ 2 2σ 2

When P[E] = 1, we get the standard Chernoff bound. Let X ∼ subG(σ 2 ). Then for
any t > 0, it holds that

t2 t2
P[X > t] ≤ exp(− ), and P[X < −t] ≤ exp(− )
2σ 2 2σ 2

Proof of Lemma 3.5.6. For any s > 0, we have

σ 2 s2
P[X > t ∧ E] = P[esX ≥ est ∧ E] ≤ e−st E[esX 1[E]] = exp(−st + ).
2

t
The proof is completed by picking s = σ2
.

We will use (Ω, Σ, P) to note the probability space and {Ft }t∈N to denote the
filtration.

Lemma 3.5.7 (Azuma Inequality with Conditioning). Let Et ∈ Ft and Et+1 ⊂ Et


for all t ≥ 0. Let {Xt }t≥1 be a martingale difference sequence and subG(σt2 , Et−1 )
s2 σt2
conditioned on Ft−1 , i.e., E[exp(sXt )1[Et−1 ] | Ft−1 ] ≤ exp( 2
) for all t ≥ 0. Then
PT PT −1 2
i=1 Xi is subG( t=0 σt , ET −1 ).

70
Proof. We will prove by induction on T . When T = 1, the statement is true by
assumption. Now suppose the statement holds for T − 1, we have for any s > 0

T
X T −1
X
E[exp(s Xi )1[ET −1 ]] =E[exp(s Xi )1[ET −1 ]E[exp(sXT )1[Et−1 ] | FT −1 ]]
i=1 i=1
T −1
X s2 σT2 −1
≤E[exp(s Xi )1[ET −1 ] exp( )]
i=1
2
T −1
X s2 σT2 −1
≤E[exp(s Xi )1[ET −2 ]] exp( )
i=1
2

PT −1
PT s2 σt2
Thus we have that E[exp(s i=1 Xi )1[ET −1 ]] ≤ exp( t=0
2
).

3.5.3 Others

Lemma 3.5.8. ∀t ∈ N, k ∈ N+ , 0 < x < 1,

t
X ekx
(1 − x)kτ ≤
τ =0
kx

Proof of Lemma 3.5.8.

t ∞ ∞
X X X 1 ekx

(1 − x) ≤ kτ
(1 − x) ≤ e−kxτ = ≤ ,
τ =0 τ =0 τ =0
1 − e−kx kx

where the last step is because ex ≥ 1 + x, ∀x ∈ R.

3.6 Proofs for Convergence of GD+WD

Proof of Lemma 3.3.2. This is a special case of Lemma 3.5.1 with x = (1 − ηλ)x(t)
and v = −η∇L(x(t)). Here we use the assumption that L is scale invariant, ∇L is
∇L(x(t))
−1-homogeneous. By Lemma 3.2.3, which means ∇L(x) = 1−ηλ
.

The following lemma deals with the case where kx(0)k22 < π 2 ρη.

71
Lemma 3.6.1. Let I = {T 0 ∈ N | ∀0 ≤ t ≤ T 0 , kx(t)k22 ≤ π 2 ρη ∧ k∇L(x(t))k22 >
2(π 2 ρη)2
8π 4 ρ2 λη}. Suppose 0 ∈ I and T = max I. Then T ≤ 1
6λη
and kx(T + 1)k22 ≤ kx(0)k22
.

Proof of Lemma 3.6.1. For any t ≤ T , we have

kx(t + 1)k22 − kx(t)k22 =((1 − λη)2 − 1) kx(t)k22 + η 2 k∇L(x(t))k22


η 2 k∇L(x(t))k22
≥ − 2λη kx(t)k22 +
kx(t)k22

≥ − 2π 2 ρλη 2 + 8π 2 ρλη 2

=6π 2 ρλη 2 .

Thus 6π 2 ρλη 2 · T ≤ kx(T )k22 − kx(0)k22 < kx(T )k22 ≤ π 2 ρη, which implies that
1
T < 6λη
. Moreover, we have that

kx(T + 1)k22 =(1 − ηλ)2 kx(T )k22 + η 2 k∇L(x(T ))k22


η 2 k∇L(x(T ))k22
≤ kx(T )k22 +
kx(T )k22
η 2 k∇L(x(T ))k22
≤ kx(T )k22 +
kx(0)k22
ρ2 π 2 η 2
≤π 2 ρη +
kx(0)k22
2(π 2 ρη)2
≤ .
kx(0)k22

This completes the proof.

Theorem 3.6.2 (convergence rate of GD+WD). Suppose ηλ ≤ 12 . Let x(t) be the


2kx(0)k2
l m
1
t-th iterate of GD (3.1), and T0 = 2ηλ ln ρπ2 η 2 . If kx(0)k22 ≥ π 2 ρη, we have

min k∇L(x(t))k22 ≤ 8π 4 ρ2 λη.


t=0,...,T0

Proof of Theorem 3.6.2. We first claim there’s 0 ≤ t ≤ T0 , such that kx(t)k22 < π 2 ρη.

72
Otherwise, by Lemma 3.3.2, for t = 0, . . . , T0 , we have L(x(t)) − L(x(t + 1)) ≤
η
2
k∇L(x(t))k22 . Note that kx(t + 1)k22 − (1 − ηλ)2 kx(t)k22 = η 2 k∇L(x(t))k22 .
Therefore, we have that

0 −1
TX
kx(T0 )k22 − (1 − ηλ)2T0
kx(0)k22 = η 2 (1 − ηλ)2(T0 −t) k∇L(x(t))k22
t=0
TX0 −1

≤ η 2 k∇L(x(t))k22
t=0
η
≤ (L(x(0)) − L(xT0 −1 ))
2
ηπ 2 ρ

2

ηπ 2 ρ
By the definition of T0 , we have (1 − ηλ)2T0 kx(0)k22 ≤ e−2ηλT0 kx(0)k22 ≤ 2
. Thus
kx(T0 )k ≤ π 2 ρη.
Without loss of generality, we let T be the smallest integer such that kx(T )k22 ≤
π 2 ρη. By assumption, T ≥ 1. Therefore kx(T − 1)k22 ≥ π 2 ρη. Because kx(T )k22 =
(1 − ηλ)2 kx(T − 1)k22 + η 2 k∇L(x(T − 1))k22 , we have that

k∇L(x(T − 1))k22 = k∇L(x(T − 1))k22 kx(T − 1)k22

≤η −2 kx(T )k22 − (1 − ηλ)2 kx(T − 1)k22 kx(T − 1)k22 .




kx(T )k22
Note that kx(T )k22 < π 2 ρη and (1−λη)2
≥ kx(T − 1)k22 ≥ π 2 ρη, we conclude that

 kx(T )k22
k∇L(x(T − 1))k22 ≤η −2 kx(T )k22 − (1 − ηλ)2 kx(T − 1)k22 )
(1 − λη)2
1 − (1 − λη)2 2 2
≤ 2 (π ρη)
η (1 − λη)2
≤8ληπ 4 ρ2 ,

which completes the proof.

73
Combining Lemma 3.6.1 and Theorem 3.6.2 removes the initial condition in
Theorem 3.6.2, and completes the proof of Theorem 3.3.1.

3.7 Proofs for Convergence of SGD+WD

We will use (Ω, Σ, P) to note the probability space and {Ft }t∈N to denote the filtration
where Ft := σ({γi | 0 ≤ i ≤ t}) is the σ-algebra generated by γ0 , . . . , γt .

4
Lemma 3.7.1. k∇Lγ (x)k22 − E k∇Lγ (x)k22 ∼ subG( 4kxk
M
4 ).
2

M2
Proof. Lemma 3.7.1 Note 0 ≤ k∇Lγ (x)k22 ≤ kxk22
. The proof is immediate by Hoeffding
Lemma (see Lemma 3.6 in [52]).

Given a integer T ≥ 0, let ET be the event that ∀0 ≤ t0 ≤ t ≤ T − 1,

s
t
X M2 1 2T 2
(1 − ηλ)4(t−τ ) k∇Lγτ (x(τ ))k22 − E[k∇Lγτ (x(τ ))k22 | x(τ )] ≤ e4ηλ ·

ln .
τ =t0
4 λη δ

(3.6)

Lemma 3.7.2. For any 0 ≤ t0 ≤ t ≤ T − 1,

t
X e8ηλ M 4
(1 − ηλ)4(t−τ ) k∇Lγτ (x(τ ))k22 − E[k∇Lγτ (x(τ ))k22 | x(τ )] ∼ subG(

)
τ =t0
32

Thus we have P[ET ] ≥ 1 − δ by Lemma 3.5.6.

8(t−τ ) M 4 e8ηλ
Pt
Proof of Lemma 3.7.2. Note that τ =t0 (1−ηλ) 4
≤ 32
by Lemma 3.5.8. Thus
by Azuma Inequality and Lemma 3.7.1, we have that the martingale

t
X
(1 − ηλ)4(t−τ ) k∇Lγτ (x(τ ))k22 − E[k∇Lγτ (x(τ ))k22 | x(τ )]

τ =t0

e8ηλ
is 32
-subgaussian.

74
By Lemma 3.5.6, we have for any ∀0 ≤ t0 ≤ t ≤ T − 1, Equation (3.6) holds with
δ
probability at least T2
. The proof is completed by applying union bound.

Lemma 3.7.3 (Norm Lower Bound). Under Condition 3.4.1 and additionally assume
ηλ ≤ 21 . On ET , it holds that for any t ≥ 0,

s
1 − ηλ 1 1 2T 2
η −2 kx(t)k42 ≥ (1 − e−4tηλ(1−ηλ) )σ 2 − (1 − ηλ)2 M 2 e4ηλ ln (3.7)
2ηλ 2 λη δ

q
σ2 M 2 4ηλ 1 2
When 12ηλ
≥ 2
e λη
ln 2Tδ , the above condition is simplified into the following:
1
on ET for any ηλ
≤ t ≤ T,

5(1 − ηλ)2 σ 2 (1 − ηλ)2 σ 2 (1 − ηλ)2 σ 2


η −2 kx(t)k42 ≥ − = , (3.8)
12ηλ 6ηλ 4ηλ

In the above inequality, we also used the fact that 1 − e−4(1−ηλ) ≥ 56 , which is
implied by ηλ ≤ 0.5.

Proof of Lemma 3.7.3. Since Lγ is scale invariant, by Theorem 3.2.2, we have

k∇Lγt (x(t))k22
kx(t + 1)k22 = (1 − ηλ)2 kx(t)k22 + η 2 . (3.9)
kx(t)k22

Squaring both sides of Equation (3.9), we have

η 4 k∇Lγt (x(t))k42
kx(t + 1)k42 = (1 − ηλ)4 kx(t)k42 + 2(1 − ηλ)2 η 2 k∇Lγ (x(t))k22 + .
kx(t)k42
(3.10)

75
Thus

t
X
η −2
kx(t + 1)k42 ≥2 (1 − ηλ)4(t−τ )+2 k∇Lγτ (x(τ ))k22
τ =0
t
X
≥2 (1 − ηλ)4(t−τ )+2 E k∇Lγτ (x(τ ))k22
τ =0
t
X
(1 − ηλ)4(t−τ )+2 k∇Lγτ (x(τ ))k22 − E k∇Lγτ (x(τ ))k22 .

+2
τ =0

We also have that

t t
X X 1 − e−4tηλ(1−ηλ) 1 − e−4tηλ(1−ηλ)
(1 − ηλ) 4(t−τ )
≥ e−4(t−τ )ηλ(1−ηλ) = ≥ .
τ =0 τ =0
1 − e−4ηλ(1−ηλ) 4ηλ(1 − ηλ)

Therefore, it holds that for any t ≥ 0, conditioned on ET ,

s
1 − ηλ 1 1 2T 2
η −2 kx(t)k42 ≥ (1 − e−4tηλ(1−ηλ) )σ 2 − (1 − ηλ)2 M 2 e4ηλ ln
2ηλ 2 λη δ

This completes the proof.

Lemma 3.7.4 (Norm upper bound). Under Condition 3.4.1 and additionally assume
1
ηλ ≤ 0.1. Let T0 = d ηλ e. Let t∗ be the earliest step t in {0, . . . , T0 − 1} that
e8 (1−ηλ)2 σ 2
η −2 kx(t)k42 ≥ 4ηλ
and we denote t∗ = T0 if this doesn’t happen in {0, . . . , T0 −1}.
(1−ηλ)2 σ 2
For the case t∗ = T0 , we have η −2 kx(T0 )k42 ≤ 4ηλ
. On ET , for any t ≥ t∗ ,

( )
2e4 M 2
−4λη(t−t∗ ) ln 4σ
2
σ2
−2
1)k42 kx(0)k4
2 −2
η kx(t + ≤e max 2M e 2η ,e . + . (3.11)
ηλ ηλ

n 2
o
2e4 M 2
Thus, there exists T1 = T0 + 1
4ηλ
max ln Mσ2ηλ + ln kx(0)k4 −2 , 4 , such that ∀t ≥ T1 ,
η 2
2σ 2
η −2
kx(t + 1)k42 ≤ ηλ
.

76
Proof of Lemma 3.7.4. If t∗ < T0 , it holds that conditioned on ET , for any t∗ ≤ t < T0 ,

∗ (1 − ηλ)2 σ 2
η −2 kxt k42 ≥ (1 − ηλ)4(t−t ) η −2 kx(t∗ )k42 ≥ (1 − ηλ)4(T0 −1) η −2 kx(t∗ )k42 ≥
4ηλ

Therefore, for any t ≥ t∗ , we have

η −2 kx(t + 1)k42
k∇Lγt (x(t))k42
=(1 − ηλ)4 η −2 kx(t)k42 + 2(1 − λη)2 k∇Lγ (x(t))k22 +
kx(t)k42 η −2
t
4(t+1−t∗ ) −2
X
=(1 − ηλ) η kx(t ∗
)k42 +2 (1 − ηλ)4(t−τ )+2 E[k∇Lγτ (x(τ ))k22 | x(τ )]
τ =t∗
| {z }
(A)
t
X
(1 − ηλ)4(t−τ )+2 k∇Lγτ (x(τ ))k22 − E[k∇Lγτ (x(τ ))k22 | x(τ )]

+2
τ =t∗
| {z }
(B)
t 4
4(t−τ ) k∇Lγτ (x(τ ))k2
X
+ (1 − ηλ) .
τ =t∗
kx(τ )k42 η −2
| {z }
(C)

(3.12)

Below we will upper-bound the terms (A), (B) and (C) on ET respectively.

(A). By Lemma 3.5.8, we have

t
X (1 − ηλ)2 e4ηλ 2 e0.2 2
(A) ≤ 2 (1 − ηλ)4(t−τ )+2 σ 2 ≤ σ ≤ σ , (3.13)
τ =t∗
2ηλ 2ηλ

where in the last step we used ηλ ≤ 0.1 and ex (1 − x) ≤ 1 for any 0 ≤ x ≤ 1.

(B). By the definition of event ET , we have

s
M 2 4ηλ 1 2T 2 (1 − ηλ)2 2
(B) ≤ (1 − ηλ)2 e ln ≤ σ (3.14)
2 λη δ 6ηλ

77
(C). Combining the above analysis and Lemma 3.7.3, we know conditioned on ET ,
(1−ηλ)2 σ 2
for any t ≥ t∗ , it holds kx(t)k42 /η 2 ≥ 4ηλ
.

Therefore, by Lemma 3.5.8, we have

t
4ηλM 4 X 4(t−τ )−2 e4ηλ M 4
(C) ≤ (1 − ηλ) ≤ (3.15)
σ 2 τ =t∗ (1 − ηλ)2 σ 2

σ2
Under Condition 3.4.1, we can further upper bound (C) by 9ηλe4ηλ (1−ηλ)2

σ2 σ2
9× 98 × 87 ηλ
= 7ηλ
, where we used the fact that ηλ ≤ 0.1.

What is left to do is to upper bound η −2 kx(t∗ )k42 . We proceed by discussing the


following three cases respectively:

• t∗ = 0. Then η −2 kx(t∗ )k42 = η −2 kx(0)k42 .

• 1 ≤ t∗ ≤ T0 − 1. In this case, we have

∗ −1)
η −1 kxt∗ −1 k22 ≥ (1−ηλ)2(t η −1 kx(0)k22 ≥ e−4(T0 −1)ηλ η −1 kx(0)k22 ≥ e−4 kx(0)k22 η −1 .

Thus it holds that

2
∇Lγt∗ −1 (x(t∗ − 1))
η −1
kx(t ∗
)k22 =(1 − ηλ) η 2 −1
kxt∗ −1 k22 + 2
kxt∗ −1 k22 η −1
s
e8 (1 − ηλ)2 σ 2 M2
≤(1 − ηλ)2 + e4
4ηλ kx(0)k22 η −1
s
e8 σ 2 4 M2
≤2 max{ ,e }
4ηλ kx(0)k22 η −1

(1−ηλ)2 σ 2
• t∗ = T0 . Then we have η −2 kx(t∗ )k42 ≤ 4ηλ
.

Taking maximum over three cases, we have


( )
2e4 M 2
ln σ2
η −2 kx(t∗ )k42 ≤ max 2e4 M 2 e kx(0)k4 −2
2η , e8 . (3.16)
ηλ

78
Plugging (3.16) back into (3.12), we got for any t ≥ t∗

η −2 kx(t + 1)k42

=(1 − ηλ)4ηλ(t+1−t ) η −2 kx(t∗ )k42 + (A) + (B) + (C) (3.17)
( )
2e4 M 2
−4λη(t−t∗ ) 2 ln
kx(0)k4 η −2 4σ
2
σ2
≤e max 2M e 2 ,e . + ,
ηλ ηλ

where we used the fact that (0.5e0.2 + 16 + 1


≈ 0.9202 < 1) in the last step.
7
n 2
o
2e4 M 2
1
Therefore there exists T1 = T0 + 4ηλ max ln Mσ2ηλ + ln kx(0)k4 −2 , 4 , such that
η 2
2σ 2
for all t ≥ T1 , η −2
kx(t)k42 ≤ ηλ
.

Theorem 3.4.2 (SGD+WD). Let x(t) be defined by SGD (3.3). For ηλ ≤ 0.1,
under Condition 3.4.1, with probability 1 − 5δ,

σ2 2λ
∀T1 ≤ t ≤ T − 1, ≤ kx(t)k42 ≤ 4σ 2 , (3.4)
2 η

and

T −1
1 X π 2 ρσ p ρσ 3
k∇L(x(t))k22 ≤ √ + 4 ηλ 2
T − T1 t=T (T − T1 ) 2ηλ σ
1
s s (3.5)
ln 2δ πρM σ ln 2δ p M 2 ρσ
+ 4 + 4 λη ,
T − T1 σ T − T1 σ2

n o
1 M 2 ηλ 2e4 M 2
where T1 = 4ηλ
max ln σ2 + ln kx(0)k4 η−2 , 8 .
2

Proof. By Lemma 3.5.1, we have

η h∇L(x(t)), ∇Lγt (x(t))i ρη 2 k∇Lγt (x(t))k22


L(x(t + 1)) − L(xt ) ≤ − +
1 − ηλ kx(t)k22 2(1 − ηλ)2 kx(t)k42

79
Summing up for t = T1 to T − 1, we have

T −1
X T −1
X
η k∇L(x(t))k22 kx(t)k−2
2 = η k∇L(x(t))k22
t=T1 t=T1
T −1
X ρη 2 E[k∇Lγt (x(t))k22 | x(t)]
≤(1 − ηλ) (L(xT1 ) − L(xT )) +
t=T1
2(1 − ηλ) kx(t)k42
| {z }
(A)
T −1
X η h∇L(x(t)), ∇L(x(t)) − ∇Lγt (x(t))i
+
t=T1
kx(t)k22
| {z }
(B)
T −1
ρη 2 k∇Lγt (x(t))k22 − E[k∇Lγt (x(t))k22 | x(t)]

X
+
t=T1
2(1 − ηλ) kx(t)k42
| {z }
(C)

Below we will give high-probability bounds for (A), (B) and (C) respectively. For
convenience, we will use A(t), B(t), C(t) to denote the tth term in (A), (B) and (C).
√ 2
Claim 3.7.5. ET =⇒ ∀T1 ≤ t ≤ T, A(t) ≤ 2 2ρηλ σσ2
PT −1 2 ληρ2 M 2
Claim 3.7.6. (B) = t=T1 B(t) is subG((T − T1 ) 4π σ2
, ET )
PT −1 2 λ2 η 2 M 4
Claim 3.7.7. (C) = t=T1 C(t) is subG((T − T1 ) 4ρ σ4
, ET )

Here Claim 3.7.5 follows from that 2(1 − ηλ) ≥ 2 and Lemma 3.7.3. Note by
the choice of T1 , we can upper and lower bound kx(t)k2 by Lemmas 3.7.3 and 3.7.4,
σ2 2σ 2
that is 4ηλ
≤ η −2 kx(t)k22 ≤ ηλ
. Thus Claims 3.7.6 and 3.7.7 is a direct consequence
of Lemma 3.5.7.
Thus we conclude w.p. 1 − 5δ,

T −1
r
λη 1 X 2 L(x(T1 )) − minx L(x) √ σ2
k∇L(x(t))k2 ≤ + 2 2ρηλ
2σ 2 T − T1 t=T T − T1 σ2
1
s s
2
8λη ln δ πρM 8 ln 2δ M 2ρ
+ + λη 2 ,
T − T1 σ T − T1 σ

80
rearranging it and applying Lemma 3.5.4, we get

T −1
1 X π 2 ρσ p ρσ 3
k∇L(x(t))k22 ≤ √ + 4 ηλ 2
T − T1 t=T (T − T1 ) 2ηλ σ
1
s s
ln 2δ 4πρM σ ln 2δ p M 2 ρσ
+ + 4 λη .
T − T1 σ T − T1 σ2

q
σ2
By Condition 3.4.1, we have M2
≥ 3 λη ln 2δ , and thus we have

T −1
s r
1 X π 2 ρσ p ρσ 3 4 1 1 4ρσ
k∇L(x(t))k22 ≤ √ + 4 ηλ 2 + πρσ + .
T − T1 t=T (T − T1 ) 2ηλ σ 3 (T − T1 )ηλ T − T1 3
1

This completes the proof.

3.8 Convergence of SGD for Multi-group Scale In-

variant Functions

In this section we extend our results to the multi-group scale invariant setting, which
is quite common in practice, e.g. a feedforward network with normalization after each
layer. By Definition 3.8.1, multi-group scale invariant function is also scale invariant.
However, it violates the assumption that the smoothness and the expectation of
stochastic gradient norm square is lower bounded on unit sphere (indeed the loss
function is not defined at everywhere on unit sphere), and thus needs to be treated
x y
separately. A simple example would be L(x, y) = L( kxk , kyk ), the loss L is undefined
2 2

at any point where kxk2 = 1 and y = 0. Yet our analysis for single scale invariant
parameter group can still extend to this case, with a similar assumption that the
expected gradient norm square is lower bounded.

81
Let d1 , . . . , dK be positive integers with d = K d d1 dK
P
k=1 dk . For x ∈ R = R ×. . .×R ,

we use sk to denote i≤k di and xk to denote the vector [xsk−1 , . . . , xsk −1 ]> . For
P

∂f (x)
convenience, we define ∇k f (x) = ∂xk
for any 1 ≤ k ≤ K.

Definition 3.8.1. Given d1 , . . . , dK and a cone U ⊂ Rd , we say a function f : U → R


is multi-group scale invariant iff f (x1 , . . . , xK ) = f (c1 x1 , . . . , cK xK ) for any x ∈ U
and ck > 0 for 1 ≤ k ≤ K.

Setting: Similarly, we assume there exists constants σ k and σ k , such that σ 2k ≤


E k∇k Lγ (x)k22 ≤ σ 2k , for any x such that kxk k2 = 1. In this subsection, we define
ρ := max λmax (∇2 L(x)).
kxk k2 =1,∀k
q
σ 2k 2
Condition 3.8.2. Mk2
≥ 3e4ηλ λη max{ln 2Tδ , 1}.

Theorem 3.8.3 (SGD+WD, Multi-group Scale Invariance). With probability 1 −


(K + 2)δ, it holds that

p T −1
λη/2 1 X
PK k∇L(x(t))k22
k=1 σ k
T − T1
t=T1
K
π2ρ √ X σ 2k
≤ + 2 2ρηλ (3.18)
T − T1 σ2
k=1 k
s s
K K
8λη ln 2δ X Mk 8 ln 2δ X Mk2
+ πρ + ληρ 2
,
T − T1 k=1
σ k T − T 1
k=1
σ k

n o
1 Mk2 ηλ 2e4 Mk2
where T1 = 4ηλ
maxk ln σ2 + ln kx (0)k4 η−2 , 8 .
k k 2

Following the same strategy, we can prove the multi-group counterpart of norm
convergence result, Lemma 3.7.2. Given a integer T ≥ 0, let ET,k be the event that
∀0 ≤ t0 ≤ t ≤ T − 1,

s
t
X M2 1 2T 2
(1 − ηλ)4(t−τ ) k∇k Lγτ (x(τ ))k22 − E[k∇k Lγτ (x(τ ))k22 | x(τ )] ≤ e4ηλ · k

ln .
τ =t0
4 λη δ
82
Lemma 3.8.4. For any 0 ≤ t0 ≤ t ≤ T − 1, 1 ≤ k ≤ K, it holds that

t
X e8ηλ Mk4
(1 − ηλ)4(t−τ ) k∇k Lγτ (x(τ ))k22 − E[k∇k Lγτ (x(τ ))k22 | x(τ )] ∼ subG(

)
τ =t0
32

Thus we have P[ET,k ] ≥ 1 − δ by Lemma 3.5.6.

The following theorem is a restatement of Lemmas 3.7.3 and 3.7.4 in the context
of multi-group scale invariance.
n o
1 Mk2 ηλ 2e4 Mk2
Lemma 3.8.5. Under Condition 3.8.2, there exists T1 = 4ηλ
maxk ln σ2 + ln kx (0)k4 η−2 , 8 ,
k k 2
σ 2k 2σ 2k
such that ∀t ≥ T1 , 4ηλ
≤ η −2 kx(t)k42 ≤ ηλ
, conditioned on ∪K
k=1 ET,k .

The proof of Theorem 3.8.3 is a natural generalization of Theorem 3.4.2.

Proof of Theorem 3.8.3. Setting x = (1 − ηλ)x(t) in Lemma 3.5.2, we have

K
η X ρη 2 k∇k Lγt (x(t))k22
L(x(t + 1)) − L(xt ) ≤ − h∇L(x(t)), ∇Lγt (x(t))i +
1 − ηλ k=1
2(1 − ηλ)2 kxk (t)k42

x> x>
b = [ kx11k , . . . , kxKKk ]> . Summing up for t = T1 to T − 1,
For convenience we define x
2 2

we have

T −1
X T −1
X
η k∇L(x(t))k22 kx(t)k−2
2 = η k∇L(x(t))k22
t=T1 t=T1
T −1 X
K
X ρη 2 E[k∇k Lγt (x(t))k22 | x(t)]
≤(1 − ηλ) (L(xT1 ) − L(xT )) +
t=T1 k=1
2(1 − ηλ) kxk (t)k42
| {z }
(A)
T −1 XK
X η h∇k L(b
x(t)), ∇k L(bx(t)) − ∇k Lγt (b
x(t))i
+ 2
t=T1 k=1
kxk (t)k2
| {z }
(B)
T −1 X
K
ρη 2 k∇k Lγt (x(t))k22 − E[k∇k Lγt (x(t))k22 | x(t)]

X
+
t=T1 k=1
2(1 − ηλ) kxk (t)k42
| {z }
(C)

83
Below we will give high-probability bounds for (A), (B) and (C) respectively. For
convenience, we will use A(t), B(t), C(t) to denote the tth term in (A), (B) and (C).
√ PK σ2k
Claim 3.8.6. ∪K
k=1 ET,k =⇒ ∀T1 ≤ t ≤ T, A(t) ≤ 2 2ρηλ k=1 σ 2k

PT −1 2 2
 PK M k  2 K
Claim 3.8.7. (B) = t=T1 B(t) is subG(4π ληρ (T − T1 ) k=1 σ k , ∪k=1 ET,k )
P 2
P −1 Mk2
Claim 3.8.8. (C) = Tt=T 1
C(t) is subG(4ρ2 2 2
λ η (T − T1 ) K
k=1 σ 2 , ∪Kk=1 ET,k )
k


Here Claim 3.8.6 follows from that 2(1 − ηλ) ≥ 2 and Lemma 3.7.3. Note by
the choice of T1 , we can upper and lower bound kx(t)k2 by Lemma 3.8.5, that is
σ 2k 2σ 2k
4ηλ
≤ η −2 kxk (t)k22 ≤ ηλ
. Thus Claims 3.8.7 and 3.8.8 is a direct consequence of
Lemma 3.5.7.
Thus by Chernoff bound (Lemma 3.5.6), with probability at least 1 − (K + 2)δ,
Equation (3.18) holds.

84
Chapter 4

Robust and Memory-efficient


Optimization via Designing Scale
Invariant Architectures

In contrast to SGD, adaptive gradient methods like Adam allow robust training
of modern deep networks, especially large language models. However, the use of
adaptivity not only comes at the cost of extra memory but also raises the fundamental
question: can non-adaptive methods like SGD enjoy similar benefits? In this chapter,
we provide an affirmative answer to this question by proposing to achieve both
robust and memory-efficient training via the following general recipe: (1) modify the
architecture and make it scale invariant, i.e. the scale of parameter doesn’t affect
the output of the network, (2) train with SGD and weight decay, and optionally
q
(3) clip the global gradient norm proportional to weight norm multiplied by 2λ
η
,
where η is learning rate and λ is weight decay. We show that this general approach is
robust to rescaling of parameter and loss by proving that its convergence only depends
logarithmically on the scale of initialization and loss, whereas the standard SGD
might not even converge for many initializations. Following our recipe, we design a

85
scale invariant version of Bert, called Sibert, which when trained simply by vanilla
SGD achieves performance comparable to Bert trained by adaptive methods like
Adam on downstream tasks.

4.1 Introduction

Neural architectures like transformers are the cornerstone for modern machine learning
applications. However, training them is difficult and often results in training instability
[53, 54]. To enable stable training, one typically requires adaptive and carefully tuned
learning rates. However, the reason behind this issue is not very well-understood and
lacks a formal treatment.
In this chapter, we hypothesize that a primary cause of such behavior is the
k-homogeneous (k ≥ 2) nature of the network i.e., property where network’s output is
scaled by sk when its parameters are scaled by s. To illustrate our point, we consider
the following instructive toy model.

Example 4.1.1. Consider logistic regression with 1-dimensional non-separable data,


{(zi , yi )}ni=1 ∈ (R × {±1})n . The loss is defined as L(x1 , , . . . , x2k ) = L(X)
e :=
− ni=1 ln(1 + e−zi yi X ) where X = x1 . . . x2k and k ≥ 2. Since the data is non-
P

separable, the global optimum X ∗ must be finite and with out loss of generality, we
assume it positive.
Since L
e is convex with bounded smoothness in X, there exists step size that are

independent of any initialization that allow GD to converge to the optimal solution. In


sharp contrast, the reparametrized loss L(x1 , , . . . , x2k ) with 2k-homogeneous structure
does not enjoy this nice stability property — the learning rate has to be tuned according
1
−1
to the initialization. In particular, when η > 2
|∇L(X(0))|
e (X(0)) k and X(0) > X ∗
where X ∗ > 0 is the global minimizer, X(t) will monotonically increase and explode,
if all xi are initialized to be the same.

86
This is because |∇L(X)|
e is positive and monotone increases among all X > X ∗ .
Since all xi are initialized equally, they must be the same at any iteration. It
 
X(t) X(t)
holds that xi (t + 1) = xi (t) − η xi (t) ∇L(X(t)) = xi (t) 1 − η x2 (t) ∇L(X(t)) , where
e e
i
 2k
X(t)
2k
X(t) = Πj=1 xj (t). This implies X(t + 1) = X(t) 1 − η k √ ∇L(X(t))
e ≥
X(t)
 2k
X(0) 1
X(t) 1 − η k √ ∇L(X(0))
e > X(t). Thus we conclude if η ≥ |∇L(X(0))|2
(X(0)) k −1
X(0) e
X(t) X(0)
and X(0) > X ∗ , η √
k
∇L(X(t))
e −1 ≥ η√
k
∇L(X(0))
e − 1 > 1 and thus X(t)
X(t) X(0)

will increase monotonically and explode.

In the above example, the success of optimization is very sensitive to the right
choice of the learning rate that depends on the initialization. Furthermore, the training
cannot recover once the norm explodes due to large gradient update.
Still it is possible to find a small workable learning rate by extensive grid search
that depends on the initial point in the above one-dimensional example. However,
the situation can get worse when the k-homogeneous structure has an unbalanced
initialization as below.

Example 4.1.2. Consider solving low-rank matrix decomposition by Gradient Descent.


2
Let L(A, B) = 1
2
AB > − Y 2
where A, B ∈ Rd×r are both initialized i.i.d. gaussian
with covariance σA2  σB2 ≈ σA−2 , Y ∈ Rd×d and d  r.
Solving this optimization problem requires A and B learning the column and row
space of Y respectively, but the unbalanced initialization will force the learning rate
to be small enough such that B does not explode and, thus, A is almost frozen. To
see this, note in the standard convergence analysis of GD, we need LR smaller than
2/ k∇2 Lk to ensure the Descent Lemma holds, i.e., loss decreases in a single step. Here
we have that the smoothness w.r.t A (fixing B) is λmax (BB T ) and the smoothness
w.r.t. B (fixing A) is λmax (AAT ). Thus, LR can be at most O( σ12 ), but the gradient
A

of A is only of magnitude O(σB ), resulting in A learning the column space slowly.

87
4
Specifically, when d = 1 and Y = 0 and for any r ≥ 1, choosing η > k∇2 L(A(0),B(0))k

will cause GD to explode [55].

Similar issues can exist in deep neural networks as the k-homogeneous structure
is quite common. For instance, Liu et al. [53] identified the gradient norm varies
with depth and that no single learning rate is globally optimal for all layers. To this
end, one has to resort to adaptive methods like Adam to handle the k-homogeneous
structure of deep networks and allow for its robust training. However, this not only
comes at the expense of higher memory, but also raises the key question of our interest:
Can non-adaptive methods like SGD enjoy fast and robust convergence without
training instability?
Answering this question, requires us to first define our notion of robustness. In
this chapter, we primarily aim for three aspects of robustness by preventing: explosion
of parameters (e.g. due to frequent large gradient updates), slow progress in training
(e.g. due to loss plateaus) and loss explosion or spikes (e.g. due to possibly infrequent
large magnitude updates). In this chapter, we propose a simple yet powerful general
approach for achieving such fast and robust convergence. At a high level, our recipe
for robust training includes three key ingredients:

1. Designing architectural scale invariance which allows for improved training


stability and prevents explosion of the parameters. We show that by using scale
invariance in the architecture (i.e., making the network 0-homogeneous), one
can effectively control the gradient updates when the parameter norm is large.

2. Using SGD with weight decay for training, wherein enabling weight decay im-
proves training efficiency under rescaling of loss and initialization. While scale
invariance prevents explosion of parameters, the training convergence has strong
dependence on initialization scale and learning rate, which can make training

88
inefficient in face of parameter and initialization rescaling. Use of SGD with
weight decay circumvents this issue.

3. Using a novel Relative Global Clipping to prevent spikes in training loss and
improve overall convergence speed. Although scale invariance in the archi-
tecture already guarantees the training stability, it does not prevent severe
non-monotonic loss explosion. By using a new global clipping approach, we show
that one can prevent such loss explosions effectively.

We show that this surprisingly simple training recipe can not only improve the
memory efficiency over adaptive methods but also achieves robust training. In light of
the above background, we list our main contributions below.

• In Section 4.3, we propose a new general recipe for memory efficient, robust
training using (1) scale invariant architecture; (2) SGD+WD for training and (3)
a novel clipping rule, called Relative Global Clipping, for clipping the updates.
Following this recipe, we design a new variant of Bert called Scale Invariant
Bert (Sibert).

• In Section 4.5, we show SGD+WD with Relative Global Clipping has better
parameter norm convergence via a novel analysis. With assumptions that the
clipping does not bring too much bias in expected gradients, we show similar
convergence result to SGD+WD.

• In our empirical analysis in Section 4.6, we demonstrate that Sibert trained


using simple SGD can achieve performance comparable to standard Bert
trained with Adam. Furthermore, we also verify our theoretical claims. To our
knowledge, this is the first time a Bert-like model has been effectively trained
using vanilla SGD.

89
4.2 Related Work and Background

The literature on adaptive methods and scale invariance in neural networks is vast, so
we only discuss works that are most relevant to our paper.

Adaptive Methods & Clipping Methods. Adaptive learning rates have long
been studied [56]. In machine learning, adaptive learning rates have been popu-
larized by Adagrad, which particularly benefits from sparse stochastic gradients
[57]. Inspired by Adagrad, several adaptive methods, like Adam, RMSprop and
its variants have been proposed in the deep learning community [1, 58–61]. These
approaches have been crucial in the success of many deep learning applications [62–
64]. Several works have studied the benefits of adaptive methods in deep learning
settings (e.g. [53, 54]). However, as mentioned earlier, these benefits come at the
cost of computational and memory efficiency. Anil et al. [65] proposed a variant of
Adagrad requiring fewer parameters for adaptivity, but still requires momentum.
Adafactor [61] removes momentum and uses much fewer adaptivity parameters, but
for large models, Adafactor still needs momentum to ensure training stability [66].
Our approach is also related to normalized and projected gradient descent, which has
been studied for quasi-convex and non-convex settings (e.g. see [67–69]). However,
these methods have seen very limited success.
Clipping based optimization methods, especially gradient clipping, are widely used
in deep learning applications to improve training stability or ensure privacy [70–72].
These approaches typically use a constant threshold to clip the gradients before the
update. However, choosing this threshold is difficult and requires careful tuning.
Adaptive variants of clipping methods partially alleviate this issue and are closely
related to adaptive methods [54]; however, they again incur additional computation
and memory costs.

90
Scale Invariance in deep networks. Various normalization schemes are the main
source of scale invariance in deep learning, e.g., BatchNorm [18], LayerNorm [19],
Weight Normalization [73], GroupNorm [30], InstanceNorm [74]. Scale invariance
from normalization allows GD and SGD to converge to stationary points from any
initialization and with any learning rate, in O(T −1/2 ) and O(T
e −1/4 ) rates respectively

[36]. The interplay between SGD, scale invariance and WD has also been well studied.
It was shown that the effect of WD for normalized networks can be replaced by
LR schedules [40, 41]. Li and Arora [22] formally builds the equivalence between
SGD+WD and SGD with an exponential increasing LR schedule for scale invariant
loss. Van Laarhoven [42] first proposed the notion of effective LR, η/ kxk22 , for normal-
ized networks, and showed that the unique stationary value of kxk42 is proportional to
λ/η, where η is LR and λ is WD. Li et al. [50] proved that the parameter norm always
converges to the above value by modeling SGD as Stochastic Differential Equation.
Wan et al. [75] proved the parameter norm converges to the same value directly for
SGD+WD, but only in expectation.

4.3 Methods

In this section, we provide a more detailed description of our recipe for robust and
memory-efficient network training, which includes three building blocks: (1) scale
invariant architecture (Section 4.3.1), (2) SGD with Weight Decay (Section 4.3.2) and
optionally (3) the Relative Global Clipping (Section 4.3.3 and Algorithm 2).

Algorithm 2 C-Clipped SGD + WD
Input: Total steps T , Scale invariant loss {Lt }Tt≥1 , initialization x(0), LR η, WD λ,
clipping factor C > 1 (C = ∞ ⇔ no clipping).
T − 1 do
for t = 0 to nq o
2Cλ
Nt ← min η
kx(t)k 2 , k∇L t (x(t))k 2 .
∇Lt (x(t))
x(t + 1) ← (1 − ηλ)x(t) − ηNt k∇Lt (x(t))k
.
2

91
4.3.1 Designing Scaling Invariant Architectures

We first revisit an approach for introducing scale invariance in neural networks, which
is presented in [22]. Viewing the neural network computation as a directed graph, the
high level idea is to ensure same homogeneity degree of different edges reaching a node.
For example in a ResNet block, the output from an affine transform is added back to
the input z from the previous layer yielding z + Aff(z). Now if we scale all the network
parameters by c, both z and Aff(z) should have the same degree of homogeneity and
scale as ck . Otherwise the network is no longer homogeneous and, hence, cannot be
scale invariant.
In this chapter, we apply the above design philosophy to develop a scale invariant
version of Bert [63] — a transformer based model. A transformer has two main
building blocks that need to be made scale invariant – residual block and Attention [62].
For residual block, Li and Arora [22] already demonstrated how to make both the
PreNorm and PostNorm version of ResNet scale invariant (see Appendix of their
paper for more details). In this chapter, we use their PreNorm variant (see Figure 4.2).
Furthermore, we design a novel scale invariant version of Attention block in transformer,
as described below.

Scale Invariant Attention: Recall the standard self attention block computes the
following for a given input Q, K, V ∈ Rn×dmodel :

QW Q (KW K )>
Attention(Q, K, V ) = Softmax( √ )V W V .
dk

Here W Q , W K ∈ Rdmodel ×dk and W V ∈ Rdmodel ×dv are affine transformations and, hence,
are all 1-homogeneous transformations. The Softmax function computes row wise
softmax normalization. It is easy to see that standard attention is not homogeneous
as softmax is itself not homogeneous.

92
We design a novel Scale Invariant Attention (SI Attention) in the following way:
(also see Figure 4.4)

SI-Attention(Q, K, V ) = N(ReLU(QW Q (KW K )> )V W V ,

a
where N denotes the row-wise normalization by sum, i.e., [N(A)]ij = P ij and
j aij

ReLU(A) denote the element-wise max between matrix A and 0. Notably we replace
the softmax with a ReLU activation followed by normalization. Both ReLU and
normalization are homogeneous operations; thus, making the overall attention score
computation (N(ReLU(ZQK > Z > ))) scale invariant to the concatenation of all param-
eters x, assuming Q, K, V are already positive homogeneous to x. The full design of
Scale Invariant Bert (Sibert) is presented to Section 4.4.

4.3.2 Training Algorithm: SGD + WD

Although scale invariance can prevent parameter divergence after a large gradient
update by eliminating the positive feedback between gradient and parameter norm, it
alone does not ensure SGD trains the network in a robust and efficient way. This is
because, as shown in [36], the parameter norm monotonically increases when SGD is
used to optimize a scale invariant loss. As a result, once the norm becomes too large
(e.g due to large gradient in some step) the training can slow down drastically as the
η
effective LR kxt k22
is too small; thus, preventing effective recovery from even minor
training instabilities.
To tackle this issue we propose to use Weight Decay(WD) as a way to reduce the
parameter norm; thereby, allowing the network to recover from slow training induced
by infrequent updates of large norm. Under mild assumptions that the expectation
of squared norm of stochastic gradient does not vary too much on the unit sphere,
1
[50, 75] show that the parameter norm will stabilize in O( ηλ ) steps and the learning

93
dynamics is equivalent to one on unit sphere with effective learning rate proportional

to Θ( λη).
Leveraging the advantage of quick norm convergence, it is shown in Chapter 3 that
the convergence of SGD+WD is insensitive to the following three operations: loss
rescaling (A1), initialization rescaling (A2) and re-parametrization (A3), meaning the
| log c|
same convergence rate (independent of scaling c) can be achieved, in up to λη
more
steps. (See formal statement in Theorems 3.3.1 and 3.4.2 This property reduces the
effort of hyperparameter tuning and also makes training more robust when switching
between different codebases and frameworks, which is likely to have different default
scaling or parametrization. Also note by scale invariance of loss L, (A2) is equivalent
to (A3).

(A1). L → cL, for any c > 0.

(A2). x(0) → cx(0), for any c > 0.

(A3). (L, x(0)) → (L0 , cx(0)), where L0 is defined as L0 (x) := L( xc ) for any c > 0.

As a comparison, previous work [36] showed that GD converges to  approximate


stationary point of a scale invariant loss in O( 12 ) and SGD converges in O(1/
e 4
) steps
with any initialization. However, the constant in O(·) scales linearly or inversely to the
above scalings (c in (A1-3)). This is far from satisfying, and indeed their experiments
show that either large or small LR could substantially slowdown the training progress.

4.3.3 Relative Global Clipping

Gradient clipping is a widely used effective strategy to stabilize neural network training.
However, often the clipping threshold need to be tuned based on the optimization
problem and the specific gradient distribution. Furthermore, simply using a constant
threshold can severely degrade the performance [54]. Thus, it is unclear how the

94
clipping threshold needs to be set for SGD+WD on scale invariant functions such
that it is insensitive to rescaling of loss and reparametrization, e.g., (A1-3).
To this end, we propose a clipping strategy named Relative Global Clipping which
allows consistent and robust training behavior for SGD+WD on scale invariant loss
under the aforementioned operations. In particular, we propose to set the clipping
q √
threshold as 2Cλ
η
kxk 2 , where C ≥ 1 is a hyperparamer with default value C = 2.
The high level design idea is that (1) the clipping rule should be invariant to the
scalings (L, η, λ) → (cL, η/c, cλ) and (x, η, λ) → (cx, c2 η, λ/c2 ) for any c > 0, to which
SGD+WD is invariant (see Lemma 3.2.4); (2) the clipping rule should only remove
the extremely large gradients and should not trigger too often to ensure that gradient
after clipping remains almost unbiased.
Intuitively, the derivation of Relative Global Clipping involves the following line of
reasoning: Suppose the norm of the stochastic gradient k∇Lγ (x)k2 is constant, say
σ, for all data and every parameter x on the unit sphere. In this case, we expect
our clipping strategy to not be triggered since there are no extremely high stochastic
gradients. Since Lγ is scale invariant, Theorem 3.2.2 implies that h∇Lγ (x), xi = 0.
That is,

kx(t + 1)k22 =(1 − ηλ)2 kx(t)k22 + η 2 k∇Lγ (x(t))k22

=(1 − ηλ)2 kx(t)k22 + η 2 σ 2 / kx(t)k22 . (4.1)

It is not difficult to show the iteration (4.1) has a unique stationary point, kx(t)k22 =
q

λ(2−ηλ)
σ[42]. In other words, at norm equilibrium, it holds

s
σ λ(2 − ηλ)
k∇Lγ (x(t))k2 = = kx(t)k2 . (4.2)
kx(t)k2 η

95
The above calculation suggests the clipping threshold should be at least
q

kx(t)k2 . 1 Furthermore, it is not difficult to check that the clipping
η
q

threshold η
kx(t)k2 is indeed invariant to the above mentioned scalings
(L, η, λ) → (cL, η/c, cλ) and (x, η, λ) → (cx, c2 η, λ/c2 ). For each hyperparame-
ter C > 1, the behavior of SGD+WD is consistent for different scalings (A1-3) and
it also improves the norm convergence (reducing undesirable spikes in norm while
training) for SGD+WD (see Theorem 4.5.3). Under mild assumptions that such
clipping does not introduce too much bias in gradients, we show that our recipe
enables convergence to approximate stationary points. Furthermore, the rate only
depends logarithmically on the initialization and loss scale, as shown in the following
section.

4.4 Design Details of Scale Invariant BERT

Following Section 2.10, we view the computation graph as a directed acyclic graph,
where each module is a node and each tensor (including inputs, intermediate compu-
tation results and final output) as an edge. Each edge can be viewed as a function
of parameters, and we can decide the homogeneity by doing induction over the com-
putation graph by its topological order. In detail, we know the jth output edge
of some (a1 , . . . , an ; b1 , . . . , bn )- homogeneous module is bj homogeneous if for each
1 ≤ i ≤ n, the ith input edge is ai -homogeneous. For convenience, we allow ai ,bi to be
functions of free variable x, meaning the module is (a1 (x), . . . , an (x); b1 (x), . . . , bm (x))-
homogeneous for every x ∈ Rd .
In Table 4.1, we summarize the homogeneity of building blocks in our design.

1
We drop −ηλ for convenience. This doesn’t lead to any practical difference as ηλ is typically
very small, e.g. less than10−4 .
96
Overview of SIBERT structure: Our SIBERT has two main parts — encoder
and classification head, which is the same to standard BERT. We only make encoder
part scale invariant and train it by SGD+WD. We leave the classification head not
scale invariant and train it by Lamb. Note the classification head is only used in
pretraining and is not used in the downstream task.

(2;2)-homogeneous encoder layer: As mentioned in Section 4.4, residual block


and attention are the two main building blocks that needs to be made scale invariant.
Following Li and Arora [22], we choose to use PreNorm structure for residual block
and make it (2; 2)-homogeneous. We also replace GeLU [76] in BERT by ReLU for
homogeneity. Since ReLU is (1; 1) homogeneous, we omit ReLU from the design,
without affecting the final scale invariance.

Table 4.1: Homogeneity of building blocks of SIBERT.

Symbol Module Homogeneity


I Input (;0)
B Adding Bias (1;1)
N Layer Normalization (no affine) (x;0)
L Linear Layer (x;x+1)
Embed Embedding Layer (x;x+1)
NA Layer Normalization with affine (x;1)
FF 2-layer feedforward network (0;2)
ATTN Scale Invariant Attention (x,x,x;x+2)
Encoder Our Encoder Layer (2;2)

Figure 4.1: Encoder and Classification Head (CLS). ‘x12/24’ means to stack 12 our
(2; 2)-homogeneous encoder layer for base SIBERT (or 24 for large SIBERT)

97
Figure 4.2: The (2; 2)-homogeneous encoder layer. ‘ATTN’ denotes our Scale Invariant
Attention (see Figure 4.4). ‘FF’ denotes the 2-layer feedforward structure, which is
(0; 2)-homogeneous.

Figure 4.3: The (0; 2)-homogeneous FeedForward layer

Figure 4.4: The (x, x, x; x + 2)-homogeneous Attention, which is defined as


P Q K > V O
Multi-Head-SI-Attention(Q, K, V ) = i N(ReLU(QWi (KWi ) )V Wi Wi , where
Q
Wi , WiK ∈ Rdmodel ×dk , WiV ∈ Rdk ×dv and WiO ∈ Rdv ×dmodel That is, if Q, K, V are
k-homogeneous functions of parameter x, then Multi-Head-SI-Attention(Q, K, V ) is
k + 2-homogeneous, for any k ∈ R. We also call it Scale Invariant Attention because
its attention score is scale invariant.

98
small init=0.002 medium init=0.02 large init=0.2
8 SIBERT, SGD 8 SIBERT, SGD 8 SIBERT, SGD
SIBERT, SGD+WD SIBERT, SGD+WD SIBERT, SGD+WD
Training Loss

Training Loss

Training Loss
SIBERT,2-clipped SIBERT,2-clipped SIBERT,2-clipped
6 SGD+WD 6 SGD+WD 6 SGD+WD

4 4 4
2 2 2
0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0
Steps 1e6 Steps 1e6 Steps 1e6

Figure 4.5: SGD+WD optimizes the scale invariant training loss of Sibert robustly
for all initialization scales, and thus for loss scalings and different learning rates
(with λη fixed). Here the default initialization for parameters in Sibert encoder is
a truncated normal distribution with standard deviation equal to 0.02 (the same as
Bert).

4.5 Convergence of SGD with Relative Global

Clipping

Now we will present our analysis for the clipped SGD. Recall the clipped SGD update
from Algorithm 2 has the following norm dynamics.
Norm dynamics of clipped SGD:
( )
k∇Lγ (x(t))k22 2λC
kx(t + 1)k22 = (1 − ηλ)2 kx(t)k22 + η 2 min 2 , kx(t)k22 .
kx(t)k2 η

To present our bound we need the following definitions.

Definition 4.5.1 (C-clipped mean). Given a distribution P on R≥0 and constant


C > 1, we define FP,C (µ) = Et∼P [min{t, Cµ}], and define the C-clipped mean of P ,
µP,C as the largest positive real number satisfying that FP,C (CµP,C ) = µP,C . Such a
definition is valid because FP,C (0) = 0 and thus 0 is always a solution.
For convenience, we also define GP,C (µ) := FP,C (Cµ) − µ and MP, 1 is defined
C

as the C median of P , that is, MP,C := sup M ≥ 0 | Pt∼P [t ≥ M ] ≥ C1 . Since


1


the cumulative density function Pt∼P [t ≥ M ] is left continuous in M , it holds that


1
Pt∼P [t ≥ MP,C ] ≥ C
.

99
Let Px denote the distribution of k∇Lγ (x)k22 . Below is a mild assumption saying
Px is universally well-concentrated from below in the sense that the mean of the
smallest (1 − C1 ) part of Px is at least a constant fraction of the C-clipped mean of Px .
Since µPx ,C ≤ µx , the assumption below holds whenever αC µx ≤ Et∼Px [t1[t < MPx , 1 ]].
C

Assumption 4.5.2. ∃αC > 0, such that for all x 6= 0, αC · µPx ,C ≤ Et∼Px [t1[t <
MPx , 1 ]].
C

We further define µC := min µPx ,C and µC := max µPx ,C and have the following
kxk2 =1 kxk2 =1
theorem:
√ √
Theorem 4.5.3 ( C-Clipped SGD+WD). Let x(t) be defined by C-Clipped SGD
+WD (Algorithm 2). Under Assumption 4.5.2, for ηλ = O(min{1, C lnαTC/δ2 }), with
probability 1 − 5δ, we have

µC 2λ
∀T 0 ≤ t ≤ T − 1, ≤ kx(t)k42 ≤ 2µC . (4.3)
2 η

and

T −1 D 2
√ 3
1 X E π ρ µ C
p ρµ 2
C
∇L(x(t)), ∇L(x(t))
g ≤ √ + 4 ηλ
T − T 0 t=T 0 (T − T 0 ) 2ηλ µC
s s (4.4)
2 2 2 3
ln δ πρµC ln δ p ρµ
+ 0
8 + 0
16 λη 2C .
T −T µC T −T µC

kxk22
n o h nq oi
R2 µ
where T 0 = 1
αC ηλ
max ln µ 0 , ln RC2 +O(1) and ∇L(x)
g := E ∇Lγ (x) min 2Cλ
η k∇Lγ (x)k
, 1 .
C 0 2

The proof of this theorem is presented in Section 4.7. Note that with clipping
Theorem 4.5.3 shows that the norm convergence (4.3) is more robust as it doesn’t need
to make any assumption about the maximum gradient norm M , unlike Theorem 3.4.2.
Indeed, from the definition of C-clipped mean, for each x, we can allow all the
gradients with norm larger than C · µPx ,C to become infinity, and yet not affect the
norm convergence, as µPx ,C and the condition in Assumption 4.5.2 do not change.
100
107
small init=0.002 107
medium init=0.02 107
large init=0.2
SIBERT, SGD SIBERT, SGD
SIBERT, SGD+WD SIBERT, SGD+WD
106 SIBERT,2-clipped 106 SIBERT,2-clipped 106
SGD+WD SGD+WD SIBERT, SGD
Norm Sqaure

Norm Sqaure

Norm Sqaure
SIBERT, SGD+WD
105 105 105 SIBERT,2-clipped
SGD+WD
104 104 104

103 103 103


0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0
Steps 1e6 Steps 1e6 Steps 1e6

Figure 4.6: The robust optimization performance of SGD+WD over the scale invariant
training loss of Sibert originates from its ability to fast adjust the parameter norm.
In contrast, when the initial norm is too large, SGD w.o. WD optimizes slowly.
Relative Global Clipping reduces the spikes in the norm curve, which verifies our
theoretical result Theorem 4.5.3 that clipping leads to better norm convergence. Here,
only the norm of the scale invariant part, i.e., the encoder part is plotted.

D E
Under the additional assumption that ∇L(x(t)), ∇L(x(t)
g = Ω(k∇L(x(t))k22 ), we
can use Equation (4.4) to show convergence to stationary points. This is a reasonable
assumption if the clipping frequency is low, e.g., it’s 1.5% in our experiments for
Sibert.

4.6 Experiments

We now conduct a comprehensive empirical study in order to demonstrate the following


key aspects of our recipe: (i) yields competitive training performance using significantly
low memory footprint, (ii) training becomes highly robust to initialization scale, and
(iii) provides better convergence of norm with clipping.

Experimental Setup. We consider the standard task of pretraining a transformer


model and fine-tuning it on benchmark datasets, following Devlin et al. [63]. We
compare its performance with Sibert, a scale invariant version of Bert as described
in Sec. 4.3.1. For both these models, we use their base size versions unless specified
otherwise. For Sibert, the scale invariant portion is trained using SGD+WD with a
piecewise constant LR schedule and WD of 1e − 2. We use Lamb optimizer for the

101
5.0
4.5
4.0
3.5

Training Loss
3.0 BERT, SGD, small LR
SIBERT,2-clipped
2.5 SGD+WD
SIBERT, AdamW
2.0 BERT, AdamW

1.5
1.00.0 0.2 0.4 0.6 0.8 1.0
Steps 1e6

Figure 4.7: Our recipe (Sibert, SGD+WD and Relative Global Clipping) significantly
improves the optimization performance compared to the baseline, Bert trained by
SGD with small LR. The final training loss is close to Bert trained by Adam.

non-scale invariant parts. The initial LR for SGD is 8e − 4 without warmup and is
divided by 10 at step 600k and 900k. Default training is for 1M steps. For Lamb we
use a linear decay schedule with initial learning rate 8e − 4 and a linear warmup of
10k steps.

Performance. We begin by establishing that proposed Sibert with SGD+WD


training performs competitively. In this regard, we first look at pretraining loss between
standard training of Bert with Adam and our Sibert trained by SGD+WD with

or without clipping (the clipping factor is set as C = 2). From Figure 4.7, one can
see that our training curve closely follows that of Bert trained by Adam, but without
the need for extra memory for keeping track of first and second order momentum. If
we use SGD on standard Bert architecture, then either we have to use small learning
rates, which slows down training, or the loss diverges. This further highlights the
importance of the scale invariant architecture, which improves training stability by
eliminating the k-homogeneous structure. To our knowledge, this is the first work
that shows effective training of Bert-like model using simple SGD (even without any
momentum).

102
Next, we compare the downstream performance on three benchmark datasets
(SQuADv1.1 [77], SQuADv2 [78] and MNLI [79]). We tried to follow standard setup,
e.g. Bert is finetuned by Adam. However for Sibert we had to use LAMB, as
Adam is very sensitive to the scale. We observe comparable performance and when
trained longer it can even outperform conventional Bert.

Table 4.2: Downstream Performance of Sibert trained by SGD+WD +clipping is


close to that of Bert trained Adam- which uses 3X more memory than SGD. The
gap is further reduced by doubling the training budget of Sibert.

MNLI SQuAD1 SQuAD2 Pretraining


Acc F1 F1 Loss
Bert 84.4 90.3 78.8 1.479
81.1 88.1 74.8 1.672
Base

Sibert
+ clipping 82.6 89.3 76.8 1.58
+ 2x training 83.3 90.3 80.0 1.495
Bert 86.8 92.4 84.1 1.181
Large

Sibert 83.7 90.6 79.3 1.404


+ clipping 85.3 91.6 81.3 1.322
+ 2x training 86.4 92.4 83.1 1.194

Training Stability: Insensitivity to the scale of initialization. To showcase


ease of optimization offered by our recipe, we consider different initialization scales
spanning two orders of magnitude. The results for the pretraining task in Figure 4.5
show good convergence across the board for our approach, whereas SGD on its own
struggles even with the scale invariant architecture.
Further note that these experiments simultaneously showcase robustness to rescal-
ing of loss, parameterization, or LR. This is because in a scale invariant model trained
by SGD+WD (+clipping), it holds that all of following scalings are equivalent:
(c1 L, c2 x(0), c3 η, c4 λ) ←→ (L, √cc12c3 x(0), η, c3 c4 λ) for any c1 , c2 , c3 , c4 > 0.

Training Stability: Improvement in parameter norm convergence. Finally,


we look at parameter norms during training in experiments. We observe that even

103
when starting from very different initialization scale, SGD+WD (+clipping) quickly
brings parameter norm to desired ranges. In contrast, SGD struggles when initial
norm and learning rate are not aligned - see the rightmost plot with large initialization
in Figure 4.6. This shows that our recipe has the ability to quickly adapt to different
initialization scales, in-line with our theoretical result (Theorem 4.5.3) showing better
norm convergence of SGD+WD (+clipping).

4.7 Proofs for Convergence of SGD with Relative

Global Clipping

Norm dynamics of clipped SGD:


( )
k∇Lγ (x(t))k22 2λC
kx(t + 1)k22 = (1 − ηλ)2 kx(t)k22 + η 2 min , kx(t)k22 . (4.5)
kx(t)k22 η

Lemma 4.7.1 (General Properties of GP,C ). For any C > 1 and measure P supported
on R≥0 , it holds that

1. GP,C is continuous and concave;

2. supµ≥0 GP,C (µ) = GP,C ( C1 MP, 1 );


C

1
3. C
MP, 1 ≤ µP,C ≤ µP , where µP is the expectation of P .
C

Proof of Lemma 4.7.1. (1). Note min{x, ·} is a continuous and concave function
for any x, we know GP,C is a concave function. (2). When GP,C is differentiable,
we have G0P,C (µ) = CFP,C
0
(Cµ) − 1. Let G0P,C (µ) = 0 implies that FP,C
0
(Cµ) =
0
1
C
. Note FP,C (Cµ) = Pt∼P [t > FP,C ], we know G0P,C ( C1 MP, 1 ) = 0. By concavity,
C

supµ≥0 GP,C (µ) = GP,C ( C1 MP, 1 ). This argument can be easily generalized to non-
C

differentiable case by using GP,C (µ) must be larger than GP,C (µ ± δ) for infinitesimal

104
δ. (3). First note that FP,C (MP, 1 ) = Et∼P [min{t, MP, 1 }] ≥ MP, 1 · Pt∼P [t ≥ MP, 1 ] =
C C C C

1
C
MP, 1 . In other words, GP,C ( C1 MP, 1 ) ≥ 0.
C C

1
Now suppose C
MP, 1 > µP,C . If GP,C ( C1 MP, 1 ) = 0, then by definition, 1
C
MP, 1 ≤
C C C

µP,C . If GP,C ( C1 MP, 1 ) > 0, by concavity, GP,C (µP,C ) > 0, contradiction!


C

Theorem 4.7.2. [Classifications of solutions of FP,C (Cµ) = µ]

1. If P[x = 0] < 1 − C1 , then FP,C (Cµ) = µ has exact two solutions which are 0 and
µP,C > 0;

1 1
2. If P[x = 0] = 1 − C
, then FP,C (Cµ) = µ for all 0 ≤ µ ≤ C
MP,C and µP,C =
1
C
MP,C ;

3. If P[x = 0] > 1 − C1 , then FP,C (Cµ) = µ has only one solution which is µP,C = 0.

Proof. Suppose there are two solutions 0 < µ1 < µ2 . By concavity, we have ∀0 ≤ µ ≤
µ2 , GP,C (µ) = 0. Thus 0 = GP,C (0) + GP,C (µ2 ) = 2g( µ22 ), which implies that

Cµ2
Et∼P [min{t, Cµ2 }] = 2Et∼P [min{t, }] = Et∼P [min{2t, Cµ2 }],
2

that is, Pt∼P [t ≥ Cµ2 ∨ t = 0] = 1. Thus for any 0 ≤ µ ≤ µ2 , we have GP,C (µ) =
1
CµP[x ≥ Cµ2 ] − µ = 0, which implies µ2 = C
MP, 1 and P[x = 0] = 1 − C1 !
C

Lemma 4.7.3. Under Assumption 4.5.2, it holds that GP,Cx ( C1 MPx , 1 ) ≥ αC µPx ,C for
C

all x 6= 0.

Proof of Lemma 4.7.3. By definition,

1 1
GP,Cx ( MPx , 1 ) = Et∼Px [t1[t < MPx ,C ]] + (Pt∼Px [t ≥ MPx ,C ] − ) · MPx ,C . (4.6)
C C C

1
By the definition of the C
-median, the second term is non-negative. The proof is
completed by applying Assumption 4.5.2.

105
Lemma 4.7.4 (Lower and upped bounds for GPx ,C ). Under Assumption 4.5.2, it
holds that
µPx ,C
1. GPx ,C (µ) ≥ αC µ, for 0 ≤ µ ≤ 2
;

µPx ,C
2. GPx ,C (µ) ≥ αC (µPx ,C − µ), for 2
≤ µ ≤ µPx ,C ;

3. GPx ,C (µ) ≤ −αC (µ − µPx ,C ), for µ ≥ µPx ,C .

Proof of Lemma 4.7.4. By Lemma 4.7.3, Assumption 4.5.2 implies that GP,Cx ( C1 MPx , 1 ) ≥
C

6 0. Further note that GP,Cx (0) = GP,Cx (µPx , C) = 0. The claims


αC µPx ,C for all x =
(a), (b) and (c) are immediate by concavity of GP,Cx .

The above inequalities also directly imply the following version using µC and µC
as thresholds.

Lemma 4.7.5 (Uniform Lower and upped bounds for GPx ,C ). Under Assumption 4.5.2,
it holds that for kxk2 = 1,
µC
1. GPx ,C (µ) ≥ αC µ, for 0 ≤ µ ≤ 2
;

µC
2. GPx ,C (µ) ≥ αC (µC − µ), for 2
≤ µ ≤ µC ;

3. GPx ,C (µ) ≤ −αC (µ − µC ), for µ ≥ µC .

αC µ 4µC
4. GPx ,C (µ) ≥ 4
, for 0 ≤ µ ≤ 5
; (4. follows from Property 1. and 2.)

For convenience, we define Rt := kx(t)k22 , gt := k∇Lγt (x(t))k22 , gbt :=



η
2
k∇Lγt (x(t))k2
min{CRt , gt }, get := Rt gbt = min{CRt2 , k∇Lγt (x(t))k22 } and g t := Rgbtt = min{C, Rt
}.
gt | x(t)] = µPx(t) ,C . We further define βl := 1 − 2λ2 η 2 + η 4 λ4 −
Thus we have E[b
4ηλαC (1 − ηλ)2 = 1 − 4ηλαC + O(η 2 λ2 ) and βu := 1 − 2λ2 η 2 + η 4 λ4 − 4ηλαC (1 −
ηλ)2 + 4C 2 η 2 λ2 = 1 − 4ηλαC + O(η 2 λ2 ).
Given an integer T ≥ 0, let ET1 be the event that ∀0 ≤ t0 ≤ t ≤ T,

s
t
X h i √ 1 2T 2
βl t−s (e gs | x(s)]) 1 Rs2 ≤ µC
gs − E[e ≤ CµC ln .
s=t0
1 − βl 2 δ
106
Let ET2 be the event that ∀0 ≤ t0 ≤ t ≤ T,

s
t
X √ 1 2T 2
βl t−s (e gs | x(s)]) 1 Rs2 ≤ 2µC
 
gs − E[e ≤ 2 CµC ln .
s=t0
1 − βl 2 δ

Let ET3 be the event that ∀0 ≤ t0 ≤ t ≤ T,

t
r
X 2T 2
g s − E[g s | x(s)] ≤ C T ln .
s=t0
δ

Lemma 4.7.6. P[ETi ] ≥ 1 − δ, for i = 1, 2, 3.

Proof of Lemma 4.7.6. Note the sequence in ETi are martingales whose differences
are uniformly bounded (µC , µC and C). The lemma follows directly from Hoeffding
Inequality and Azuma Inequality.

Theorem 4.7.7 (Norm lower bound with clipping: Warm Start). Suppose Assump-
tion 4.5.2 holds, with probability at least 1 − δ (or whenever ET1 holds), if Rt2 ≥ 34 µC ,
then for any t0 ≥ t, we have

s !
t0 −t
βl p 2C T2
Rt20 ≥ 1− − O( ηλ) − ηλ ln (1 + O(ηλ)) µC (4.7)
4 αC δ

µC
Proof. We first claim for any t ≤ t0 ≤ T , conditioned on ET1 , it holds that Rt20 ≥ 2
.
µC
Below we prove by contradiction. If not, let t0 be the smallest step such that Rt20 < 2
.
We let t∗ be the largest step between t and t0 such that Rt2∗ ≥ µC (t∗ = t − 1 is no
such t∗ exists) Thus if t∗ ≥ t then Rt2∗ +1 is at least (1 − ηλ)4 Rt2 = (1 − O(ηλ))µC .

Otherwise t∗ = t and it implies that Rt2∗ +1 = Rt2 = ( 34 − O( ηλ))µC . By the definition,
we know for any t∗ + 1 ≤ s ≤ t0 , Rs2 ≤ µC .

107
Similar to Equation (3.10), we have

2
Rs+1 =Rs2 (1 − ηλ)4 + 4ηλ(1 − ηλ)2 ges + 4η 2 λ2 get2

≥Rs2 ((1 − ηλ)4 + 4ηλ(1 − ηλ)2 + 4C 2 η 2 λ2 ) (4.8)

+4ηλ(1 − ηλ)2 (E[e


gs | x(s)] − Rs2 ) + 4ηλ(1 − ηλ)2 (e
gs − E[e
gs | x(s)])

Thus for any s such that µC ≤ Rs2 ≤ 2µC , by Lemma 4.7.5, it holds that

GPx(s) ,C (Rs2 ) = E[e


gs | x(s)] − Rs2 ≤ αC (µC − Rs2 ).

Thus, we have that

2
Rs+1 ≥Rs2 (1 − 2η 2 λ2 + η 4 λ4 )

+4ηλαC (1 − ηλ)2 (µC − Rs2 ) + 4ηλ(1 − ηλ)2 (e


gs − E[e
gs | x(s)])

=βl Rs2 + 4ηλαC (1 − ηλ)2 µC + 4ηλ(1 − ηλ)2 (e


gs − E[e
gs | x(s)]).

That is,

2
4ηλαC (1 − ηλ)2 µC
Rs+1 −
1 − βl
4ηλαC (1 − ηλ)2 µC
≥βl (Rs2 − )
1 − βl
+4ηλ(1 − ηλ)2 (e
gs − E[e
gs | x(s)])

108
Applying the above inequality for s = t∗ + 1, . . . , t0 − 1, we have that
!
t0 −t∗ −1
4ηλαC (1 − ηλ)2 µC
Rt20 ≥ βl Rt2∗ +1 −
1 − βl
| {z }
(A)

4ηλαC (1 − ηλ)2 µC
+
1 − βl
| {z }
(B)
t 0
X h i
2 t−s
+ 4ηλ(1 − ηλ) βl gs − E[e
(e gs | x(s)]) 1 Rs2 ≤ µC .
s=t∗ +1
| {z }
(C)

For term (B), we have 1 − βu = 4ηλαC (1 − ηλ)2 (1 + O(ηλ)) and thus (B) =
0 ∗ √
µC (1 + O(ηλ)). Since Rt∗ +1 ≥ 43 µC , it holds that (A) ≥ −βl t −t −1 ( 14 + O( λη))µC ≥

−( 41 + O( λη))µC . Since ET1 holds, we have

s s
√ 1 2T 2 2C T2
|(C)| ≤ 4ηλ(1 − ηλ)2 · CµC ln = µC ηλ ln (1 + O(ηλ))
1 − βl 2 δ αC δ

Thus there’s some constant ι, such for ηλ ≤ min{ι, 64C αlnCT 2 /δ }, (A) + (B) + (C) ≥
√ √ µ
( 6−8 2 − O( ηλ))µC ≥ 2C . This leads to a contradiction to the definition of t0 . Thus
µ
for any t ≤ t0 ≤ T , conditioned on ET1 , it holds that Rt20 ≥ 2C . Furthermore, if t∗ 6= t,
√ √
then Rt∗ +1 ≥ (1 − O( ηλ))µC . Thus (A) ≥ −O( ηλ)µC . Otherwise if t∗ = t, then
0 √
(A) ≥ −βl t −t ( 14 + O( λη))µC . Combine the bounds in these two cases, we conclude
that
s !
t0 −t
βl p 2C T2
Rt20 ≥ 1− − O( ηλ) − ηλ ln (1 + O(ηλ)) µC
4 αC δ

Theorem 4.7.8 (Norm upper bound with clipping: Warm Start). Suppose Assump-
tion 4.5.2 holds, with probability at least 1 − δ (or whenever ET2 holds), if Rt2 ≤ 32 µC ,

109
then for any t0 ≥ t, we have

0
s !
βl t −t p 2C T2
Rt20 ≤ 1+ + O( ηλ) + ηλ ln (1 + O(ηλ)) µC
2 αC δ

Proof of Theorem 4.7.8. We first claim for any t ≤ t0 ≤ T , conditioned on ET2 , it


holds that Rt20 ≤ 2µC . Below we prove by contradiction. If not, let t0 be the
largest step such that Rt20 > 2µC . We let t∗ be the largest step between t and t0
such that Rt2∗ ≤ µC (t∗ = t − 1 is no such t∗ exists) Thus if t∗ ≥ t then Rt2∗ +1 is
at most (1 + 2Cηλ)2 Rt2 = (1 + 2Cηλ)2 µC . Otherwise t∗ = t and it implies that
Rt2∗ +1 = Rt2 ≤ 32 µC . By the definition, we know for any t∗ + 1 ≤ s ≤ t0 , Rs2 ≥ µC .
Similar to Equation (3.10), we have

2
Rs+1 ≤Rs2 (1 − ηλ)4 + 4ηλ(1 − ηλ)2 ges + 4η 2 λ2 gbs2

≤Rs2 ((1 − ηλ)4 + 4ηλ(1 − ηλ)2 + 4η 2 λ2 C 2 ) (4.9)

+4ηλ(1 − ηλ)2 (E[e


gs | x(s)] − Rs2 ) + 4ηλ(1 − ηλ)2 (e
gs − E[e
gs | x(s)])

Thus for any s such that µC ≤ Rs2 , by Lemma 4.7.5, it holds that

GPx(s) ,C (Rs2 ) = E[e


gs | x(s)] − Rs2 ≥ αC (µC − Rs2 ).

Thus, we have that

2
Rs+1 ≤Rs2 (1 − 2η 2 λ2 + η 4 λ4 + 4η 2 λ2 C 2 )

+4ηλαC (1 − ηλ)2 (µC − Rs2 ) + 4ηλ(1 − ηλ)2 (e


gs − E[e
gs | x(s)])

=βu Rs2 + 4ηλαC (1 − ηλ)2 µC + 4ηλ(1 − ηλ)2 (e


gs − E[e
gs | x(s)]).

110
That is,

2 4ηλαC (1 − ηλ)2 µC
Rs+1 −
1 − βu
4ηλαC (1 − ηλ)2 µC
≤βu (Rs2 − ) + 4ηλ(1 − ηλ)2 (e
gs − E[e
gs | x(s)])
1 − βu

Applying the above inequality for s = t∗ + 1, . . . , t0 − 1, we have

4ηλαC (1 − ηλ)2 µC
 
t0 −t∗ −1
Rt20 ≤ βu Rt2∗ +1 −
1 − βu
| {z }
(A)
2
4ηλαC (1 − ηλ) µC
+
1 − βu
| {z }
(B)
t 0
X
2
βu t−s (e gs | x(s)]) 1 Rs2 ≤ 2µC .
 
+ 4ηλ(1 − ηλ) gs − E[e
s=t∗ +1
| {z }
(C)

For term (B), we have 1 − βu = 4ηλαC (1 − ηλ)2 (1 + O(ηλ)) and thus (B) =
0 ∗ √
µC (1 + O(ηλ)). Since Rt∗ +1 ≤ 23 µC , it holds that (A) ≤ βu t −t −1 ( 12 + O( λη))µC ≤

( 21 + O( λη))µC . Since ET2 holds, we have that

s s
√ 1 2T 2 2C T2
|(C)| ≤ 8ηλ(1 − ηλ)2 · CµC 2 ln = 2µC ηλ ln (1 + O(ηλ))
1 − βu δ αC δ

Thus there’s some constant ι, such for ηλ ≤ min{ι, 64C αlnCT 2 /δ }, (A) + (B) + (C) ≤
√ √
( 6+4 2 + O( ηλ))µC ≤ 2µC . This leads to a contradiction to the definition of t0 . Thus
for any t ≤ t0 ≤ T , conditioned on ET1 , it holds that Rt20 ≥ 2µC . Furthermore, if t∗ 6= t,
√ √
then Rt∗ +1 ≤ (1 + O( ηλ))µC . Thus (A) ≤ O( ηλ)µC . Otherwise if t∗ = t, then
0 √
(A) ≤ βu t −t ( 12 + O( λη))µC . Combine the bounds in these two cases, we conclude

111
that
s !
t0 −t
βl p 2C T2
Rt20 ≤ 1+ + O( ηλ) + ηλ ln (1 + O(ηλ)) µC
2 αC δ

Theorem 4.7.9 (Norm Convergence of clipped SGD). Suppose Assumption 4.5.2


holds, for ηλ = O(min{1, C lnαTC/δ2 }), with probability 1 − 3δ (when ET1 ,ET2 and ET3
2
 
R0 µ
max ln ,ln C +O(1)
µC R02
happens), there is a T 0 = αC ηλ
, such that for all T 0 ≤ t ≤ T , we have

µC
≤ Rt2 ≤ 2µC .
2

More concretely, we have that

0 0
p p
Rt2 ∈ [(1 − βlt−T )µC − O(
e λη), µC (1 + βut−T ) + O(
e λη)].

Proof of Theorem 4.7.9. We will prove the desired inequality always holds when ETi
holds, for i = 1, 2, 3. We have already proved the result for the case where 34 µC ≤
Rt2 ≤ 32 µC in Theorems 4.7.7 and 4.7.8. Now we turn to the case where R02 ≥ 32 µC
and R02 ≤ 12 µC . Our goal is to prove with high probability, that Rt2 ∈ [ 34 µC , 23 µC ] for
at least some t < T 0 .
Below we first show ∃0 < t < T 0 , Rt2 ≤ 23 µC . Otherwise, similar to Equation (4.9),

2
Rs+1 ≤Rs2 (1 − ηλ)4 + 4ηλ(1 − ηλ)2 ges + 4η 2 λ2 gbs2

≤Rs2 ((1 − ηλ)4 + 4ηλ(1 − ηλ)2 + 4η 2 λ2 C 2 ) (4.10)

+4ηλ(1 − ηλ)2 (E[e


gs | x(s)] − Rs2 ) + 4ηλ(1 − ηλ)2 (e
gs − E[e
gs | x(s)])

112
Thus for any s such that 23 µC ≤ Rs2 , by Lemma 4.7.5, it holds that

αC 2
GPx(s) ,C (Rs2 ) = E[e
gs | x(s)] − Rs2 ≥ αC (µC − Rs2 ) ≥ − R .
3 s

Thus,

2
Rs+1 ≤Rs2 (1 − 2η 2 λ2 + η 4 λ4 + 4η 2 λ2 C 2 )
4
− ηλαC (1 − ηλ)2 Rs2 + 4ηλ(1 − ηλ)2 (e
gs − E[e
gs | x(s)])
3  
2 2 2 4 4 2 2 2 4 2 2
=Rs 1 − 2η λ + η λ + 4η λ C − ηλαC (1 − ηλ) + 4ηλ(1 − ηλ) (g s − E[g s | x(s)])
3

Note that g s ≤ C, we have

2 4
ln Rs+1 − ln Rs2 ≤ − ηλαC + ηλ(g s − E[g s | x(s)]) + O(η 2 λ2 )
3

Since we assume ∀0 ≤ t ≤ T 0 , Rt2 ≥ 23 µC , conditioned on ET3 , we have

r
3 2 2 2 4T 2T 2
ln + ln µC − ln R0 ≤ ln RT 0 − ln R0 ≤ − ηλαC + Cηλ T ln + O(η 2 λ2 T ),
4 3 δ

2
 
R0 µ
max ln ,ln C +O(1)
µC R02
which is in contradiction with the definition of T 0 = αC ηλ
.
Now we show ∃0 < t < T 0 , Rt2 ≥ 43 µC . Otherwise, similar to Equation (4.9),

2
Rs+1 =Rs2 (1 − ηλ)4 + 4ηλ(1 − ηλ)2 ges + 4η 2 λ2 get2

≥Rs2 ((1 − ηλ)4 + 4ηλ(1 − ηλ)2 + 4C 2 η 2 λ2 ) (4.11)

+4ηλ(1 − ηλ)2 (E[e


gs | x(s)] − Rs2 ) + 4ηλ(1 − ηλ)2 (e
gs − E[e
gs | x(s)])

Thus for any s such that Rs2 ≤ 54 µC , by Lemma 4.7.5, it holds that

αC 2
GPx(s) ,C (Rs2 ) = E[e
gs | x(s)] − Rs2 ≥ R .
4 s

113
Thus, we have that

2
Rs+1 ≥Rs2 (1 − 2η 2 λ2 + η 4 λ4 )

+ηλαC (1 − ηλ)2 Rs2 + 4ηλ(1 − ηλ)2 (e


gs − E[e
gs | x(s)])

=Rs2 1 − 2η 2 λ2 + η 4 λ4 + ηλαC (1 − ηλ)2 + 4ηλ(1 − ηλ)2 (g s − E[g s | x(s)])




Note that g s ≤ C, we have that

2
ln Rs+1 − ln Rs2 ≥ ηλαC + ηλ(g s − E[g s | x(s)]) + O(η 2 λ2 )

Since we assume ∀0 ≤ t ≤ T 0 , Rt2 ≥ 32 µC , conditioned on ET3 , we have

r
2T 2
ln µC − ln R02 ≥ ln RT2 0 − ln R02 ≥ T ηλαC − Cηλ T ln + O(η 2 λ2 T ),
δ

2
 
R0 µ
max ln ,ln C +O(1)
µC R02
which is in contradiction with the definition of T 0 = αC ηλ
.

Proof of Theorem 4.5.3. The proof of Algorithm 2 is almost identical to that of


Theorem 3.4.2, except replacing M by 2µC , σ by µC , σ by µC since the clipped
stochastic gradient has smaller maximum norm, maximum covariance and smaller
covariance.

114
Part II

Implicit Bias Along Entire


Optimization Trajectory

115
Chapter 5

Why do ConvNets Generalize


Better Than Fully-connected Nets?

Convolutional neural networks often dominate fully-connected counterparts in gener-


alization performance, especially on image classification tasks. This is often explained
in terms of “better inductive bias.” However, this has not been made mathematically
rigorous, and the hurdle is that the sufficiently wide fully-connected net can always
simulate the convolutional net. Thus the training algorithm plays a role.
This chapter describes a natural task on which a provable sample complexity gap can
be shown, for standard training algorithms. We construct a single natural distribution
on Rd × {±1} on which any orthogonal-invariant algorithm (e.g. fully-connected
networks trained with most gradient-based methods from gaussian initialization)
requires Ω(d2 ) samples to generalize while O(1) samples suffice for convolutional
architectures. Furthermore, we demonstrate a single target function, learning which on
all possible distributions leads to an O(1) vs Ω(d2 /ε) gap. The proof relies on the fact
that SGD on fully-connected network is orthogonal equivariant. Similar results are
achieved for `2 regression and adaptive training algorithms, e.g. Adam and AdaGrad,
which are only permutation equivariant.

116
5.1 Introduction

Deep convolutional nets (“ConvNets”) are at the center of the deep learning revo-
lution [48, 80, 81]. For many tasks, especially in vision, convolutional architectures
perform significantly better their fully-connected (“FC”) counterparts, at least given
the same amount of training data. Practitioners explain this phenomenon at an
intuitive level by pointing out that convolutional architectures have better “inductive
bias”, which intuitively means the following: (i) ConvNet is a better match to the
underlying structure of image data, and thus are able to achieve low training loss with
far fewer parameters (ii) models with fewer total number of parameters generalize
better.
Surprisingly, the above intuition about the better inductive bias of ConvNets over
FC nets has never been made mathematically rigorous. The natural way to make
it rigorous would be to show explicit learning tasks that require far more training
samples on FC nets than for ConvNets. (Here “task”means, as usual in learning theory,
a distribution on data points, and binary labels for them generated given using a fixed
labeling function.) Surprisingly, the standard repertoire of lower bound techniques in
ML theory does not seem capable of demonstrating such a separation. The reason is
that any ConvNet can be simulated by an FC net of sufficient width, since a training
algorithm can just zero out unneeded connections and do weight sharing as needed.
Thus the key issue is not an expressiveness per se, but the combination of architecture
plus the training algorithm. But if the training algorithm must be accounted for, the
usual hurdle arises that we lack good mathematical understanding of the dynamics of
deep net training, whether FC or ConvNet. How then can one establish such limitation
of “FC nets + current training algorithms”? (Indeed, many lower bound techniques
in PAC learning theory are information theoretic and ignore the training algorithm.)
The current paper makes significant progress on the above problem by exhibiting
simple tasks that require Ω(d2 ) factor more training samples for FC nets than for
117
ConvNets, where d is the data dimension. (In fact this is shown even for 1-dimensional
ConvNets; the lowerbound easily extends to 2-D ConvNets.) The lower bound holds
for FC nets trained with vanilla SGD with Gaussian initialization of network weights,
with the optional use of momentum, `2 regularization, and various learning rate
schedules. Our proof relies on the fact that these popular algorithms lead to an
orthogonal-equivariance property on the trained FC nets, which says that at the end
of training the FC net —no matter how deep or how wide — will make the same
predictions even if we apply orthogonal transformation on all datapoints (i.e., both
training and test). This notion is inspired by Ng [82] (where it is named “orthogonal
invariant”), which showed the power of logistic regression with `1 regularization versus
other learners. For a variety of learners (including kernels and FC nets) that paper
described explicit tasks where the learner has Ω(d) higher sample complexity than
logistic regression with `1 regularization. The lower bound example and technique
can also be extended to show a (weak) separation between FC nets and ConvNets.
(See Section 5.5.2)
Our separation is quantitatively stronger than the results by Ng [82] because the
sample complexity gap is Ω(d2 ) vs O(1), and not Ω(d) vs O(1). But in a more subtle
way our result is conceptually far stronger: the technique by Ng [82] seems incapable
of exhibiting a sample gap of more than O(1) between Convnets and FC nets in our
framework. The reason is that the technique by Ng [82] can exhibit a hard task for FC
nets only after fixing the training algorithm. But there are infinitely many training
algorithms once we account for hyperparameters associated in various epochs with
LR schedules, `2 regularizer and momentum, etc.. Thus the technique by Ng [82]
cannot exclude the possibility that the hard task for “FC net + Algorithm 1” is easy
for “FC net + Algorithm 2”. Note that we do not claim any issues with the results
claimed by Ng [82]; merely that the technique cannot lead to a proper separation
between ConvNets and FC nets, when the FC nets are allowed to be trained with any

118
1.0 Gauss 1.0 cifar-10

0.9 0.9

0.8 0.8

test acc

test acc
2-layer cnn w/ quadratic
3-layer cnn w/ relu
0.7 0.7 resnet14 cnn
hybrid w/ quadratic
hybrid w/ relu
2-layer fc w/ quadratic
0.6 0.6 3-layer fc w/ quadratic
3-layer fc w/ relu
3-layer fc w/ relu + bn
0.5 2 3 4 5 6
0.5 2 3 4 5
10 10 10 10 10 10 10 10 10
# training data # training data

Figure 5.1: Comparison of generalization performance of convolutional versus fully-connected


models trained by SGD. The grey dotted lines indicate separation, and we can see convo-
lutional networks consistently outperform fully-connected networks. Here the input data
are 3 × 32 × 32 RGB images and the binary label indicates for each image whether the
first channel has larger `2 norm than the second one. The input images are drawn from
entry-wise independent Gaussian (left) and CIFAR-10 (right). In both cases, the 3-layer
convolutional networks consist of two 3 × 3 convolutions with 10 hidden channels, and a 3 × 3
convolution with a single output channel followed by global average pooling. The 3-layer
fully-connected networks consist of two fully-connected layers with 10000 hidden channels
and another fully-connected layer with a single output. The 2-layer versions have one less
intermediate layer and have only 3072 hidden channels for each layer. The hybrid networks
consist of a single fully-connected layer with 3072 channels followed by two convolutional
layers with 10 channels each. bn stands for batch-normalization [18].

of the infinitely many training algorithms. (Section 5.5.2 spells out in more detail the
technical difference between our technique and Ng’s idea.)
The reader may now be wondering what is the single task that is easy for ConvNets
but hard for FC nets trained with any standard algorithm? A simple example is the
following: data distribution in Rd is standard Gaussian, and target labeling function
is the sign of d/2
P 2
Pd 2
i=1 zi − i=d/2+1 zi . Figure 5.1 shows that this task is indeed much

more difficult for FC nets. Furthermore, the task is also hard in practice for data
distributions other than Gaussian; the figure shows that a sizeable performance gap
exists even on CIFAR images with such a target label.
Extension to broader class of algorithms. The orthogonal-equivariance property
holds for many types of practical training algorithms, but not all. Notable exceptions
are adaptive gradient methods (e.g. Adam and AdaGrad), `1 regularizer, and initial-

119
ization methods that are not spherically symmetric. To prove a lower bound against
FC nets with these algorithms, we identify a property, permutation-invariance, which
is satisfied by nets trained using such algorithms. We then demonstrate a single and
natural task on Rd × {±1} that resembles real-life image texture classification, on
which we prove any permutation-invariant learning algorithm requires Ω(d) training
examples to generalize, while Empirical Risk Minimization with O(1) examples can
learn a convolutional net.
Structure of this chapter. In Section 5.2 we discuss about related works. In
section 5.3, we define the notation and cover some preliminaries in PAC learning. In
Section 5.4, we define algorithmic equivariance and prove the orthogonal equivariance
of FC-Net trained by gradient descent. In Section 5.5, we give two warmup examples
and an overview for the proof technique for the main theorem. In Section 5.6, we
present our main results on the lower bound of orthogonal and permutation equivariant
algorithms.

5.2 Related Works

Du et al. [83] attempted to investigate the reason why convolutional nets are more
sample efficient. Specifically they prove O(1) samples suffice for learning a convolutional
filter and also proved a Ω(d) min-max lower bound for learning the class of linear
classifiers. Their lower bound is against learning a class of distributions, and their
work fails to serve as a sample complexity separation, because their upper and lower
bounds are proved on different classes of tasks.
Arjevani and Shamir [84] also considered the notion of distribution-specific hardness
of learning neural nets. They focused on proving running time complexity lower
bounds against so-called ”orthogonally invariant” and ”linearly invariant” algorithms.
However, here we focus on sample complexity.

120
Recently, there has been progress in showing lower bounds against learning with
kernels. Wei et al. [85] constructed a single task on which they proved a sample
complexity separation between learning with neural networks vs. with neural tangent
kernels. Notably the lower bound is specific to neural tangent kernels [15]. Relatedly,
Allen-Zhu and Li [86] showed a sample complexity lower bound against all kernels for
a family of tasks, i.e., learning k-XOR on the hypercube.

5.3 Notation and Preliminaries

We will use Z = Rd , Y = {−1, 1} to denote the domain of the data and label and H =
{h | h : Z → Y} to denote the hypothesis class. Formally, given a joint distribution
P , the error of a hypothesis h ∈ H is defined as errP (h) := Pz,y∼P [h(z) 6= y]. If h is a
random hypothesis, we define errP (h) := Pz,y∼P,h [h(z) 6= y] for convenience. A class
of joint distributions supported on Z × Y is referred as a problem, P.
We use k·k2 to denote the spectrum norm and k·kF to denote the Frobenius norm
of a matrix. We use A ≤ B to denote that B − A is a semi-definite positive matrix.
We also use O(d) and GL(d) to denote the d-dimensional orthogonal group and general
2
linear group respectively. We use Bpd to denote the unit Schatten-p norm ball in Rd×d .
We use N (µ, Σ) to denote Gaussian distribution with mean µ and covariance
Σ. For random variables X and Y , we denote X is equal to Y in distribution by
d
X = Y . In this work, we also always use PZ to denote the distributions on Z
and P to denote the distributions supported jointly on Z × Y. Given an input
distribution PZ and a hypothesis h, we define PZ  h as the joint distribution on
Z × Y, such that (PZ  h)(S) = PZ ({z|(z, h(z)) ∈ S}), ∀S ⊂ Z × Y. In other
words, to sample (Z, Y ) ∼ PZ  h means to first sample Z ∼ PZ , and then set
Y = h(Z). For a family of input distributions PZ and a hypothesis class H, we define

121
PZ  H = {PZ  h | PZ ∈ PZ , h ∈ H}. In this work all joint distribution P can be
written as PZ  h for some h, i.e. PY|Z is deterministic.
For set S ⊂ Z and bijection g : Z → Z, we define g(S) = {g(x)|x ∈ S}. We use
◦ to denote function composition. (f ◦ g)(x) is defined as f (g(x)), and for function
classes F, G, F ◦ G = {f ◦ g | f ∈ F, g ∈ G}. For any distribution PZ supported on
Z , we define PZ ◦ g as the distribution such that (PZ ◦ g)(S) = PZ (g(S)). In other
words, if Z ∼ PZ ⇐⇒ g −1 (Z) ∼ PZ ◦ g, because

 −1 
∀S ⊆ Z, P g (Z) ∈ S = P [Z ∈ g(S)] = [PZ ◦ g](S).
Z∼PZ Z∼PZ

For any joint distribution P of form P = PZ  h, we define P ◦ g = (PZ ◦ g)  (h ◦ g).


In other words, (Z, Y ) ∼ P ⇐⇒ (g −1 (Z), Y ) ∼ P ◦ g. For any distribution class P
and group G acting on Z, we define P ◦ G as {P ◦ g | P ∈ P, g ∈ G}.

Definition 5.3.1. A deterministic supervised Learning Algorithm A is a mapping from


a sequence of training data, {(zi , yi )}ni=1 ∈ (Z × Y)n , to a hypothesis A({(zi , yi )}ni=1 ) ∈
H ⊆ Y Z . The algorithm A could also be randomized, in which case the output
A({(zi , yi )}ni=1 ) is a distribution on hypotheses. Two randomized algorithms A and
A0 are the same if for any input, their outputs have the same distribution in function
d
space, which is denoted by A({zi , yi }ni=1 ) = A0 ({zi , yi }ni=1 ).

5.3.1 Sample Complexity and VC Theory

Definition 5.3.2 (Sample Complexity). Given a distribution class P and a learning


algorithm A, δ, ε ∈ [0, 1], we define the (ε, δ)-sample complexity, denoted N (A, P, ε, δ),
as the smallest integer n such that ∀P ∈ P, w.p. 1 − δ over the randomness of
{zi , yi }ni=1 and A, errP (A({zi , yi }ni=1 )) ≤ ε. We also define the ε-expected sample
complexity for a problem P, denoted N ∗ (A, P, ε), as the smallest integer n such that

122
Algorithm 3 Iterative algorithm A
Require: Initial parameter distribution Pinit supported in W = Rm , total iterations
T , training dataset {zi , yi }ni=1 , parametric model M : W → H, iterative update
rule F (x, M, {zi , yi }ni=1 )
Ensure: Hypothesis h : Z → Y.
Sample x(0) ∼ Pinit .
for t = 0 to T − 1 do
n
x(t+1) = F (x(t) , M, {zi , y(Ti }) i=1 ).

return h(·) = sign M[x ](·) .

∀P ∈ P, E [errP (A({zi , yi }ni=1 ))] ≤ ε.


(zi ,yi )∼P
By definition, we have N ∗ (A, P, ε + δ) ≤ N (A, P, ε, δ) ≤ N ∗ (A, P, εδ), ∀ε, δ ∈ [0, 1].

For function class H, we use ΠH (n) to denote the growth function of H, i.e.
ΠH (n) := sup |{(h(z1 ), h(z2 ), . . . , h(zn )) | h ∈ H}| . The Vapnik–Chervonenkis(VC)
z1 ,...,zn ∈Z
dimension of H is defined as the largest integer such that ΠH (n) = 2n and we denote
 VCdim(H)
en
it by VCdim(H). By Sauer-Shelah Lemma, we know ΠH (n) ≤ VCdim(H) for
n ≥ VCdim(H).

Theorem 5.3.3. [87] If learning algorithm A is consistent and ranged in H, i.e.


A({zi , yi }ni=1 )(zi ) = yi , ∀i ∈ [n] and A({zi , yi }ni=1 ) ∈ H, then for any distribution PZ
and 0 < ε, δ < 1, we have that

VCdim(H) ln 1ε + ln 1δ
 
N (A, PZ  H, ε, δ) = O . (5.1)
ε

Meanwhile, there’s a distribution PZ supported on any subsets {z0 , . . . , zd−1 } which


can be shattered by H, such that for any 0 < ε, δ < 1 and any algorithm A, it holds

VCdim(H) + ln 1δ
 
N (A, PZ  H, ε, δ) = Ω . (5.2)
ε

123
5.3.2 Parametric Models and Iterative Algorithms

A parametric model M : W → H is a functional mapping from weight x to a


hypothesis M[x]. Given a specific parametric model M, a general iterative algorithm
is defined as Algorithm 3. In this work, we will only use the two parametric models
below, FC-NN and CNN.
FC Nets: A L-layer Fully-connected Neural Network parameterized by its weights
x = (W1 , W2 , . . . , WK ) is a function FC-NN[x](·) : Rd → R, where Wi ∈ Rdi−1 ×di ,
d0 = d, and dK = 1:

FC-NN[x](z) = WK σ(WK−1 · · · σ(W2 σ(W1 z))).

Here, σ : R → R can be any function, and we abuse the notation such that σ is also
defined for vector inputs, in the sense that [σ(z)]i = σ(zi ).
ConvNets (CNN): In this chapter we will only use two layer Convolutional
Neural Networks with one channel. Suppose d = d0 r for some integers d0 and r, a
2-layer CNN parameterized by its weights x = (w, a, b) ∈ Rk × Rr × R is a function
CNN[x](·) : Rd → R:

r
X
CNN[x](z) = ai σ([w ∗ z]d0 (i−1)+1:d0 i ) + b,
i=1

where ∗ : Rk × Rd → Rd is the convolution operator, defined as [w ∗ z]i =


d0
Pk
j=1 wj z[i−j−1 mod d]+1 , and σ : R → R is the composition of pooling and element-

wise non-linearity.

124
5.4 Algorithmic Equivariance in Fully-connected

Nets Trained by SGD

In this section, we first give the formal definition of equivariant algorithms. Then we
start with an informal sketch of why FC nets trained with standard algorithms have
certain equivariance properties and then give the formal proof.

Definition 5.4.1 (Equivariant Algorithms). A learning algorithm is equivariant under


group GZ (or GZ -equivariant) if and only if for any dataset {zi , yi }ni=1 ∈ (Z × Y)n and
∀g ∈ GZ , z ∈ Z, A({g(zi ), yi }ni=1 ) ◦ g = A({zi , yi }ni=1 ), or A({g(zi ), yi }ni=1 )(g(z)) =
[A({zi , yi }ni=1 )](z). 1

The high level idea here is if update rule of the network, or more generally,
the parametrized model, exhibits certain symmetry per step, i.e., property 2 in
Theorem 5.4.2, then by induction it will hold till the last iteration.
Taking linear regression as an example, let zi ∈ Rd , i ∈ [n] be the data and y ∈ Rn
2
be the labels, the GD update for L(w) = 12 ni=1 (z> 2 1 >
P
i w − yi ) = 2 Z w − y 2 would

be wt+1 = F (wt , Z, y) := wt − ηZ(Z> wt − y). Now suppose there’s another person


trying to solve the same problem using GD with the same initial linear function, but
he observes everything in a different basis, i.e., Z0 = U Z and w00 = U w0 , for some
orthogonal matrix U . Not surprisingly, he would get the same solution for GD, just in a
different basis. Mathematically, this is because wt0 = U wt =⇒ wt+1
0
= F (wt0 , U Z, y) =
U F (wt , Z, y) = U wt+1 . In other words, he would make the same prediction for
unseen data. Thus if the initial distribution of w0 is the same under all basis (i.e.,
d
under rotations), e.g., gaussian N (0, Id ), then w0 = U w0 =⇒ F t (w0 , U Z, y) =
U F t (w0 , Z, y), for any iteration t, which means GD for linear regression is orthogonal
invariant.
1 n d n
For randomized algorithms, the condition becomes A({g(zi ), yi }i=1 ) ◦ g = A({zi , yi }i=1 ), which
n d n
is stronger than A({g(zi ), yi }i=1 )(g(z)) = [A({zi , yi }i=1 )](z), ∀z ∈ Z.
125
Matrix Group Diagonal Matrices, Permutation Orthogonal General
(Symmetry) Diagonals ∈ {±1} Group Group Linear Group
Algorithms AdaGrad, Adam AdaGrad, Adam SGD Newton’s method
Initialization Symmetric i.i.d. i.i.d. Gaussian All zero
Regularization `p norm `p norm `2 norm None

Table 5.1: Examples of gradient-based equivariant training algorithms for FC networks.


The initialization requirement is only for the first layer of the network.

To show orthogonal equivariance for gradient descent on general deep FC nets, it


suffices to apply the above argument on each neuron in the first layer of the FC nets.
Below we give sufficient conditions for an iterative algorithm to be equivariant (as
defined in Algorithm 3) and its proof based on the same idea of induction. Equivariance
for both GD on FC nets and other training algorithms (see Table 5.1) can be derived
using Theorem 5.4.2.

Theorem 5.4.2. Suppose GZ is a group acting on Z = Rd , the iterative algorithm A


is GZ -equivariant (as defined in Algorithm 3) if the following conditions are met:

1. There’s a group GW acting on W and a group isomorphism τ : GZ → GW , such


that M[τ (g)(x)](g(z)) = M[x](z), ∀z ∈ Z, x ∈ W, g ∈ G. (One can think g as
the rotation U applied on data z in linear regression and τ (U ) as the rotation U
applied on w.)

2. Update rule F is invariant under any joint group action (g, τ (g)), ∀g ∈ G. In
other words, [τ (g)](F (x, M, {zi , yi }ni=1 )) = F ([τ (g)](x), M, {g(zi ), yi }ni=1 ).

3. The initialization Pinit is invariant under group GW , i.e. ∀g ∈ GW , Pinit =


Pinit ◦ g −1 .

Here we want to address that the three conditions in Theorem 5.4.2 are natural and
almost necessary. Condition 1 is the minimal expressiveness requirement for model
M to allow equivariance. Condition 3 is required for equivariance at initialization.
Condition 2 is necessary for induction.
126
Proof of Theorem 5.4.2. ∀g ∈ GZ , we sample x(0) ∼ Pinit , and x̃(0) = τ (g)(x(0) ).
d
By property (3), x̃(0) = x(0) ∼ Pinit . Let x(t+1) = F x(t) , M, {zi , yi }ni=1 and


x̃(t+1) = F x̃(t) , M, {g(zi ), yi }ni=1 for 0 ≤ t ≤ T − 1, we can show x̃(t) = τ (g)x(t) ) by




induction using property (2). By definition of Algorithm 3, we have

d
A {zi , yi }ni=1 = M[x(T ) ],

and
d
M[x̃(T ) ] ◦ g = A({g(zi ), yi }ni=1 ) ◦ g.

By property (1), we have M[x̃(T ) ](g(z)) = M[τ (g)(x(T ) ](g(z)) = M[x(T ) ](z).
d d
Therefore, A({zi , yi }ni=1 ) = M[x(T ) ] = M[x̃(T ) ] ◦ g = A({g(zi ), yi }ni=1 ) ◦ g, meaning A
is GZ -equivariant.

Remark 5.4.3. Theorem 5.4.2 can be extended to the stochastic case and the adaptive
case which allows the algorithm to use information of the whole trajectory, i.e., the
update rule could be generalized as x(t+1) = Ft ({x(s) }ts=1 , M, {zi , yi }ni=1 ), as long as
(the distribution of) each Ft is invariant under joint transformations.

Below are two example applications of Theorem 5.4.2. Other results in Table 5.1
could be achieved in the same way.
For classification tasks, optimization algorithms often work with a differentiable
surrogate loss ` : R → R instead the 0-1 loss, such that `(yh(z)) ≥ 1 [yh(z) ≤ 0],
and the total loss for hypothesis h and training, L(M[x]; {zi , yi }ni=1 ) is defined as
Pn
i=1 `(M[x](zi )yi ). It’s also denoted by L(x) when there’s no confusion.

Definition 5.4.4 (Gradient Descent for FC nets, Algorithm 4). We call Algorithm 3
Gradient Descent for FC nets if M = FC-NN and F = GDL , where GDL (x) =
x − η∇L(x) is called the one-step Gradient Descent update and η > 0 is the learning
rate.

127
Algorithm 4 Gradient Descent for FC-NN (FC networks)
Require: Initial parameter distribution Pinit , total iterations T , training dataset
{zi , yi }ni=1 , loss function `
Ensure: Hypothesis h : Z → Y.
Sample x(0) ∼ Pinit .
for t = 0 to T − 1 do
n
x(t+1) = x(t) − η ∇`(FC-NN(x(t) )(zi ), yi )
P
i=1 
return h = sign FC-NN[x(T ) ] .

Corollary 5.4.5. Fully-connected networks trained with (stochastic) gradient descent


from i.i.d. Gaussian initialization is equivariant under the orthogonal group.

Proof of Corollary 5.4.5. We will verify the three conditions required in Theorem 5.4.2
one by one.
Condition 1: This is the only place we use the FC structure.

Lemma 5.4.6. There’s a subgroup GW of O(m), and a group isomorphism τ : GZ =


O(d) → GW , such that FC-NN[τ (R)(x)] ◦ R = FC-NN[x], ∀x ∈ W, R ∈ GZ .

Proof of Lemma 5.4.6. By definition, FC-NN[x](z) could be written FC-NN[x2:L ](σ(W1 z)),
which implies FC-NN[x](z) = FC-NN[W1 R−1 , x2:L ](Rz), ∀R ∈ O(d), and thus we can
pick τ (R) = O ∈ O(m), where O(x) = [W1 R−1 , x2:L ], and GW = τ (O(d)).

Condition 2: A notable property of Gradient Descent is that it is invariant


under orthogonal re-parametrization. Formally, given loss function L : Rm → R and
parameters W ∈ Rm , an orthogonal re-parametrization of the problem is to replace
(L, W ) by (L ◦ O−1 , OW ), where O ∈ Rm×m is an orthogonal matrix.

Lemma 5.4.7 (Gradient Descent is invariant under orthogonal re-parametization).


For any L, W and orthogonal matrix O ∈ Rm×m , we have O · GDL (x) = GDL◦O−1 (Ox).

128
Proof of Lemma 5.4.7. By definition, it suffices to show that for each i ∈ [n], and
every x and x0 = Ox,

O∇x `(FC-NN[x](zi ), yi ) = ∇x0 `(FC-NN[O−1 x0 ](zi ), yi ),

which holds by chain rule.

For any R ∈ O(d), and set O = τ (R) by Lemma 5.4.6, (L ◦ O−1 )[x] =
Pn −1
Pn
i=1 `(FC-NN[O (x)](zi ), yi ) = i=1 `(FC-NN[x](Rzi ), yi ). The second condition in

Theorem 5.4.2 is satisfied by plugging above equality into Lemma 5.4.7.


Condition 3: The third condition is also satisfied since the initialization distribu-
tion is i.i.d. Gaussian, which is known to be orthogonal invariant. In fact, from the
proof, it suffices to have the initialization of the first layer invariant under GZ .

Corollary 5.4.8. FC nets trained with newton’s method from zero initialization for
the first layer and any initialization for the rest parameters is GL(d)-equivariant, or
equivariant under the group of invertible linear transformations.
Here, Netwon’s method means to use NT(x) = x − η(∇2 L(x))−1 ∇L(x) as the
update rule and we assume ∇2 L(x) is invertible.

Proof of Corollary 5.4.8. The proof is almost the same as that of Corollary 5.4.5,
except the following modifications.
Condition 1: If we replace the O(d), O(m) by GL(d), GL(m) in the statement
and proof Lemma 5.4.6, the lemma still holds.
Condition 2:By chain rule, one can verify the update rule Newton’s method is
invariant under invertible linear re-parametization, i.e. ONT[W ] = NTL◦O−1 [OW ], for
all invertible matrix O.
Condition 3: Since the first layer is initialized to be 0, it is invariant under any
linear transformation.

129
Remark 5.4.9. The above results can be easily extended to the case of momentum and
Lp regularization. For momentum, we only need to ensure that the following update
rule, x(t+1) = GDM(x(t) , x(t−1) , M, {zi , yi }ni=1 ) = (1 + γ)x(t) − γx(t−1) − η∇L(x(t) ),
also satisfies the property in Lemma 5.4.7. For Lp regularization, because kxkp is
independent of {zi , yi }ni=1 , we only need to ensure kxkp = kτ (R)(x)kp , ∀R ∈ GZ ,
which is easy to check when GZ only contains permutation or sign-flip.

5.4.1 Examples of Equivariance for Non-iterative Algorithms

To demonstrate the wide application of our lower bounds, we give two more examples
of algorithmic equivariance where the algorithm is not iterative. The proofs are
folklore.

Definition 5.4.10. Given a positive semi-definite kernel K, the Kernel Regression


algorithm REGK is defined as:

REGK ({zi , yi }ni=1 )(z) := 1 K(z, ZN ) · K(ZN , ZN )† y ≥ 0


 

where K(ZN , ZN ) ∈ Rn×n , [K(ZN , ZN )]i,j = K(zi , zj ), y = [y1 , y2 , . . . , yN ]> and


K(z, ZN ) = [K(z, z1 ), . . . , K(z, zN )].

Kernel Regression: If kernel K is GZ -equivariant, i.e., ∀g ∈ GZ , z, z0 ∈ Z,


K(g(z), g(z0 )) = K(z, z0 ), then algorithm REGK is GZ -equivariant.
ERM: If F = F ◦ GZ , and argminh∈F ni=1 1 [h(zi ) 6= yi ] is unique, then ERMF is
P

GZ -equivariant.

130
5.5 Warm-up Examples and Proof Idea for Main

Results

5.5.1 Example 1: Ω(d) Lower Bound Against oOthogonal

Equivariant Methods

We start with a simple but insightful example to how equivariance alone could suffice
for some non-trivial lower bounds.
We consider a task on Rd × {±1} which is a uniform distribution on the set
{(ei y, y)|i ∈ {1, 2, . . . , d}, y = ±1}, denoted by P . Each sample from P is a one-hot
vector in Rd and the sign of the non-zero coordinate determines its label. Now imagine
our goal is to learn this task using an algorithm A. After observing a training set of n
labeled points S := {(zi , yi )}ni=1 , the algorithm is asked to make a prediction on an
unseen test data z, i.e., A(S)(z). Here we are concerned with orthogonal equivariant
algorithms ——the prediction of the algorithm on the test point remains the same
even if we rotate every zi and the test point x by any orthogonal matrix U , i.e.,

d
A({(U zi , yi )}ni=1 )(U z) = A({(zi , yi )}ni=1 )(z)

Now we show this algorithm fails to generalize on task P , if it observes only d/2
training examples. The main idea here is that, for a fixed training set S, the prediction
A({(zi , yi )}ni=1 )(z) is determined solely by the inner products between z and zi ’s due
to orthogonal equivariance, i.e., there exists a random function f (which may depend
on S) such that2

d
A({(zi , yi )}ni=1 )(z) = f (z> z1 , . . . , z> zn )

2
this can be made formal using the fact that Gram matrix determine a set of vectors up to an
orthogonal transformation.
131
But the input distribution for this task is supported on 1-hot vectors. Suppose n < d/2.
Then at test time the probability is at least 1/2 that the new data point (z, y) ∼ P ,
is such that z has zero inner product with all n points seen in the training set S. This
fact alone fixes the prediction of A to the value f (0, . . . , 0) whereas y is independently
and randomly chosen to be ±1. We conclude that A outputs the wrong answer with
probability at least 1/4.

5.5.2 Example 2: Ω(d2 ) Lower Bound in the Weak Sense

This warm up example illustrates the main insight of Ng [82], namely, that when
an orthogonal equivariant algorithm is used to do learning on a certain task, it is
actually being forced to simultaneously learn all orthogonal transformations of this
task. Intuitively, this should make the learning much more sample-hungry compared
to even Simple SGD on ConvNets, which is not orthogonal equivariant. Now we sketch
why the obvious way to make this intuition precise using VC dimension (Theorem 5.3.3)
does not give a proper separation between ConvNets and FC nets, as mentioned in
the introduction.
hP i
d P2d
We first fix the ground truth labeling function h∗ (z) = sign 2
i=1 zi − 2
i=d+1 i .
z
Algorithm A is orthogonal equivariant (Definition 5.4.1) means that for any task
P = PZ  h∗ , where PZ is the input distribution and h∗ is the labeling function, A must
have the same performance on P and its rotated version P ◦ U = (PZ ◦ U )  (h∗ ◦ U ),
where U can be any orthogonal matrix. Therefore if there is an orthogonal equivariant
learning algorithm A that learns h∗ on all distributions, then A will also learn
every the rotated copy of h∗ , h∗ ◦ U , on every distribution PZ , simply because A
learns h∗ on distribution PZ ◦ U −1 . Thus A learns the class of labeling functions
h∗ ◦ O(2d) := {h | h(z) = h∗ (U z), ∀U ∈ O(2d)} on all distributions. (See formal
statement in Theorem 5.6.2) By the standard lower bounds with VC dimension (See
Theorem 5.3.3), it takes at least Ω( VCdim(H◦O(2d))
ε
) samples for A to guarantee 1 − ε

132
accuracy. Thus it suffices to show the VC dimension VCdim(H ◦ O(2d)) = Ω(d2 ),
towards a Ω(d2 ) sample complexity lower bound. (Ng [82] picks a linear thresholding
function as h∗ , and thus VCdim(h∗ ◦ O(2d)) is only O(d).)
Formally, we have the following theorem, whose proof is deferred into Section 5.7.2:

Theorem 5.5.1 (All distributions, single hypothesis). Let P = {all distributions} 


{h∗ }. For any orthogonal equivariant algorithm A, N (A, P, ε, δ) = Ω((d2 + ln 1δ )/ε),
while there’s a 2-layer ConvNet architecture, such that N (ERMCNN , P, ε, δ) =
O 1ε log 1ε + log 1δ .


As noted in the introduction, this doesn’t imply there is some task hard for every
training algorithm for the FC net. The VC dimension based lower bound implies for
each algorithm A the existence of a fixed distribution PZ ∈ P and some orthogonal
matrix UA such that the task (PZ ◦ UA−1 )  h∗ is hard for it. However, this does not
preclude (PZ ◦ UA−1 )  h∗ being easy for some other algorithm A0 .

5.5.3 Proof Idea for Fixed Distribution Lower Bounds

At first sight, the issue highlighted above (and in the Introduction) seems difficult to
get around. One possible avenue is if the hard input distribution PZ in the task were
invariant under all orthogonal transformations, i.e., PZ = PZ ◦ U for all orthogonal
matrices U . Unfortunately, the distribution constructed in the proof of lower bound
with VC dimension is inherently discrete and cannot be made invariant to orthogonal
transformations.
Our proof uses a fixed PZ , the standard Gaussian distribution, which is indeed
invariant under orthogonal transformations. The proof also uses the Benedek-Itai’s
lower bound, Theorem 5.5.2, and the main technical part of our proof is the lower
bound for the the packing number D(H, ρ, ε) defined below (also see Equation (5.4)).
Let ρ be a metric on H, We define N (H, ρ, ε) as the ε-covering number of H w.r.t.
ρ, and D(H, ρ, ε) as the ε-packing number of H w.r.t. ρ. For distribution PZ , we use
133
ρZ (h, h0 ) := PX∼PZ [h(X) 6= h0 (X)] to denote the discrepancy between hypothesis h
and h0 w.r.t. PZ .

Theorem 5.5.2. [Benedek-Itai’s lower bound [88]] For any algorithm A that (ε, δ)-
learns H with n i.i.d. samples from a fixed distribution PZ , it must hold for every

ΠH (n) ≥ (1 − δ)D(H, ρZ , 2ε) (5.3)

Since ΠH (n) ≤ 2n , we have N (A, PZ  H, ε, δ) ≥ log2 D(H, ρZ , 2ε) + log2 (1 − δ), which
is the original bound by Benedek and Itai [88]. Later Long [89] improved this bound
for the regime n ≥ VCdim(H) using Sauer-Shelah lemma, i.e.,

VCdim(H) 1
N (A, PZ , ε, δ) ≥ ((1 − δ)D(H, ρZ , 2ε)) VCdim(H) . (5.4)
e

Intuition behind Benedek-Itai’s lower bound. We first fix the data distribu-
tion as PZ . Suppose the 2ε-packing is labeled as {h1 , . . . , hD(H,ρZ ,2ε) } and ground truth
is chosen from this 2ε-packing, (ε, δ)-learns the hypothesis H means the algorithm
is able to recover the index of the ground truth w.p. 1 − δ. Thus one can think this
learning process as a noisy channel which delivers log2 D(H, ρZ , 2ε) bits of information.
Since the data distribution is fixed, unlabeled data is independent of the ground
truth, and the only information source is the labels. With some information-theoretic
inequalities, we can show the number of labels, or samples (i.e., bits of information)
N (A, PZ  H, ε, δ) ≥ log2 D(H, ρZ , 2ε) + log2 (1 − δ). A more closer look yields Equa-
tion (5.4), because when VCdim(H) < ∞, then only log2 ΠH (n) instead of n bits
information can be delivered.

134
5.6 Main Results: Sample Complexity Lower

Bounds for Equivariant Algorithms

Below we first present a reduction from a special subclass of PAC learning to equivariant
learning (Theorem 5.6.2), based on which we prove our main separation results,
Theorem 5.5.1, 5.6.4, 5.6.5 and 5.6.6.

Lemma 5.6.1. Let A be the set of all algorithms and AGZ be the set of all GZ -
equivariant algorithms, the following inequality holds. The equality is attained when
GZ is a compact group.

inf N ∗ (A, P, ε) ≥ inf N ∗ (A, P ◦ GZ , ε) (5.5)


A∈AGZ A∈A

Proof of Lemma 5.6.1. Given GZ -equivariant algorithm A, by definition, N ∗ (A, P, ε) =


N ∗ (A, P ◦ g −1 , ε), ∀g ∈ GZ . Consequently, we have

N ∗ (A, P, ε) = N ∗ (A, P ◦ GZ , ε). (5.6)

Take infimum over AGZ over the both side of Equation (5.6), and note that AGZ ⊂ A,
Inequality (5.5) is immediate.
Suppose the group GZ is compact and let µ be the Haar measure on it, i.e.
∀S ⊂ GZ , g ∈ GZ , µ(S) = µ(g ◦ S). We claim for each algorithm A, the sample
complexity of the following equivariant algorithm A0 is no higher than that of A on
P  GZ :
A0 ({zi , yi }ni=1 ) = A({g(zi ), yi }ni=1 ) ◦ g, where g ∼ µ.

135
By the definition of Haar measure, A0 is GZ -equivariant. Moreover, for any fixed
n ≥ 0, we have

inf E [errP (A0 ({zi , yi }ni=1 ))] = inf E E [errP (A({zi , yi }ni=1 ))]
P ∈P (zi ,yi )∼P P ∈P g∼µ (zi ,yi )∼P ◦g −1

≥ inf inf E [errP (A({zi , yi }ni=1 ))] = inf E [errP (A({zi , yi }ni=1 ))] ,
P ∈P g∈GZ (zi ,yi )∼P ◦g −1 P ∈P◦GZ (zi ,yi )∼P

which implies inf A∈AGZ N ∗ (A, P, ε) ≤ inf A∈A N ∗ (A, P ◦ GZ , ε).

Theorem 5.6.2. If PZ is a set of data distributions that is invariant under group


GZ , i.e., PZ ◦ GZ = PZ , then the following inequality holds. Furthermore it becomes
an equality when GZ is a compact group.

inf N ∗ (A, PZ  H, ε) ≥ inf N ∗ (A, PZ  (H ◦ GZ ), ε) (5.7)


A∈AGZ A∈A

Proof of Theorem 5.6.2. Simply note that (PZ  H) ◦ GZ = ∪g∈GZ (PZ ◦ g)  (H ◦ g −1 ) =


∪g∈GZ PZ  (H ◦ g −1 ) = PZ  (H ◦ GZ ), the theorem is immediate from Lemma 5.6.1.

Remark 5.6.3. The sample complexity in standard PAC learning is usually defined
again hypothesis class H only, i.e., PZ is the set of all the possible input distributions.
In that case, PZ is always invariant under group GZ , and thus Theorem 5.6.2 says
that GZ -equivariant learning against hypothesis class H is as hard as learning against
hypothesis H ◦ GZ without equivariance constraint.

5.6.1 Ω(d2 ) Lower Bound for Orthogonal Equivariance With

a Fixed Distribution

In this subsection we show Ω(d2 ) vs O(1) separation on a single task in our main
theorem (Theorem 5.6.4). With the same proof technique, we further show we can
2
get correct dependency on ε for the lower bound, i.e., Ω( dε ), by considering a slightly

136
larger function class, which can be learnt by ConvNets with O(d) samples. We also
generalize this Ω(d2 ) vs O(d) separation to the case of `2 regression with a different
proof technique.
hP i
d P2d
Theorem 5.6.4. There’s a single task, PZ h∗ , where h∗ = sign 2
i=1 zi − 2
i=d+1 zi

and PZ = N (0, I2d ) and a constant ε0 > 0, independent of d, such that for any
orthogonal equivariant algorithm A, we have

N ∗ (A, PZ  h∗ , ε0 ) = Ω(d2 ), (5.8)

while there’s a 2-layer ConvNet architecture CNN, such that N (ERMCNN , PZ 


h∗ , ε, δ) = O 1ε log 1ε + log 1δ . Moreover, ERMCNN could be realized by gradient


descent (on the second layer only).

Proof of Theorem 5.6.4. Upper bound: implied by upper bound in Theorem 5.5.1.
Lower bound: Note that the PZ = N (0, I2d ) is invariant under O(2d), by The-
orem 5.6.2, it suffices to show that there’s a constant ε0 > 0 (independent of d),
for any algorithm A, it takes Ω(d2 ) samples to learn the augmented function class
h∗ ◦ O(2d) w.r.t. PZ = N (0, I2d ). Define hU = sign z>
  d×d
1:d U zd+1:2d , ∀U ∈ R , and
by Lemma 5.7.4, we have H = {hU | U ∈ O(d)} ⊆ h∗ ◦ O(2d). Thus it suffices to
show a Ω(d2 ) sample complexity lower bound for the function subclass H, i.e.,

N ∗ (A, N (0, I2d )  {sign z> 2


 
1:d U zd+1:2d | U ∈ O(d) }, ε0 ) = Ω(d ). (5.9)

By Benedek&Itai’s lower bound, [88] (Equation (5.3)), we know

N (A, P, ε0 , δ) ≥ log2 ((1 − δ)D(H, ρZ , 2ε0 )) . (5.10)

d(d−1)
By Lemma 5.7.6, there’s some constant C, such that D(H, ρZ , ε) ≥ ( Cε ) 2 , ∀ε > 0.

137
kU −V kF
The high-level idea for Lemma 5.7.6 is to first show that ρZ (hU , hV ) ≥ Ω( √
d
),
and then we show the packing number of orthogonal matrices in a small neighborhood
k·kF
of Id w.r.t. √
d
is roughly the same as that in the tangent space of orthogonal manifold
d(d−1)
at Id , i.e., the set of skew matrices, which is of dimension 2
and has packing
d(d−1)
number ( Cε ) 2 . The advantage of working in the tangent space is that we can apply
the standard volume argument.
d(d−1)
Setting δ = 12 , we have that N ∗ (A, P, ε0 ) ≥ N (A, P, 12 , 2ε0 ) ≥ 2
log2 C
4ε0
−1 =
Ω(d2 ).

Indeed, we can improve the above lower bound by applying Equation (5.4), and
get
  12   21 − 2d1
1 d2 1 d C 1 1
N (A, P, ε, ) ≥ = Ω(d2 ε− 2 + 2d ). (5.11)
2 e 2 ε
1 1
Note that the dependency in ε in Equation (5.11) is ε− 2 + 2d is not optimal, as
opposed to ε−1 in upper bounds and other lower bounds. A possible reason for
this might be that Theorem 5.5.2 (Long’s improved version) is still not tight and it
might require a tighter probabilistic upper bound for the growth number ΠH (n), at
least taking PZ into consideration, as opposed to the current upper bound using VC
2
dimension only. We left it as an open problem to show a single task P with Ω( dε )
sample complexity to achieve ε error for all orthogonal equivariant algorithms.
However, if the hypothesis class is of VC dimension O(d), using a similar idea, we
can prove a Ω(d2 /ε) sample complexity lower bound for equivariant algorithms, and
O(d) upper bounds for ConvNets.

Theorem 5.6.5 (Single distribution, multiple functions). There is a problem with


hP i
d 2
single input distribution, P = {PZ }H = {N (0, Id )}{z → sign i=1 αi zi | αi ∈ R},

such that for any orthogonal equivariant algorithms A and ε > 0, N ∗ (A, P, ε) =
Ω(d2 /ε), while there’s a 2-layer ConvNets architecture, such that N (ERMCNN , P, ε, δ) =
d log 1ε +log 1
O( ε
δ
).

138
Interestingly, we can show an analog of Theorem 5.6.5 for `2 regression, i.e., the
algorithm not only observes the signs but also the values of labels yi . Here we define
the `2 loss of function h : Rd → R as `P (h) = E [(h(z) − y)2 ] and the sample
(z,y)∼P
complexity N ∗ (A, P, ε) for `2 loss similarly as the smallest number n ∈ N such that
∀P ∈ P, E [`P (A({zi , yi }ni=1 ))] ≤ ε E [y 2 ]. The last term E [y 2 ] is added
(zi ,yi )∼P (x,y)∼P (x,y)∼P
for normalization to avoid the scaling issue and thus any ε > 1 could be achieved
trivially by predicting 0 for all data.

Theorem 5.6.6 (Single distribution, multiple functions, `2 regression). There is a


problem with single input distribution, P = {PZ }  H = {N (0, Id )}  { di=1 αi zi2 | αi ∈
P

R} , such that for any orthogonal equivariant algorithms A and ε > 0, N ∗ (A, P, ε) ≥
d(d+3)
2
(1 − ε) − 1, while there’s a 2-layer ConvNet architecture CNN, such that
N ∗ (ERMCNN , P, ε) ≤ d for any ε > 0.

5.6.2 Ω(d) Lower Bound for Permutation Equivariance

In this subsection we will present Ω(d) lower bound for permutation equivariance via
a different proof technique — direct coupling. The high-level idea of direct coupling
is to show with constant probability over (Zn , z), we can find a g ∈ GZ , such that
g(Zn ) = Zn , but z and g(z) has different labels, in which case no equivariant algorithm
could make the correct prediction.

Theorem 5.6.7. Let ti = ei +ei+1 and si = ei +ei+2 3 and P be the uniform distribution
on {(si , 1)}ni=1 ∪ {(ti , −1)}ni=1 , which is the classification problem for local textures in
a 1-dimensional image with d pixels. Then for any permutation equivariant algorithm
A, N (A, P, 18 , 18 ) ≥ N ∗ (A, P, 14 ) ≥ d
10
. Meanwhile, N (ERMCN N , P, 0, δ) ≤ log2 1δ + 2,
where ERMCN N stands for ERMCN N for function class of 2-layer ConvNets.

3
For vector z ∈ Rd , we define zi = z(i−1) mod d+1 if i ∈
/ [d].
139
Remark 5.6.8. The task could be understood as detecting if there are two consecutive
white pixels in the black background. For proof simplicity, we take texture of length
2 as an illustrative example. It is straightforward to extend the same proof to
more sophisticated local pattern detection problem of any constant length and to
2-dimensional images.

5.7 Proofs

5.7.1 Some Basic Inequalities

Lemma 5.7.1.
p
∀x ∈ [−1, 1], arccos x ≥ 2(1 − x).

Proof of Lemma 5.7.1. Let x = cos(t). If t = 0, then both sides are equal to 0 and
the inequality holds. Otherwise if t ∈ (0, π], we have the following

arccos(x) t t √
√ =p =√ ≥ 2,
1−x 1 − cos(t) 2 sin(t/2)

which completes the proof.

Lemma 5.7.2. There is C > 0, such that ∀d ∈ N+ , M ∈ Rd×d ,

√ √
C kM kF / d ≤ E [kM zk2 ] ≤ kM kF / d. (5.12)
z∼Sd−1

Proof of Lemma 5.7.2. Upper Bound: By Cauchy-Schwarz inequality, we have

s  r
tr[M M > ]

kM k
r
kM zk22 = = √ F.
 
E [kM zk2 ] ≤ E tr M E [zz> ] M > =
z∼Sd−1 z∼Sd−1 z∼Sd−1 d d

Lower Bound: Let M = U ΣV > be the singular value decomposition of M ,


where U, V are orthogonal matrices and Σ is diagonal. Since kM kF = kΣkF , and

140
E [kM zk2 ] = E [kΣzk2 ], w.l.o.g., we only need to prove the lower bound for
z∼Sd−1 z∼Sd−1
all diagonal matrices.
By Proposition 2.5.1 in [90], there’s some constant C, such that

v v
u d u d
uX uX
C kΣkF = C t σi2 ≤ E t zi2 σi2 = E [kM zk]2 .
z∼N (0,Id ) z∼N (0,Id )
i=1 i=1
r  √
kzk22 = d.

By Cauchy-Schwarz Inequality, we have E [kzk2 ] ≤ E
z∼N (0,Id ) z∼N (0,Id )
Therefore, we have that

C kΣkF ≤ E [kM zk]2


z∼N (0,Id )

= E [kM ẑk]2 E [kzk2 ] (5.13)


ẑ∼Sd−1 z∼N (0,Id )

≤ E [kM ẑk]2 d,
ẑ∼Sd−1

which completes the proof.

Lemma 5.7.3. For any z > 0, we have that

2 z
Pr (|x| ≤ z) ≤ √
x∼N (0,σ) πσ

Proof of Lemma 5.7.3. By definition of normal distribution, we have that

z
r
x2
Z  
1 2z
Pr (|x| ≤ z) = √ exp − 2 dx ≤
x∼N (0,σ) −z 2π σ 2σ πσ

5.7.2 Proof of Theorem 5.5.1

Lemma 5.7.4. Define hU = sign z>


 
1:d U zd+1:2d , ∀U ∈ Rd×d , we have H = {hU | U ∈
hP i
∗ ∗ d 2
P2d 2
O(d)} ⊆ h ◦ O(2d), where h (z) = sign i=1 zi − i=d+1 zi .

141
Proof of Lemma 5.7.4. Note that

 
     
 0 U  Id 0   0 Id  Id 0 
 = · · ,
U> 0 0 U> Id 0 0 U

and

       
√ √ √ √
2 2 2 2
 0 Id   2 Id − I
2 d
I
 d 0   2 Id I
2 d
 = √ √ · · √ √ ,
2 2
Id 0 2 d
I 2 d
I 0 −I d − 22 Id 2
2 d
I

thus for any U ∈ O(d), ∀z ∈ R2d , we have that

   
 >  0 U 
hU (z) = sign z>
 
1:d U zd+1:2d = sign z   z
U> 0
    (5.14)
Id 0 
=sign gU (z)>   gU (z) ,
 
0 −Id

   
√ √
2 2
 2 Id − I
2 d Id 0 
where gU (z) = √ √  ·   · z is an orthogonal transformation on R2d .
2 2
2 d
I 2 d
I 0 U
Thus we conclude that hU ∈ h∗ ◦ O(2d).

Lemma 5.7.5. Define hU = sign z>


  d×d
1:d U zd+1:2d , ∀U ∈ R , and H = {hU | U ∈
O(d)}, we have
d(d − 1)
VCdim(H) ≥ .
2

Proof of Lemma 5.7.5. Now we claim H shatters {ei + ed+j }1≤i<j≤d , i.e. O(d) can
shatter {ei e>
j }1≤i<j≤d , or for any sign pattern {σij }1≤i<j≤d , there exists U ∈ O(d),

such that sign U, ei e> = σij , which implies VCdim(H) ≥ d(d−1)


 
j 2
.
Let so(d) = {M | M = −M > , M ∈ Rd×d }, we know

u2
exp(u) = Id + u + + · · · ∈ SO(d), ∀u ∈ so(d).
2
142
σij (ei e> > +
P
Thus for any sign pattern {σij }1≤i<j≤d , let u = j −ej ei ) and λ → 0 ,
1≤i<j≤d
it holds that sign exp(λu), ei e>
 
j = sign [0 + λσij + O(λ2 )] = sign [σij + O(λ)] =
σij .

Theorem 5.5.1 (All distributions, single hypothesis). Let P = {all distributions} 


{h∗ }. For any orthogonal equivariant algorithm A, N (A, P, ε, δ) = Ω((d2 + ln 1δ )/ε),
while there’s a 2-layer ConvNet architecture, such that N (ERMCNN , P, ε, δ) =
O 1ε log 1ε + log 1δ .


Proof of Theorem 5.5.1. Lower bound: Suppose d = 2d0 for some integer d0 , we
construct P = PZ  H, where PZ is the set of all possible distributions on Z = R3k ,
hP 0 P2d0 i
d 2 2 0
 > 
and H = {sign z
i=1 i − i=d0 +1 i }. By Lemma 5.7.4, H = {sign z1:d U zd+1:2d |
z
U ∈ O(d0 )} ⊆ H ◦ O(d). By Theorem 5.6.2, we have that

inf N ∗ (A, PZ H, ε) ≥ inf N ∗ (A, PZ (H◦GZ ), ε) ≥ inf N ∗ (A, PZ H0 , ε) (5.15)
A∈AGZ A∈A A∈A

By the lower bound in Theorem 5.3.3, we have inf A∈A N ∗ (A, PZ  H0 , ε) ≥


1
VCdim(H0 )+ln d0 (d0 −1)
ε
δ
. By Lemma 5.7.5 VCdim(H0 ) ≥ 2
= Ω(d2 ).
Upper Bound: Take CNN as defined in Section 5.3.2 with d = 2d0 , r = 2, k =
0 Pd0 2
1, σ : Rd → R, σ(z) = i=1 zi (square activation + average pooling), we have
n h P0 i o
FCNN = sign z → 2i=1 ai dj=1 x2(i−1)d0 +j w12 + b |a1 , a2 , w1 , b ∈ R .
P

Note that min errP (h) = 0, ∀P ∈ P, and the VC dimension of F is 3, by


h∈FCNN

Theorem 5.3.3, we have ∀P ∈ P, w.p. 1 − δ, errP (ERMFCNN ({zi , yi }ni=1 )) ≤ ε, if


n = Ω 1ε log 1ε + log 1δ ) .


Convergence guarantee for Gradient Descent: We initialize all the parameters


by i.i.d. standard gaussian and train the second layer by gradient descent only, i.e. set
the LR of w1 as 0. (Note training the second layer only is still a orthogonal-equivariant
algorithm for FC nets, thus it’s a valid separation.)

143
For any convex non-increasing surrogate loss of 0-1 loss l satisfying l(0) ≥
1, limx→∞ l(x) = 0 e.g. logistic loss, we define the loss of the weight x as (zk,i is
the kth coordinate of zi )

n n 2 d0 ! !
X X X X
L(x) = l(FCNN [x](zi )yi ) = l ai x2(k−1)d0 +j,i w12 + b yi ,
i=1 i=1 k=1 j=1

6 0 with probability 1, which means the data


which is convex in ai and b. Note w1 =
are separable even with fixed first layer, i.e. mina,b L(x) = L(x) |a=a∗ ,b=0 = 0, where a∗
is the ground truth. Thus with sufficiently small step size, GD converges to 0 loss
solution. By the definition of surrogate loss, L(x) < 1 implies for zi , l(zi yi ) < 1 and
thus the training error is 0.

5.7.3 Proofs of Lemmas for Theorem 5.6.4

Lemma 5.7.6. Define hU = sign z>


 
1:d U zd+1:2d , H = {hU | U ∈ O(d)}, and

ρ(U, V ) := ρZ (hU , hV ) = Pz∼N (0,I2d ) [hU (z) 6= hV (z)]. There exists a constant C,
 d(d−1)
such that the packing number D(H, ρZ , ε) = D(O(d), ρ, ε) ≥ Cε 2
.

Proof of Lemma 5.7.6. The key idea here is to first lower bound ρZ (U, V ) by

kU − V kF / d and apply volume argument in the tangent space of Id in O(d). We

144
have that

ρZ (hU , hV ) = P [hU (z) 6= hV (z)]


z∼N (0,I2d )

z>
  >  
= P 1:d U zd+1:2d z1:d V zd+1:2d < 0
z∼N (0,I2d )
 >
z1:d U V > z1:d
 
1
= E arccos
π z1:d ∼N (0,Id ) kz1:d k2
"s #
> >
1 z U V z1:d (5.16)
≥ E 2 − 2 1:d (by Lemma 5.7.1)
π z1:d ∼N (0,Id ) kz1:d k2
1 hp
>U V >z
i
= E 2 − 2z
π z∼Sd−1
1
(U > − V > )z F
 
= E
π z∼Sd−1

≥C1 kU − V kF / d (by Lemma 5.7.2)

Below we show it suffices to pack in the 0.4 `∞ neighborhood of Id . Let so(d) be


the Lie algebra of SO(d), i.e., {M ∈ Rd×d | M = −M > }. We also define the matrix
A2 A3
exponential mapping exp : Rd×d → Rd×d , where exp(A) = A + 2!
+ 3!
+ · · · . It holds
that exp(so(d)) = SO(d) ⊆ O(d). The benefit of covering in such neighborhood is
that it allows us to translate the problem into the tangent space of Id by the following
lemma.

Lemma 5.7.7. [91, Implication of Lemma 4] For any matrix A, B ∈ so(d), satisfying
that kAk∞ ≤ π4 , kBk∞ ≤ π4 , we have

0.4 kA − BkF ≤ kexp(A) − exp(B)kF ≤ kA − BkF . (5.17)

Therefore, we have that

√ π d2 √
D(H, ρZ , ε) ≥ D(O(d), C1 k·kF / d, ε) ≥ D(so(d)∩ B∞ , C1 k·kF / d, 2.5ε). (5.18)
4

145
d(d−1) 2
Note that so(d) is a 2
-dimensional subspace of Rd , by Inverse Santalo’s
inequality (Lemma 3, [92]), we have that

2
! d(d−1)
2
p
d
vol(so(d) ∩ B∞ ) dim(so(d))
d2
≥ C2  .
vol(so(d) ∩ B2 ) E Πso(d) (G) ∞
G∼N (0,Id2 )

d(d−1) G−G>
where vol(·) is the 2
volume defined in the space of so(d) and Πso(d) (G) = 2

is the projection operator onto the subspace so(d). We further have that

G − G> √
 
 
E Πso(d) (G) ∞
= E ≤ E [kGk∞ ] ≤ C3 d,
G∼N (0,Id2 ) G∼N (0,Id2 ) 2 ∞ G∼N (0,Id2 )

where the last inequality is by Theorem 4.4.5, [93].


Finally, we have that

π d2 √
D(so(d) ∩ B∞ , C1 k·kF / d, 2.5ε)
4 √
d2 10 dε
=D(so(d) ∩ B∞ , k·kF , )
C1 π
d2
  d(d−1)
vol(so(d) ∩ B∞ ) C1 π 2
≥ 2 × √
vol(so(d) ∩ B2d ) 10 dε (5.19)
 q  d(d−1)
2
C1 C2 π d(d−1)
2
≥ 
10dε
  d(d−1)
C 2
:=
ε

This completes the proof.

5.7.4 Proof of Theorem 5.6.5

Theorem 5.6.5 (Single distribution, multiple functions). There is a problem with


hP i
d 2
single input distribution, P = {PZ }H = {N (0, Id )}{z → sign i=1 i i | αi ∈ R},
α z
such that for any orthogonal equivariant algorithms A and ε > 0, N ∗ (A, P, ε) =

146
Ω(d2 /ε), while there’s a 2-layer ConvNets architecture, such that N (ERMCNN , P, ε, δ) =
d log 1ε +log 1
O( ε
δ
).

Proof of Theorem 5.6.5. Lower bound: Note P = {N (0, Id )}  H, where H =


hP i
d 2
{sign i=1 αi zi | αi ∈ R}. Since N (0, Id ) is invariant under all orthogonal transfor-

mations, by Theorem 5.6.2, N ∗ (A, N (0, Id ) ◦ H, ε0 ) = inf N ∗ (A, N (0, Id ) 


inf
equivariant A A
hP i
(H ◦ O(d)), ε0 ). Furthermore, it can be show that H ◦ O(d) = {sign i,j β ij zi zj |
βij ∈ R}, the sign functions of all quadratics in Rd . Thus it suffices to show learning
quadratic functions on Gaussian distribution needs Ω(d2 /ε) samples for any algorithm
(see Lemma 5.7.8, where we assume the dimension d can be divided by 4).
Upper bound:Take CNN as defined in Section 5.3.2 with d = d0 , r = 1, k =
1, σ : R → R, σ(x) = x2 (square activation + no pooling), we have FCNN =
n hP i o n hP i o
d 2 2 d 2
sign i=1 ai w1 zi + b |ai , w1 , b ∈ R = sign i=1 ai zi + b |ai , b ∈ R .

Note that min errP (h) = 0, ∀P ∈ P, and the VC dimension of F is d + 1,


h∈FCNN

by Theorem 5.3.3, we have ∀P ∈ P, w.p. 1 − δ, errP (ERMFCNN ({zi , yi }ni=1 )) ≤ ε, if


n = Ω 1ε d log 1ε + log 1δ ) .


Convergence guarantee for Gradient Descent: We initialize all the parameters


by i.i.d. standard gaussian and train the second layer by gradient descent only, i.e. set
the LR of w1 as 0. (Note training the second layer only is still a orthogonal-equivariant
algorithm for FC nets, thus it’s a valid separation.)
For any convex non-increasing surrogate loss of 0-1 loss l satisfying l(0) ≥
1, limx→∞ l(x) = 0 e.g. logistic loss, we define the loss of the weight x as (zk,i is
the kth coordinate of zi )

n n d
!
X X X
L(x) = l(CNN[x](zi )yi ) = l ( 2
w12 ai zk,i + b)yi ,
i=1 i=1 k=1

6 0 with probability 1, which means the data


which is convex in ai and b. Note w1 =
are separable even with fixed first layer, i.e. mina,b L(x) = L(x) |a=a∗ ,b=0 = 0, where a∗
147
is the ground truth. Thus with sufficiently small step size, GD converges to 0 loss
solution. By the definition of surrogate loss, L(x) < 1 implies for zi , l(zi yi ) < 1 and
thus the training error is 0.

5.7.5 Proof of Theorem 5.7.8


 
A 0 
Lemma 5.7.8. For A ∈ Rd×d , we define MA ∈ R2d×2d as MA =  , and
0 Id
hA : R4d → {−1, 1} as hA (z) = sign z>
 
1:2d MA z2d+1:4d . Then for H = {hA | ∀A ∈
Rd×d } ⊆ {sign z> Az]|∀A ∈ R4d×4d }, satisfies that it holds that for any d, algorithm
 

A and ε > 0,
d2
N ∗ (A, {N (0, I4d )}  H, ε) = Ω( ).
ε
2
1 d
Proof of Lemma 5.7.8. Below we will prove a Ω( ε
) lower bound for packing num-
ber, i.e. D(H, ρZ , 2ε0 ) = D(Rd×d , ρ, 2ε0 ), where ρ(U, V ) = ρZ (hU , hV ). Then we can
apply Long’s improved version Equation (5.4) of Benedek-Itai’s lower bound and get a
Ω(d2 /ε) sample complexity lower bound. The reason that we can get the correct rate
of ε is that the VCdim(H) is exactly equal to the exponent of the packing number. (cf.
the proof of Theorem 5.6.4)
Similar to the proof of Theorem 5.6.4, the key idea here is to first lower bound

ρ(U, V ) by kU − V kF / d and apply
 volume
 argument. Recall for A ∈ Rd×d , we

A 0 
define MA ∈ R2d×2d as MA =  , and hA : R4d → {−1, 1} as hA (z) =
0 Id
 > 
sign z1:2d MA z2d+1:4d . Then for H = {hA | ∀A ∈ Rd×d } . Below we will see it
2
suffices to lower bound the packing number of a subset of Rd×d , i.e. Id + 0.1B∞
d
,
d 2 d 2
where B∞ is the unit spectral norm ball. Clearly ∀z, kzk2 = 1, ∀U ∈ Id + 0.1B∞ ,
0.9 ≤ kU zk2 ≤ 1.1.

148
d 2
Thus, ∀U, V ∈ Id + 0.1B∞ we have that,

ρZ (hU , hV ) = P [hU (z) 6= hV (z)]


z∼N (0,I4d )

z> z>
   
= P 1:2d MV z2d+1:4d < 0
1:2d MU z2d+1:4d
z∼N (0,I4d )
" !#
1 z> M
1:2d U M >
z
V 1:2d
= E arccos
π z1:2d ∼N (0,I2d ) MU z1:2d 2 MV> z1:2d 2
>
"s #
1 z> M
1:2d U M >
z
V 1:2d
≥ E 2−2 (by Lemma 5.7.1)
π z1:2d ∼N (0,I2d ) MU z1:2d 2 MV> z1:2d 2
>
√ q 
2 > > > >
≥ E MU z1:2d 2 MV z1:2d 2 − z1:2d MU MV z1:2d
1.1π z1:2d ∼N (0,I2d )
q 
1 > > 2 > >
2
= E (MU − MV )z1:2d 2 − MU z1:2d 2 − MV z1:2d 2
1.1π z1:2d ∼N (0,I2d )
1
(MU> − MV> )z1:2d 2
 
≥ ( E
1.1π z1:2d ∼N (0,I2d )

MU> z1:2d 2 − MV> z1:2d 2 )


 
− E
z1:2d ∼N (0,I2d )
C0
(MU> − MV> )z1:2d
 
≥ E 2
(by Lemma 5.7.9)
1.1π z1:2d ∼N (0,I2d )

≥C1 kMU − MV kF / d (by Lemma 5.7.2)

=C1 kU − V kF / d

It remains to lower bound the packing number. We have the following for some
constant C:

d2  d2
√ d2

d2 vol(B∞ ) 0.1C1 C
M(0.1B∞ , C1 k·kF / d, ε) ≥ × √ ≥ , (5.20)
d2 ε
vol(B2 ) dε

The proof is completed by plugging the above bound and VCdim(H) = d2 into
Equation (5.4).

Lemma 5.7.9. Suppose z ∼ N (0, Id ), then ∀R, S ∈ Rd×d , we have

q q 
E [k(R − S)zk2 ] − E kRzk22 + kyk22 2 2
− kSzk2 + kyk2 ≥ C0 E [k(R − S)zk2 ] ,
z z,y z
149
for some constants C0 independent of R, S and d.

Proof of Lemma 5.7.9. Note that

q q
kRzk2 + kyk2 − kSzk22 + kyk22
2 2

kRzk2 + kSzk2
= |kRzk2 − kSzk2 | q q
kRzk2 + kyk2 + kSzk22 + kyk22
2 2

kRzk2 + kSzk2
≤ k(R − S)zk2 q q
kRzk22 + kyk22 + kSzk22 + kyk22

Let F (x, d) be the cdf of chi-square distribution, i.e. F (x, d) = Pz kzk22 ≤ x . Let
 

z = xd , we have F (zd, d) ≤ (ze1−z )d/2 ≤ (ze1−z )1/2 . Thus Py kyk22 ≤ d/2 < 1, which
 

implies for any kzk2 ≤ 10 d,

q q 
2 2 2 2
E kRzk2 + kyk2 − kSzk2 + kyk2
y
 
kRzk2 + kSzk2
≤ k(R − S)zk2 E  q q 
y 2 2 2 2
kRzk2 + kyk2 + kSzk2 + kyk2

≤(1 − α1 ) k(R − S)zk2 ,

for some 0 < α1 .


Therefore, we have

q q 
2 2 2 2
E [k(R − S)zk2 ] − E kRzk2 + kyk2 − kSzk2 + kyk2
z z,y
h h √ ii
≥ E k(R − S)zk2 1 kzk ≤ 10 d
z
√ i
q q h 
2 2 2 2
−E kRzk2 + kyk2 − kSzk2 + kyk2 1 kzk2 ≤ 10 d
z,y
h h √ ii
≥α1 E k(R − S)zk2 1 kzk2 ≤ 10 d
z

≥α1 α2 E [k(R − S)zk2 ] ,


z

150
for some constant α2 > 0. Here we use the other side of the tail bound of cdf of
chi-square, i.e. for z > 1, 1 − F (zd, d) < (ze1−z )d/2 < (ze1−z )1/2 .

5.7.6 Proofs of Theorem 5.6.6


2
M +M >
 > 
Lemma 5.7.10. Let M ∈ Rd×d , we have E (z M z)2 = 2
+ (tr[M ])2 .
z∼N (0,Id ) F

Proof of Lemma 5.7.10. It holds that

 >
(z M z)2

E
z∼N (0,Id )
" #
X
= E zi zj zi0 zj 0 Mij Mi0 j 0
z∼N (0,Id )
i,j,i0 j 0

 2 2 X 2
X  
(Mij2
 4
= + Mij Mji + Mii Mjj ) E x + Mii E x
x∼N (0,1) x∼N (0,1)
i6=j i
X X
= (Mij2 + Mij Mji + Mii Mjj ) + 3 Mii2
i6=j i
> 2
M +M
= + (tr[M ])2
2 F

where in the third inequality we use the fact that E [x4 ] = 3.


x∼N (0,1)

Theorem 5.6.6 (Single distribution, multiple functions, `2 regression). There is a


problem with single input distribution, P = {PZ }  H = {N (0, Id )}  { di=1 αi zi2 | αi ∈
P

R} , such that for any orthogonal equivariant algorithms A and ε > 0, N ∗ (A, P, ε) ≥
d(d+3)
2
(1 − ε) − 1, while there’s a 2-layer ConvNet architecture CNN, such that
N ∗ (ERMCNN , P, ε) ≤ d for any ε > 0.

Proof of Theorem 5.6.6. Lower bound: Similar to the proof of Theorem 5.6.5, it
suffices to for any algorithm A, N ∗ (A, H ◦ O(d), ε) ≥ d(d+3) 2
(1 − ε) − 1. Note that
P
H◦O(d) = { i,j βij zi zj | βij ∈ R} is the set of all quadratic functions. For convenience
we denote hM (z) = z> M z, ∀M ∈ Rd×d . Now we claim quadratic functions such that
151
d(d+1)
any learning algorithm A taking at most n samples must suffer 2
− n loss if the
ground truth quadratic function is sampled from i.i.d. gaussian. Moreover, the loss
d(d+3)
is at most 2
for the trivial algorithm always predicting 0. In other words, if the
d(d+1)
−n
expected relative error ε ≤ 2
d(d+3) , we must have the expected sample complexity
2
d(d+3)
N ∗ (A, P, ε) ≥ n. That is N ∗ (A, P, ε) ≥ 2
(1 − ε) − 1.
(1). Upper bound for E [y 2 ]. By Lemma 5.7.10,

" 2
#
 2 M + M>
E E y = E + (tr[M ])2
M ∼N (0,Id2 ) z∼PZ ,y=z> M z M ∼N (0,Id2 ) 2 F

d(d − 1) d(d + 3)
=d + d + = .
2 2

(2). Lower bound for expected loss.


The infimum of the test loss over all possible algorithms A is

 
inf E E [`P (A({zi , yi }ni=1 ))]
A M ∼N (0,I 2 ) (zi ,yi )∼PZ hM
d
  
([A({zi , yi }ni=1 )](z) 2
 
= inf E E E − y)
A M ∼N (0,I 2 ) (zi ,yi )∼PZ hM z,y∼PZ ◦hM
d
  
n 2
 
= inf E E E ([A({zi , hM (zi )}i=1 )](z) − hM (z))
A M ∼N (0,I 2 ) zi ∼PZ z∼PZ
d

≥ E [Varz,zi ,M [hM (z) | {zi , hM (zi )}ni=1 , z]]


zi ,z∼PZ
M ∼N (0,Id2 )

= E [VarM [hM (z) | {hM (zi )}ni=1 ]] ,


zi ,z∼PZ
M ∼N (0,Id2 )

where the inequality is achieved when [A({zi , yi }ni=1 )](z) = E [hM (z) | {zi , yi }ni=1 ].
M
n
Thus it suffices to lower bound VarM [hM (z) | {hM (zi )}i=1 ], for fixed {zi }ni=1 and
z. For convenience we define Sd = {A ∈ Rd×d | A = A> } be the linear space
of all d × d symmetric matrices, where the inner product hA, Bi := tr[A> B] and
Πn : Rd×d → Rd×d as the projection operator for the orthogonal complement of the

152
n-dimensional space spanned by zi z> d
i in S . By definition, we can expand

n
X
>
zz = αi zi z> >
i + Πn (zz ).
i=1

Thus even conditioned on {zi , yi }ni=1 and z,

n
X
>
hM (z) = tr[zz ] = αi tr[zi z> >
i M ] + tr[Πn (zz )M ],
i=1

2
still follows a gaussian distribution, N (0, Πn (zz> ) F
).
Note we can always find symmetric matrices Ei with kEi kF = 1 and tr[Ei> Ej ] = 0
such that Πn (A) = ki=1 Ei tr[Ei> A], where the rank of Πn , is at least d(d+1)
P
2
− n. Thus
we have that
 2

h i k
2
X
E Πn (zz> ) F
=E Ei tr[Ei> zz> ] 
z z
i=1 F
k h i
2
X
= E Ei tr[Ei> zz> ] F
z
i=1
Xk
 > > 2
= E (z Ei z) (by Lemma 5.7.10)
z
i=1
Xk
≥ kEi kF2 ≥ k
i=1
d(d + 1)
≥ −n
2

Thus the infimum of the expected test loss is

 
inf E E [`P (A({zi , yi }ni=1 ))]
A M ∼N (0,I 2 ) (zi ,yi )∼PZ hM
d

≥ E [VarM [hM (z) | {hM (zi )}ni=1 ]] .


zi ,z∼PZ
M ∼N (0,Id2 )
 h i
> 2
= E E Πn (zz ) F
.
zi ∼PZ z
M ∼N (0,Id2 )

153
d(d + 1)
≥ − n.
2

Upper bound: We use the same CNN construction as in the proof of The-
nP o
d 2 2
orem 5.6.5, i.e., the function class is FCNN = i=1 a i w z
1 i + b|a i , w 1 , b ∈ R =
nP o
d 2 2 2 2
i=1 ai zi + b|ai , b ∈ R . Thus given d + 1 samples, w.p. 1, (z1 , z2 , . . . , zd , 1) will be

linear independent, which means ERMCNN could recover the ground truth and thus
have 0 loss.

5.7.7 Proof of Theorem 5.6.7

Theorem 5.6.7. Let ti = ei +ei+1 and si = ei +ei+2 4 and P be the uniform distribution
on {(si , 1)}ni=1 ∪ {(ti , −1)}ni=1 , which is the classification problem for local textures in
a 1-dimensional image with d pixels. Then for any permutation equivariant algorithm
A, N (A, P, 18 , 18 ) ≥ N ∗ (A, P, 14 ) ≥ d
10
. Meanwhile, N (ERMCN N , P, 0, δ) ≤ log2 1δ + 2,
where ERMCN N stands for ERMCN N for function class of 2-layer ConvNets.

Proof of Theorem 5.6.7. Lower Bound: We further define permutation gi as gi (z) =


z − (ei+1 − ei+2 )> (ei+1 − ei+2 )z for i ∈ [d]. Clearly, gi (ti ) = si , gi (si ) = ti . For
i, j ∈ {1, 2, . . . , d}, we define d(i, j) = min{(i − j) mod d, (j − i) mod d}. It can be
verified that if d(i, j) ≥ 3, then gi (sj ) = sj , gi (tj ) = tj . For z = si or ti , z0 = sj or tj ,
we define d(z, z0 ) = d(i, j).

4
For vector z ∈ Rd , we define zi = z(i−1) mod d+1 if i ∈
/ [d].
154
Given Zn , yn , we define B := {d(z, zk ) ≥ 3, ∀k ∈ [n]} and we have P [B] =
d
d− 10 ∗5
Pz [d(z, zk ) ≥ 3, ∀k ∈ [n]] ≥ d
= 12 . Therefore, we have

errP (A(Zn , yn )) = P [A(Zn , yn )(z) 6= y] ≥ P [A(Zn , yn )(z) 6= y | B] P [B]


z,y,A z,y,A
1
≥ P [A(Zn , yn )(z) 6= y | B]
2 z,y,A
1 1
= P [A(Zn , yn )(si ) 6= 1 | B] + P [A(Zn , yn )(ti ) 6= −1 | B]
4 i,A 4 i,A
∗ 1 1
= P [A(gi (Zn ), yn )(gi (si )) 6= 1 | B] + P [A(Zn , yn )(ti ) 6= −1 | B]
4 i,A 4 i,A
1 1 1
= P [A(Zn , yn )(ti ) 6= 1 | B] + P [A(Zn , yn )(ti ) 6= −1 | B] = ,
4 i,A 4 i,A 4


where = uses the Definition 5.4.1.
Thus for any permutation equivariant algorithm A, N ∗ (A, {P }, 14 ) ≥ d
10
.
Upper Bound: Take CNN as defined in Section 5.3.2 with d0 = d, r = 1, k = 2, σ :
Rd → R, σ(z) = di=1 zi2 , we have
P

( " d
#)
X
FCNN = z → sign a1 (w1 zi−1 + w2 zi−2 )2 + b|a1 , w1 , w2 , b ∈ R .
i=1

Note that ∀h ∈ FCNN , h has the following form: ∀1 ≤ i ≤ d, h(si ) = a1 (2w12 +


2w22 ) + b, h(ti ) = a1 (w12 + w22 + (w1 + w2 )2 ) + b, thus the probability of ERMFCNN not
achieving 0 error is at most the probability that all data in the training dataset are ti
or si : (note the training error of ERMFCNN is 0)

−n −n+1
P [∀i ∈ [n], zi ∈ {sj | j ∈ [d]}]+ P [∀i ∈ [n], zi ∈ {tj | j ∈ [d]}] = 2 × 2 = 2 .

Convergence guarantee for Gradient Descent: We initialize all the parameters


by i.i.d. standard gaussian and train the second layer by gradient descent only, i.e.
set the LR of w1 , w2 as 0. (Note training the second layer only is still a permutation-

155
equivariant algorithm for FC nets, thus it’s a valid separation even we restrict the
discussion to training algorithm freezing the first layer.)
For any convex non-increasing surrogate loss of 0-1 loss l satisfying l(0) ≥
1, limx→∞ l(x) = 0 e.g. logistic loss, we define the loss of the weight x as

n
X
L(x) = l(CNN[x](zi )yi )
i=1

=NS × l a1 (2w12 + 2w22 ) + b + Nt × l −a1 (w12 + w22 + (w1 + w2 )2 ) + b .


 

6 0 with probability 1, which means the data are separable even with
Note w1 w2 =
fixed first layer, i.e. inf a1 ,b L(x) = 0. Further note L(x) is convex in a1 and b, which
implies with sufficiently small step size, GD converges to 0 loss solution. By the
definition of surrogate loss, L(x) < 1 implies for zi , l(zi yi ) < 1 and thus the training
error is 0.

156
Chapter 6

Low-Rank Implicit Bias of Matrix


Factorization

Matrix factorization is a simple and natural test-bed to investigate the implicit


regularization of gradient descent. Gunasekar et al. [16] conjectured that Gradient Flow
with infinitesimal initialization converges to the solution that minimizes the nuclear
norm, but as we have discussed in Chapter 7, matrix factorization is not a commuting
parametrization an therefore the existing analysis leveraging the equivalence between
reparametrized gradient descent and mirror descent cannot apply.
In this chapter, we develop new techniques for matrix factorization and provide
both theoretical and empirical evidence that for depth-2 matrix factorization, gradient
flow with infinitesimal initialization is mathematically equivalent to a simple heuristic
rank minimization algorithm, Greedy Low-Rank Learning, under some reasonable
assumptions. This generalizes the rank minimization view from previous works
to a much broader setting and enables us to construct counter-examples to refute
the conjecture from Gunasekar et al. [16]. We also extend the results to the case
where depth ≥ 3, and we show that the benefit of being deeper is that the above
convergence has a much weaker dependence over initialization magnitude so that this

157
rank minimization is more likely to take effect for initialization with practical scale.
Interestingly, despite there is a separation between depth equal to 2 and depth larger
than 3, it turns out that being deeper than 3 (e.g., increasing depth to infinity) has
only marginal value on the implicit bias.

6.1 Introduction

There are usually far more learnable parameters in deep neural nets than the number
of training data, but still deep learning works well on real-world tasks. Even with
explicit regularization, the model complexity of state-of-the-art neural nets is so large
that they can fit randomly labeled data easily [4]. Towards explaining the mystery of
generalization, we must understand what kind of implicit regularization does Gradient
Descent (GD) impose during training. Ideally, we are hoping for a nice mathematical
characterization of how GD constrains the set of functions that can be expressed by a
trained neural net.
As a direct analysis for deep neural nets could be quite hard, a line of works
turned to study the implicit regularization on simpler problems to get inspirations, for
example, low-rank matrix factorization, a fundamental problem in machine learning
and information process. Given a set of observations about an unknown matrix
W ∗ ∈ Rd×d of rank r∗  d, one needs to find a low-rank solution W that is compatible
with the given observations. Examples include matrix sensing, matrix completion,
phase retrieval, robust principal component analysis, just to name a few (see Chi
et al. 94 for a survey). When W ∗ is symmetric and positive semidefinite, one way
to solve all these problems is to parameterize W as W = U U > for U ∈ Rd×r and
optimize L(U ) := 12 f (U U > ), where f ( · ) is some empirical risk function depending
on the observations, and r is the rank constraint. In theory, if the rank constraint
is too loose, the solutions do not have to be low-rank and we may fail to recover

158
W ∗ . However, even in the case where the rank is unconstrained (i.e., r = d), GD
with small initialization can still get good performance in practice. This empirical
observation reveals that the implicit regularization of GD exists even in this simple
matrix factorization problem, but its mechanism is still on debate. Gunasekar et al.
[16] proved that Gradient Flow (GD with infinitesimal step size, a.k.a., GF) with
infinitesimal initialization finds the minimum nuclear norm solution in a special case
of matrix sensing, and further conjectured this holds in general.

Conjecture 6.1.1 (Gunasekar et al. 16, informal). With sufficiently small initializa-
tion, GF converges to the minimum nuclear norm solution of matrix sensing.

Subsequently, Arora et al. [95] challenged this view by arguing that a simple
mathematical norm may not be a sufficient language for characterizing implicit
regularization. One example illustrated in Arora et al. [95] is regarding matrix sensing
with a single observation. They showed that GD with small initialization enhances
the growth of large singular values of the solution and attenuates that of smaller
ones. This enhancement/attenuation effect encourages low-rank, and it is further
intensified with depth in deep matrix factorization (i.e., GD optimizes f (U1 · · · UL )
for L ≥ 2). However, these are not captured by the nuclear norm alone. Gidel et al.
[96], Gissin et al. [97] further exploited this idea and showed in the special case of
full-observation matrix sensing that GF learns solutions with gradually increasing
rank. Razin and Cohen [98] showed in a simple class of matrix completion problems
that GF decreases the rank along the trajectory while any norm grows towards infinity.
More aggressively, they conjectured that the implicit regularization can be explained
by rank minimization rather than norm minimization.

Our Contributions. In this chapter, we move one further step towards resolving
the implicit regularization in the matrix factorization problem. Our theoretical results
show that GD performs rank minimization via a greedy process in a broader setting.
159
Specifically, we provide theoretical evidence that GF with infinitesimal initialization
is in general mathematically equivalent to another algorithm called Greedy Low-Rank
Learning (GLRL). At a high level, GLRL is a greedy algorithm that performs rank-
constrained optimization and relaxes the rank constraint by 1 whenever it fails to
reach a global minimizer of f ( · ) with the current rank constraint. As a by-product,
we refute Conjecture 6.1.1 by demonstrating an counterexample (Example 6.5.9).
We also extend our results to deep matrix factorization Section 6.6, where we
prove that the trajectory of GF with infinitesimal identity initialization converges to
a deep version of GLRL, at least in the early stage of the optimization. We also use
this result to confirm the intuition achieved on toy models [97], that benefits of depth
in matrix factorization is to encourage rank minimization even for initialization with
a relatively larger scale, and thus it is more likely to happen in practice. This shows
that describing the implicit regularization using GLRL is more expressive than using
the language of norm minimization. We validate all our results with experiments
in Section 6.8.

6.2 Related Works

Norm Minimization. The view of norm minimization, or the closely related view of
margin maximization, has been explored in different settings. Besides the nuclear norm
minimization for matrix factorization [16] discussed in the introduction, previous works
have also studied the norm minimization/margin maximization for linear regression
[13, 99–103], deep linear neural nets [104, 105], homogeneous neural nets [106, 107],
ultra-wide neural nets [108–110].

Small Initialization and Rank Minimization. The initialization scale can


greatly influence the implicit regularization. A sufficiently large initialization can
make the training dynamics fall into the lazy training regime defined by Chizat
160
et al. [111] and diminish test accuracy. Using small initialization is particularly
important to bias gradient descent to low-rank solutions for matrix factorization, as
empirically observed by Gunasekar et al. [16]. Arora et al. [95], Gidel et al. [96], Gissin
et al. [97], Razin and Cohen [98] studied how gradient flow with small initialization
encourages low-rank in simple settings, as discussed in the introduction. Li et al.
[112] proved recovery guarantees for gradient flow solving matrix sensing under
Restricted Isometry Property (RIP), but the proof cannot be generalized easily to
the case without RIP. Belabbas [113] made attempts to prove that gradient flow is
approximately rank-1 in the very early phase of training, but it does not exclude
the possibility that the approximation error explodes later and gradient flow is not
converging to low-rank solutions. Compared to these works, the current paper studies
how GF encourages low-rank in a much broader setting.

6.3 Preliminaries

Notations. For two matrices A, B, we define hA, Bi := Tr(AB > ) as their inner
product. We use kAkF , kAk∗ and kAk2 to denote the Frobenius norm, nuclear norm
and the largest singular value of A respectively. For a matrix A ∈ Rd×d , we use
λ1 (A), . . . , λd (A) to denote the eigenvalues of A in decreasing order (if they are all
reals). We define Sd as the set of symmetric d × d matrices and S+
d ⊆ Sd as the set of

positive semidefinite (PSD) matrices. We write A  B or B  A if A − B is PSD.


We use S+ +
d,r , Sd,≤r to denote the set of d × d PSD matrices with rank equal to r and

small than r respectively.

Matrix Factorization. Matrix factorization problem asks one to optimize


L(U, V ) := 12 f (U V > ) among U, V ∈ Rd×r , where f : Rd×d → R is a convex func-
tion. A notable example is matrix sensing. There is an unknown rank-r∗ matrix
W ∗ ∈ Rd×d with r∗  d. Given m measurements X1 , . . . , Xm ∈ Rd×d , one can
161
observe yi := hXi , W ∗ i through each measurement. The goal of matrix sensing is to
reconstruct W ∗ via minimizing f (W ) := 12 m 2
P
i=1 (hW, Xi i − yi ) . Matrix completion

is a notable special case of matrix sensing in which every measurement has the form
Xi = epi e>
qi , where {e1 , · · · , ed } stands for the standard basis (i.e., exactly one entry is

observed through each measurement).


For technical simplicity, in this chapter we focus on the symmetric case as in
previous works [16]. Given a C 3 -smooth convex function f : Rd×d → R, we aim to find
a low-rank solution for the convex optimization problem (P):

min f (W ) s.t. W  0 (P)

For this, we parameterize W as W = U U > for U ∈ Rd×r and optimize L(U ) :=


1
2
f (U U > ).
We assume WLOG throughout this chapter that f (W ) = f (W > ); oth-
erwise, we can set f 0 (W ) = 12 f (W ) + f (W > ) so that f 0 (W ) = f 0 (W > ) while


L(U ) = 12 f 0 (U U > ) is unaffected. This assumption makes ∇f (W ) symmetric for every


symmetric W .
Note that matrix factorization in the general case can be reduced to this symmetric
case: let U 0 = [ VU ] ∈ R2d×r , f 0 ([ CA D
B ]) = 1 f (B) + 1 f (C), then f (U V > ) = f 0 (U 0 U 0> ).
2 2

So focusing on the symmetric case does not lose generality.

Gradient Flow. In this chapter, we analyze Gradient Flow (GF) on symmetric


matrix factorization, which is defined by the following ODE for U (t) ∈ Rd×r :

dU
= −∇L(U ) = −∇f (U U > )U. (6.1)
dt

Let W (t) = U (t)U (t)> ∈ Rd×d . Then the following end-to-end dynamics holds for
W (t):
dW
= −W ∇f (W ) − ∇f (W )W =: g(W ). (6.2)
dt

162
We use φ(W0 , t) to denote the matrix W (t) in (6.2) when W (0) = W0  0. Throughout
this chapter, we assume φ(W0 , t) exists for all t ∈ R, W0  0. It is easy to prove that
U is a stationary point of L( · ) (i.e., ∇L(U ) = 0) iff W = U U > is a critical point of
(6.2) (i.e., g(W ) = 0); see Lemma 6.10.1 for a proof. If W is a minimizer of f ( · ) in
S+
d (i.e., W is a minimizer of (P)), then W is a critical point of (6.2), but the reverse

may not be true, e.g., g(0) = 0, but 0 is not necessarily a minimizer.


In this chapter, we particularly focus on the overparameterized case, where r = d,
to understand the implicit regularization of GF when there is no rank constraint for
the matrix W .

6.4 Warmup Examples

Before introducing our main results, we illustrate how GD performs greedy learning
using two warmup examples.

Linearization Around the Origin. In general, for a loss function L(U ) =


1
2
f (U U > ), we can always apply Taylor expansion f (W ) ≈ f (0) + hW, ∇f (0)i around
the origin to approximate it with a linear function. This motivates us to study the
linear case: f (W ) := f0 − hW, Qi for some symmetric matrix Q. In this case, the
dU
matrix U follows the ODE, dt
= QU , which can be understood as a continuous
version of the classical power iteration method for solving the top eigenvector. Let
Q := di=1 µi vi vi> be the eigendecomposition of Q, where µ1 ≥ µ2 ≥ · · · ≥ µd and
P

v1 , . . . , vd are orthogonal to each other. Then we can write the solution as:

Xd 
U (t) = etQ U (0) = eµi t vi vi> U (0). (6.3)
i=1

163
When µ1 > µ2 , the ratio between eµ1 t and eµi t for i 6= 1 increases exponentially fast.
As t → +∞, U (t) and W (t) become approximately rank-1 as long as vi> U (0) 6= 0, i.e.,

lim e−µ1 t U (t) = v1 v1> U (0), lim e−2µ1 t W (t) = (v1> W (0)v1 )v1 v1> . (6.4)
t→∞ t→∞

The analysis for the simple linear case reveals that GD encourages low-rank through
a process similar to power iteration. However, f (W ) is non-linear in general, and the
linear approximation is close to f (W ) only if W is very small. With sufficiently small
initialization, we can imagine that GD still resembles the above power iteration in the
early phase of the optimization. But what if W (t) grows to be so large that the linear
approximation is far from the actual f (W )?

Full-observation Matrix Sensing. To understand the dynamics of GD when


the linearization fails, we now consider a well-studied special case [97]: L(U ) =
1
2
f (U U > ), f (W ) = 12 kW − W ∗ k2F for some unknown PSD matrix W ∗ . GF in this case
can be written as:

dU dW
= (W ∗ − U U > )U, = (W ∗ − W )W + W (W ∗ − W ). (6.5)
dt dt

Pd
Let W ∗ := i=1 µi vi vi> be the eigendecomposition of W ∗ . Our previous analysis
shows that the dynamics is approximately dU
dt
= W ∗ U in the early phase and thus
encourages low-rank.

To get a sense for the later phases, we simplify the setting by specifying U (0) = αI
for a small number α. We can write W (0) and W ∗ as diagonal matrices W (0) =
diag(α, α, · · · , α), W ∗ = diag(µ1 , µ2 , · · · , µd ) with respect to the basis v1 , . . . , vd . It is
easy to see that W (t) is always a diagonal matrix, since the time derivatives of non-
diagonal coordinates stay 0 during training. Let W (t) = diag(σ1 (t), σ2 (t), · · · , σd (t)),
d
then σi (t) satisfies the dynamical equation σ (t)
dt i
= 2σi (t)(µi − σi (t)), and thus

164
αµi
σi (t) = α+(µi −α)e−2µi t
. This shows that every σi (t) increases from α to µi over time.
As α → 0, every σi (t) has a sharp transition from near 0 to near µi at time roughly
( 2µ1 i + o(1)) log α1 , which can be seen from the following limit:


c ∈ (− 2µ1 i , 0),

  αµi 0

lim σi ( 2µ1 i + c) log(1/α) = lim =
α→0 α→0 α + (µi − α)α1+2cµi 
µ i
 c ∈ (0, +∞).

This means for every q ∈ ( 2µ1 i , 2µ1i+1 ) for i = 1, . . . , d − 1 (or q ∈ ( 2µ1 i , +∞) for
i = d), limα→0 W (q log(1/α)) = diag(µ1 , µ2 , . . . , µi , 0, 0, · · · , 0). Therefore, when the
initialization is sufficiently small, GF learns each component of W ∗ one by one,
according to the relative order of eigenvalues. At a high level, this shows a greedy
nature of GD: GD starts learning with simple models; whenever it underfits, it increases
the model complexity (which is rank in our case). This is also called sequential learning
or incremental learning in the literature [96, 97].
However, it is unclear how and why this sequential learning/incremental learning
can occur in general. Through the first warmup example, we may understand why GD
learns a rank-1 matrix in the early phase, but does GD always learn solutions with
rank 2, 3, 4, . . . sequentially? If true, what is the mechanism behind this? The current
paper answers the questions by providing both theoretical and empirical evidence that
the greedy learning behavior does occur in general with a similar reason as for the
first warmup example.

165
6.5 Main Results: Equivalence between Gradi-

ent Descent and Greedy Low-Rank Learning

(GLRL)

In this section, we present a trajectory-based analysis for the implicit bias of GF on


matrix factorization. Our main result is that GF with infinitesimal initialization is
generically the same as that of a simple greedy algorithm, Greedy Low-Rank Learning
(GLRL, Algorithm 5).
The GLRL algorithm consists of several phases, numbered from 1. In phase r,
GLRL increases the rank constraint to r and optimizes L(Ur ) := 12 f (Ur Ur> ) among
Ur ∈ Rd×r via GD until it reaches a stationary point Ur (∞), i.e., ∇L(Ur (∞)) = 0. At
convergence, Wr := Ur (∞)Ur> (∞) is a critical point of (6.2), and we call it the r-th
critical point of GLRL. If Wr is further a minimizer of f ( · ) in S+
d , or equivalently,

λ1 (−∇f (Wr )) ≤ 0 (see Lemma 6.10.2), then GLRL returns Wr ; otherwise GLRL
enters phase r + 1.

Algorithm 5 Greedy Low-Rank Learning (GLRL)


Require: step size η > 0; small  > 0
r ← 0, W0 ← 0 ∈ Rd×d , and U0 (∞) ∈ Rd×0 is an empty matrix
while λ1 (−∇f (Wr )) > 0 do
r ←r+1
ur ← unit top eigenvector of −∇f (Wr−1 )

Ur (0) ← [Ur−1 (∞) ur ] ∈ Rd×r
for t = 0, 1, . . . do
Ur (t + 1) ← Ur (t) − η∇L(Ur (t))
Wr ← Ur (∞)Ur> (∞) 1

return Wr

166
To set the initial point of GD in phase r, GLRL appends a small column vector
δr ∈ Rd to the resulting stationary point Ur−1 (∞) from the last phase, i.e., Ur (0) ←
[Ur−1 (∞) δr ] ∈ Rd×r (in the case of r = 1, U1 (0) ← [δ1 ] ∈ Rd×1 ). In this way,
Ur (0)Ur> (0) = Wr−1 + δr δr> is perturbed away from the (r − 1)-th critical point. In

GLRL, we set δr = ur , where ur is the top eigenvector of −∇f (Wr ) with unit
norm kur k2 = 1, and  > 0 is a parameter controlling the magnitude of perturbation
(preferably very small). Note that it is guaranteed that λ1 (−∇f (Wr−1 )) > 0; otherwise
Wr−1 is a minimizer of the convex function f ( · ) in S+
d and GLRL exits before phase

r. Expanding f ( · ) around Wr−1 shows that the loss is decreasing in this choice of δr .

1 1
L(Ur (0)) = f (Wr−1 + δr δr> ) = L(Ur−1 (∞)) + δr> ∇f (Wr−1 )δr + O(kδr k42 )
2 2

= L(Ur−1 (∞)) − λ1 (−∇f (Wr−1 )) + O(2 ).
2

Trajectory of GLRL. We define the (limiting) trajectory of GLRL by taking the


learning rate η → 0. The goal is to show that the trajectory of GLRL is close to that
of GF with infinitesimal initialization. Recall that φ(W0 , t) stands for the solution
W (t) in (6.2) when W (0) = W0 .

Definition 6.5.1 (Trajectory of GLRL). Let W 0, := 0 be the 0th critical point
of GLRL. For every r ≥ 1, if the (r − 1)-th critical point W r−1, exists and is
not a minimizer of f ( · ) in S+ G >
d , we define Wr, (t) := φ(W r−1, + ur, ur, , t), where

ur, is a top eigenvector of ∇f (W r−1, ) with unit norm, kur, k2 = 1. We define


G
W r, := limt→+∞ Wr, (t) to be the r-th critical point of GLRL if the limit exists.

Throughout this chapter, we always focus on the case where the top eigenvalue of
every ∇f (W r−1, ) is unique. In this case, the trajectory of GLRL is unique for every
 > 0, since the normalized top eigenvectors can only be ±ur, , and both of them lead
G
to the same Wr, (t).

167
Comparison to existing greedy algorithms for rank-constrained optimiza-
tion. The most related one to GLRL (Algorithm 5) is probably Rank-1 Matrix
Pursuit (R1MP) proposed by Wang et al. [114] for matrix completion, which was later
generalized to general convex loss in [115]. R1MP maintains a set of rank-1 matrices
as the basis, and in phase r, R1MP adds the same ur u> r as defined in Algorithm 5 into
Pr
its basis and solve minα f ( i=1 αi ui u>
i ) for rank-r estimation. The main difference

between R1MP and GLRL is that the optimization in each phase of R1MP is performed
on the coefficients α, while the entire Ur evolves with GD in each phase of GLRL. In
Figure 6.5, we provide empirical evidence that GLRL generalizes better than R1MP
when ground truth is low-rank, although GLRL may have a higher computational
cost depending on η, .
Similar to R1MP, Greedy Efficient Component Optimization (GECO, Shalev-
Shwartz and Singer 116) also chooses the r-th component of its basis as the top
eigenvector of −∇f (Wr ), while it solves minβ f ( 1≤i,j≤r βij ui u>
P
j ) for the rank-r es-

timation. Khanna et al. [117] provided convergence guarantee for GECO assuming
strong convexity. Haeffele and Vidal [118] proposed a local-descent meta algorithm, of
which GLRL can be viewed as a specific realization.

6.5.1 The Limiting Trajectory: A General Theorem for Dy-

namical System

To prove the equivalence between GF and GLRL, we first introduce our high-level
idea by analyzing the behavior of a more general dynamical system around its critical
point, say 0. A specific example is (6.2) if we set x to be the vectorization of W .

dx
= g(x), where g(0) = 0. (6.6)
dt

168
We use φ(x0 , t) to denote the value of x(t) in the case of x(0) = x0 . We assume that g(x)
is C 2 -smooth with J(x) being the Jacobian matrix and φ(x0 , t) exists for all x0 and t.
For ease of presentation, in the main text we assume J(0) is diagonalizable over R and
defer the same result for the general case into Section 6.12.3. Let J(0) = Ṽ D̃Ṽ −1 be the
eigendecomposition, where Ṽ is an invertible matrix and D̃ = diag(µ̃1 , . . . , µ̃d ) is the
diagonal matrix consisting of the eigenvalues µ̃1 ≥ µ̃2 ≥ · · · ≥ µ̃d . Let Ṽ = (ṽ1 , . . . , ṽd )
and Ṽ −1 = (ũ1 , . . . , ũd )> , then ũi , ṽi are left and right eigenvectors associated with µ̃i
Pd
and ũ>
i ṽj = δij . We can rewrite the eigendecomposition as J(0) =
>
i=1 µ̃i ṽi ũi .

We also assume the top eigenvalue µ̃1 is positive and unique. Note µ̃1 > 0 means
the critical point x = 0 is unstable, and in matrix factorization it means 0 is a strict
saddle point of L( · ).
The key observation is that if the initialization is infinitesimal, the trajectory is
almost uniquely determined. To be more precise, we need the following definition:

Definition 6.5.2. For any x0 ∈ Rd and u ∈ Rd , we say that {xα }α∈(0,1) converges to
D E
−x0
x0 with positive alignment with u if lim xα = x0 and lim inf kxxαα−x 0 k2
, u > 0.
α→0 α→0

xα −x0
A special case is that the direction of xα − x0 converges, i.e., x̄ := limα→0 kxα −x0 k2

exists. In this case, {xα } has positive alignment with either u or −u except for a
zero-measure subset of x̄. This means any convergent sequence generically falls into
either of these two categories.
The following theorem shows that if the initial point xα converges to 0 with
positive alignment with ũ1 as α → 0, the trajectory starting with xα converges
1
to a unique trajectory z(t) := φ(αṽ1 , t + µ̃1
log α1 ). By symmetry, there is another
unique trajectory for sequences {xα } with positive alignment to −ũ1 , which is z 0 (t) :=
1
φ(−αṽ1 , t + µ̃1
log α1 ). This is somewhat surprising: different initial points should lead
to very different trajectories, but our analysis shows that generically there are only
two limiting trajectories for infinitesimal initialization. We will soon see how this
theorem helps in our analysis for matrix factorization in Sections 6.5.2 and 6.5.3.
169
1
Theorem 6.5.3. Let zα (t) := φ(αṽ1 , t + µ̃1
log α1 ) for every α > 0, then z(t) :=
limα→0 zα (t) exists and is also a solution of (6.6), i.e., z(t) = φ(z(0), t). If δα converges
to 0 with positive alignment with ũ1 as α → 0, then ∀t ∈ R, there is a constant C > 0
such that
  γ̃
1 1 µ̃1 +γ̃
φ δα , t + µ̃1
log hδα ,ũ1 i
− z(t) ≤ C · kδα k2 , (6.7)
2

for every sufficiently small α, where γ̃ := µ̃1 − µ̃2 > 0 is the eigenvalue gap.

Proof sketch. The main idea is to linearize the dynamics near origin as we have done
for the first warmup example. For sufficiently small x, by Taylor expansion of g(x),
dx
the dynamics is approximately dt
≈ J(0)x, which can be understood as a continuous
version of power iteration. If the linear approximation is exact, then x(t) = etJ(0) x(0).
For large enough t0 , et0 J(0) = di=1 eµ̃i t0 ṽi ũ> µ̃1 t0
ṽ1 ũ> µ̃2 t0
P
i = e 1 + O(e ). Therefore, as
long as the initial point x(0) has a positive inner product with ũ1 , x(t0 ) should be very
close to ṽ1 for some  > 0, and the rest of the trajectory after t0 should be close to
the trajectory starting from ṽ1 . However, here is a tradeoff: we should choose t0 to be
large enough so that the power iteration takes effect; but if t0 is so large that the norm
of x(t0 ) reaches a constant scale, then the linearization fails unavoidably. Nevertheless,
if the initialization scale is sufficiently small, we show via a careful error analysis that
there is always a suitable choice of t0 such that x(t0 ) is well approximated by ṽ1 and
the difference between x(t0 + t) and φ(ṽ1 , t) is bounded as well. We defer the details
to Section 6.12.

6.5.2 Equivalence Between GD and GLRL: Rank-One Case

Now we establish the equivalence between GF and GLRL in the first phase. The main
idea is to apply Theorem 6.5.3 on (6.2). For this, we need the following lemma on the
eigenvalues and eigenvectors.

170
Lemma 6.5.4. Let g(W ) := −W ∇f (W ) − ∇f (W )W and J(W ) be its Jacobian.
Then J(0) is symmetric and thus diagonalizable. Let −∇f (0) = di=1 µi u1[i] u>
P
1[i] be

the eigendecomposition of the symmetric matrix −∇f (0), where µ1 ≥ µ2 ≥ · · · ≥ µd .


Then J(0) has the form:

d X
X d
J(0)[∆] = (µi + µj ) ∆, u1[i] u> >
1[j] u1[i] u1[j] , (6.8)
i=1 j=1

where J(0)[∆] stands for the resulting matrix produced by left-multiplying J(0) to the
vectorization of ∆. For every pair of 1 ≤ i ≤ j ≤ d, µi + µj is an eigenvalue of J(0)
and u1[i] u> >
1[j] + u1[j] u1[i] is a corresponding eigenvector. All the other eigenvalues are 0.

We simplify the notation by letting u1 := u1[1] . A direct corollary of Lemma 6.5.4


is that u1 u>
1 is the top eigenvector of J(0). According to Theorem 6.5.3, now there are

only two types of trajectories, which correspond to infinitesimal initialization Wα → 0


with positive alignment with u1 u> >
1 or −u1 u1 . As the initialization must be PSD, Wα →

0 cannot have positive alignment with −u1 u>


1 . For the former case, Theorem 6.5.6

below states that, for every fixed time t, the GF solution φ(Wα , T (Wα ) + t) after
shifting by a time offset T (Wα ) := 1
2µ1
log(hWα , u1 u> −1
1 i ) converges to the GLRL

solution W1G (t) as Wα → 0. The only assumption for this result is that 0 is not a
minimizer of f ( · ) in S+
d (which is equivalent to λ1 (−∇f (0)) > 0) and −∇f (0) has

an eigenvalue gap. In the full observation case, this assumption is satisfied easily if
the ground-truth matrix has a unique top eigenvalue. The proof for Theorem 6.5.6 is
deferred to Section 6.14.1.

Assumption 6.5.5. µ1 > max{µ2 , 0}, where µ1 := λ1 (−∇f (0)), µ2 := λ2 (−∇f (0)).

Theorem 6.5.6. Under Assumption 6.5.5, the following limit W1G (t) exists and is a
solution of (6.2).

   
W1G (t) := lim W1,
G 1
2µ1
log 1 + t = lim φ u1 u> , 1
1 2µ1 log 1

+ t . (6.9)
→0 →0
171
Let {Wα } ⊆ S+ >
d be PSD matrices converging to 0 with positive alignment with u1 u1

as α → 0, that is, limα→0 Wα = 0 and ∃α0 , q > 0 such that Wα , u1 u>


1 ≥ q kWα kF
for all α < α0 . Then ∀t ∈ R, there is a constant C > 0 such that

  γ̃

φ Wα , 2µ1 1 log 1
+t − W1G (t) 2µ1 +γ̃
≤ C kWα kF (6.10)
hWα ,u1 u>1 i F

for every sufficiently small α, where γ̃ := 2µ1 − (µ1 + µ2 ) = µ1 − µ2 .

It is worth to note that W1G (t) has rank ≤ 1 for any t ∈ R, since every W1,
G
(t) has
rank ≤ 1 and the set S+
d,≤1 is closed. This matches with the first warmup example:

GD does start learning with rank-1 solutions. Interestingly, in the case where the limit
W 1 := limt→+∞ W1G (t) happens to be a minimizer of f ( · ) in S+
d , GLRL should exit

with the rank-1 solution W 1 after the first phase, and the following theorem shows
that this is also the solution found by GF.

Assumption 6.5.7. f (W ) is locally analytic at each point.

Theorem 6.5.8. Under Assumptions 6.5.5 and 6.5.7, if kW1G (t)kF is bounded for all
t ≥ 0, then the limit W 1 := limt→+∞ W1G (t) exists. Further, if W 1 is a minimizer of
f ( · ) in S+ +
d , then for PSD matrices {Wα } ⊆ Sd converging to 0 with positive alignment

with u1 u>
1 as α → 0, it holds that limα→0 limt→+∞ φ(Wα , t) = W 1 .

Assumption 6.5.7 is a natural assumption, since f ( · ) in most cases of matrix


factorization is a quadratic or polynomial function (e.g., matrix sensing, matrix
completion). In general, it is unlikely for a gradient-based optimization process to
get stuck at saddle points [119, 120]. Thus, we should expect to see in general that
GLRL finds the rank-1 solution if the problem is feasible with rank-1 matrices. This
means at least for this subclass of problems, the implicit regularization of GD is rather
unrelated to norm minimization. Below is a concrete example:

172
Example 6.5.9 (Counter-example of Conjecture 6.1.1, Gunasekar et al. 16). Theo-
rem 6.5.8 enables us to construct counterexamples of the implicit nuclear norm regular-
ization conjecture in [16]. The idea is to construct a problem where every rank-1 station-
ary point of L(U ) (i.e., ∇L(U ) = 0 and U ∈ Rd×d is rank-1) attains the global minimum
but none of them is minimizing the nuclear norm. Below we give a concrete matrix
completion problem that meets the above requirement. Let M be a partially observed
matrix to be recovered, where the entries in Ω = {(1, 3), (1, 4), (2, 3), (3, 1), (3, 2), (4, 1)}
are observed and the others (marked with “?”) are unobserved. The optimization
problem is defined formally by L(U ) = 12 f (U U > ), f (W ) = 21 (i,j)∈Ω (Wij − Mij )2 .
P

     
? ?1 R R 1 1 R  1 R 1 R 
     
? 2 2
? R ? 1 R R 1
R R R R 

M =  , Mnorm =   , Mrank =  .
  
1 R ? ? 1 R R 1 1 R 1 R 
     
     
2 2
R ? ? ? R 1 1 R R R R R

Here R > 1 is a large constant, e.g., R = 100. The minimum nuclear norm solution
is the rank-2 matrix Mnorm , which has kMnorm k∗ = 4R (which is 400 when R = 100).
Mrank is a rank-1 solution with much larger nuclear norm, kMnorm k∗ = 2R2 + 2 (which
is 20002 when R = 100). We can verify that f ( · ) satisfies Assumptions 6.5.5 and 6.5.7
and W1G (t) converges to the rank-1 solution Mrank . Therefore, GF with infinitesimal
initialization converges to Mrank rather than Mnorm , which refutes the conjecture in
[16]. See Section 6.11 for a formal statement.

6.5.3 Equivalence between GD and GLRL: General Case

Theorem 6.5.6 shows that for any fixed time t, the trajectory of GLRL in the first
phase approximates GF with infinitesimal initialization, i.e., W1G (t) = limα→0 W
cα (t),
cα (t) := φ(Wα , 1 log(hWα , u1 u>
where W −1 G
1 i ) + t). However, W1 (∞) 6= limα→0 Wα (∞)
c
2µ1

does not hold in general, unless the prerequisite in Theorem 6.5.8 is satisfied, i.e.,
173
unless W 1 = W1G (∞) is a minimizer of f ( · ) in S+
d . This is because of the well-known

result that GD converges to local minimizers [119, 121]. We adapt Theorem 2 of Lee
et al. [119] to the setting of GF (Theorem 6.14.5) and obtain the following result
(Theorem 6.5.10); see Section 6.14.4 for the proof.

Theorem 6.5.10. Let f : Rd×d → R be a convex C 2 -smooth function. (1). All


stationary points of L : Rd×d → R, L(U ) = 12 f (U U > ) are either strict saddles or global
minimizers; (2). For any random initialization, GF (6.1) converges to strict saddles
of L(U ) with probability 0.

Therefore, for convex f ( · ) such as matrix sensing and completion, suppose f ( · )


has no rank-1 PSD minimizer, then no matter how small α is, W
cα (∞) (if exists) is a

minimizer of f ( · ) with a higher rank and thus away from the rank-1 matrix W 1 . In
other words, W1G (t) only describes the limiting trajectory of GF in the first phase, i.e.,
when GF goes from near 0 to near W 1 . After a sufficiently long time (which depends
on α), GF escapes the critical point W 1 , but this part is not described by W1G (t).
To understand how GF escapes W 1 , a priori, we need to know how GF approaches
W 1 . Using a similar argument for Theorem 6.5.3, Theorem 6.5.11 shows that generically
GF only escapes in the direction of v1 v1> , where v1 is the (unique) top eigenvector of
−∇f (W 1 ), and thus the limiting trajectory exactly matches with that of GLRL in
the second phase until GF gets close to another critical point W 2 ∈ S+
d,≤2 . If W 2 is

still not a minimizer of f ( · ) in S+ +


d (but it is a local minimizer in Sd,≤2 generically),

then GF escapes W 2 and the above process repeats until W K is a minimizer in S+


d for

some K. Here by “generically” we hide some technical assumptions and we elaborate


on them in Section 6.5.4. See Figure 6.1 and Figure 6.2 for experimental verification
of the equivalence between GD and GLRL. We end this section with the following
characterization of GF:

Theorem 6.5.11 (Theorem 6.14.2, informal). Let W be a critical point of (6.2)


satisfying that W is a local minimizer of f ( · ) in S+
d,≤r for some r ≥ 1 but not a
174
Pd
minimizer in S+
d . Let −∇f (W ) = i=1 µi vi vi> be the eigendecomposition of −∇f (W ).
If µ1 > µ2 and if there exists time Tα ∈ R for every α so that φ(Wα , Tα ) converges
to W with positive alignment with the top principal component v1 v1> as α → 0, then
1 1
for every fixed t, lim φ(Wα , Tα + log + t) exists and is equal to
α→0 2µ1 hφ(Wα ,Tα ),v1 v1> i
W G (t) := lim→0 φ(W + v1 v1> , 2µ1 1 log 1 + t).

Characterization of the trajectory of GF. Generically, the trajectory of GF


with small initialization can be split into K phases by K + 1 critical points of (6.2),
{W r }K
r=0 (W 0 = 0), where in phase r GF escapes from W r−1 in the direction of the

top principal component of −∇f (W r−1 ) and gets close to W r . Each W r is a local
minimizer of f ( · ) in S+ +
d,≤r , but none of them is a minimizer of f ( · ) in Sd except W K .

The smaller the initialization is, the longer GF stays around each W r . Moreover,
{W r }K K
r=0 corresponds to {W r, }r=0 in Definition 6.5.1 with infinitesimal  > 0.

101
W(0) F = 10 3
101
W(0) F = 10 6
101
W(0) F = 10 12
101
W(0) F = 10 24
101
W(0) F = 10 48
101
W(0) F = 10 96

100 100 100 100 100 100


10 1 10 1 10 1 10 1 10 1 10 1
distance

10 2 10 2 10 2 10 2 10 2 10 2

10 3 10 3 10 3 10 3 10 3 10 3

10 4 10 4 10 4 10 4 10 4 10 4
0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0
Continuous Time 1e4 Continuous Time 1e4 Continuous Time 1e4 Continuous Time 1e4 Continuous Time 1e4 Continuous Time 1e4

Figure 6.1: The trajectory of depth-2 GD, WGD (t), converges to the trajectory of GLRL,
WGLRL (t), as the initialization scale goes to 0. We plot dist(t) = mint0 ∈T kWGD (t) −
WGLRL (t0 )kF for different initialization scale kW (0)kF , where T is a discrete subset of R
that δ-covers the entire trajectory of GLRL: maxt mint0 ∈T kWGLRL (t) − WGLRL (t0 )kF ≤
δ for δ ≈ 0.00042. For each kW (0)kF , we run 20 random seeds and plot them
separately. The ground truth W ∗ ∈ R20×20 is a randomly generated rank-3 matrix
with kW ∗ kF = 20. 30% entries are observed. See more in Section 6.8.1.

6.5.4 Equivalence Between GF and GLRL

In this section we elaborate on the theoretical evidence that GF and GLRL are
equivalent generically, including the case where GLRL does not end in the first phase.
The word “generically” used when we want to assume one of the following regularity
conditions:
175
1. We want to assume that GF converges to a local minimizer (i.e., GF does not
get stuck on saddle points);

2. We want to assume that the top eigenvalue λ1 (−∇f (W )) is unique for a critical
point W of (6.2) that is not a minimizer of f ( · ) in S+
d;

3. We want to assume that a convergent sequence of PSD matrices Wα → W has


positive alignment with vv > for some fixed vector v with hW , vv > i = 0, i.e.,
for a convergent sequence of PSD matrices Wα → W , it holds for sure that
hWα ,vv> i
D E
Wα −W >
lim inf kWα −W k , vv = lim inf kWα −W k ≥ 0, and we further assume that the
α→0 F α→0 F

inequality is strict generically.

Theorem 6.14.2 uncovers how GF with infinitesimal initialization generically be-


haves. Let W 0 := 0. For every r ≥ 1, if W r−1 is a local minimizer in S+
d,≤r−1

but not a minimizer in S+


d , then λ1 (−∇f (W r−1 )) > 0 by Lemma 6.10.2. Generi-

cally, the top eigenvalue λ1 (−∇f (W r−1 )) should be unique, i.e., λ1 (−∇f (W r−1 )) >
λ2 (−∇f (W r−1 )). This enables us to apply Theorem 6.14.2 and deduce that the
limiting trajectory

 
1 1
WrG (t) := lim φ W r−1 + ur u>
r , log + t
→0 2λ1 (−∇f (W r−1 )) 

exists, where ur is the top eigenvector of −∇f (W r−1 ). This WrG ( · ) is exactly the
trajectory of GLRL in phase r as  → 0.
Note that WrG ( · ) corresponds to a trajectory of GF minimizing L( · ) in Rd×r ,
which should generically converge to a local minimizer of L( · ) in Rd×r . This means
the limit W r := limt→+∞ WrG (t) should generically be a local minimizer of f ( · ) in
S+ +
d,≤r . If W r is further a minimizer in Sd , then λ1 (−∇f (W r )) ≤ 0 and GLRL exits

with W r ; otherwise GLRL enters phase r + 1.


If GF aligns well with GLRL in the beginning of phase r (defined below), then by
Theorem 6.14.2, as α → 0, the minimum distance from GF to WrG (t) converges to
176
0 for every t ∈ R. Therefore, GF can get arbitrarily close to the r-th critical point
(r) (r)
W r of GLRL, i.e., there exists a suitable choice Tα so that limα→0 φ(Wα , Tα ) = W r .
(r)
D E
> φ(Wα ,Tα )−W r >
Note that W r , ur ur = 0 by (6.31) and thus lim inf (r) , ur ur =
D E α→0 kφ(Wα ,Tα )−W r kF
(r)
φ(Wα ,Tα ),ur u>
r (r)
lim inf (r) ≥ 0. Generically, there should exist a suitable choice of Tα so
α→0 kφ(Wα ,Tα )−W r kF
(r)
that φ(Wα , Tα ) not only converges to W r but also has positive alignment with ur u>
r ,

that is, GF should generically align well with GLRL in the beginning of phase r + 1.

Definition 6.5.12. We say that GF aligns well with GLRL in the beginning of phase
(r) (r)
r if there exists Tα for every α > 0 such that φ(Wα , Tα ) converges to W r−1 with
positive alignment with ur u>
r as α → 0.

If the initialization satisfies that Wα converges to 0 with positive alignment with


u1 u>
1 as α → 0, then GF aligns well with GLRL in the beginning of phase 1, which
(1)
can be seen by taking Tα = 0. Now assume that GF aligns well with GLRL in the
beginning of phase r − 1, then the above argument shows that GF should generically
align well with GLRL in the beginning of phase r, if GLRL does not exit in phase
r − 1. In the other case, we can use a similar argument as in Theorem 6.5.8 to show
that GF converges to a solution near the minimizer W r of f ( · ) as t → ∞, and the
distance between the solution and W r converges to 0 as α → 0. By this induction we
prove that GF with infinitesimal initialization is equivalent to GLRL generically.

6.6 Benefits of Depth: A View from GLRL

In this section, we consider matrix factorization problems with depth L ≥ 3. Our goal
is to understand the effect of the depth-L parametrization W = U1 U2 · · · UL on the
implicit bias — how does depth encourage GF to find low rank solutions? We take the
standard assumption in existing analysis for the end-to-end dynamics that the weight
matrices have a balanced initialization, i.e. Ui> (0)Ui (0) = Ui+1 (0)Ui+1
>
(0), ∀1 ≤ i ≤
L − 1. Arora et al. [29] showed that if {Ui }Li=1 is balanced at initialization, then we
177
have the following end-to-end dynamics. Similar to the depth-2 case, we use φ(W (0), t)
to denote W (t), where

L−1
dW X i i+1
=− (W W > ) L ∇f (W )(W > W )1− L . (6.11)
dt i=0

The lemma below is the foundation of our analysis for the deep case, which greatly
simplifies (6.11). We defer its derivations and applications into Section 6.15.

Lemma 6.6.1. If W (t) is a symmetric solution of (6.11), then for M (t) := W (t)2/L ,
we have
dM
= −∇f (M L/2 )M L/2 − M L/2 ∇f (M L/2 ). (6.12)
dt

Algorithm 6 Deep Greedy Low-Rank Learning (Deep GLRL)


Require: step size η > 0; small  > 0
0 ← 1/L , L (U1 , · · · , UL ) := f (W1 · · · WL ).
W0 ← 0 ∈ Rd×d , and U0,1 (∞), . . . , U0,L (∞) ∈ Rd×0 are empty matrices
while λ1 (−∇f (Wr )) > 0 do
r ←r+1
let ur be a top (unit) eigenvector of −∇f (Wr−1 )
0
 d×r
Ur,1 (0) ← Ur−1,1 (∞)  ur ∈ R
U (∞) 0
Ur,k (0) ← r−1,k 0 ∈ R
r×r
for all 2 ≤ k ≤ L − 1
 0  
Ur−1,L (∞)
Ur,L (0) ← ∈ Rr×d
0 u>
r
for t = 0, 1, . . . do
Ur,i (t + 1) ← Ur,i (t) − η∇Ui L (Ur,1 (t), · · · , Ur,L (t)), ∀1 ≤ i ≤ L.
Wr ← Ur,1 (∞) · · · Ur,L (∞)
return Wr

Our main result, Theorem 6.6.2, gives a characterization of the limiting trajectory
for deep matrix factorization with infinitesimal identity initialization. Here W (t) :=
−(1−1/P )
limα→0 WαG (t) is the trajectory of deep GLRL, where WαG (t) := φ(αe1 e> α
1 , 2µ1 (P −1) + t)

(see Algorithm 6). The dynamics for general initialization is more complicated. Please
see discussions in Section 6.6.1.

178
L
Theorem 6.6.2. Let P = 2
, L ≥ 3. Suppose k∇f (0)k2 = λ1 (−∇f (0)) >
max{λ2 (−∇f (0)), 0},2

  1
α−(1−1/P )
for every fixed t ∈ R, φ αI, 2µ1 (P −1)
+ t − W (t) = O(α P (P +1) ), (6.13)
F

and for any 2 ≤ k ≤ d,

  −(1−1/P )

for every fixed t ∈ R, λk φ αI, α2µ1 (P −1) + t = O(α). (6.14)

So how does depth encourage GF to find low-rank solutions? When the


ground truth is low-rank, say rank-k, our experiments (Figure 6.2) suggest that GF
with small initialization deep matrix sensing finds solutions with smaller k-low-rankness
compared to the depth-2 case, thus achieving better generalization. At first glance,
this is contradictory to what Theorem 6.6.2 suggests, i.e., the convergence rate of deep
GLRL at a constant time gets slower as the depth increases. However, it turns out
the uniform upper bound for the distance between GF and GLRL is not the ideal
metric for the eventual k-low-rankness of learned solution. Below we will illustrate
why the r-low-rankness of GF within each phase r is a better metric and how they are
different.

Definition 6.6.3 (r-low-rankness). For matrix M ∈ Rd×d , we define the r-low-


qP
d 2
rankness of M as i=r+1 σi (M ), where σi (M ) is the i-th largest singular value of

M.

Suppose we run GF from αI for both depth-2 and depth-L cases. Intuitively, the
1-low-rankness of the depth-2 solution is Ω(α1−µ2 /µ1 ), which can be seen from the
second warmup example in Section 6.4. For the depth-L solution, though it may
diverge from the trajectory of deep GLRL more than the depth-2 solution does, its
2
k∇f (0)k2 = λ1 (−∇f (0)) is a technical assumption which we believe could be removed with a
more refined analysis.
179
1-low-rankness is only O(α), as shown in Theorem 6.6.4. The key idea is to show
that there is a basin in the manifold of rank-1 matrices around W0 such that any GF
starting within the basin converges to W0 . Based on this, we can prove that starting
from any matrix O(α)-close to the basin, GF converges to a solution O(α)-close to
W0 . See Section 6.16 for more details.

Theorem 6.6.4. In the same settings as Theorem 6.6.2, if W (∞) exists and is a
minimizer of f ( · ) in S+
d,≤1 , under the additional regularity assumption 6.16.1, we have

inf φ (αI, t) − W (∞) F


= O(α). (6.15)
t∈R

Interpretation for the advantage of depth with multiple phases. For depth-
2 GLRL, the low-rankness is raised to some power less then 1 per phase (depending on
the eigengap). For deep GLRL, we show the low-rankness is only multiplied by some
constant for the first phase and speculate it to be true for later phases. This conjecture
is supported by our experiments; see Figure 6.2. Interestingly, our theory and
experiments (Figure 6.7) suggest that while being deep is good for generalization,
being much deeper may not be much better: once L ≥ 3, increasing the depth
does not improve the order of low-rankness significantly. While this theoretical result
is only for identity initialization, Theorem 6.7.1 and Corollary 6.7.2 further show that
the dynamics of GF (6.11) with any initialization pointwise converges as L → ∞,
under a suitable time rescaling. See Figure 6.4 for experimental verification.

6.6.1 Escaping Direction for Deep Matrix Factorization

For deep matrix factorization, recall that we only prove that GF with infinitesimal
identity initialization escapes in the direction of the top eigenvector. The main burden
for us to generalize this proof to general initialization is that we don’t know how to
analyze the early phase dynamics of (6.12), i.e., the analytical solution of (6.16) is

180
102
L = 2, W(0) F = 10 6
102
L = 2, W(0) F = 10 12
102
L = 2, W(0) F = 10 24
102
L = 2, W(0) F = 10 48
102
L = 2, W(0) F = 10 96

10 2 10 2 10 2 10 2 10 2

10 6 10 6 10 6 10 6 10 6

10 10 10 10 10 10 10 10 10 10
0 1000 2000 3000 4000 0 1000 2000 3000 4000 0 1000 2000 3000 4000 0 1000 2000 3000 4000 0 1000 2000 3000 4000

102
L = 2, W(0) F = 10 3
102
L = 2, W(0) F = 10 4
102
L = 2, W(0) F = 10 5
102
L = 2, W(0) F = 10 6
102
L = 2, W(0) F = 10 7

10 2 10 2 10 2 10 2 10 2

10 6 10 6 10 6 10 6 10 6

10 10 10 10 10 10 10 10 10 10
102 103 102 103 102 103 102 103 102 103

102
L = 3, W(0) F = 10 3
102
L = 3, W(0) F = 10 4
102
L = 3, W(0) F = 10 5
102
L = 3, W(0) F = 10 6
102
L = 3, W(0) F = 10 7

10 2 10 2 10 2 10 2 10 2

10 6 10 6 10 6 10 6 10 6

10 10 10 10 10 10 10 10 10 10
0.00 0.25 0.50 0.75 1.00 0.0 0.5 1.0 1.5 0.0 0.5 1.0 1.5 2.0 0 1 2 3 0 1 2 3 4
1e4 1e4 1e4 1e4 1e4

102
L = 4, W(0) F = 10 3
102
L = 4, W(0) F = 10 4
102
L = 4, W(0) F = 10 5
102
L = 4, W(0) F = 10 6
102
L = 4, W(0) F = 10 7

10 2 10 2 10 2 10 2 10 2

10 6 10 6 10 6 10 6 10 6

10 10 10 10 10 10 10 10 10 10
0 1000 2000 3000 0.00 0.25 0.50 0.75 1.00 0 1 2 3 0.00 0.25 0.50 0.75 1.00 0 1 2 3
1e4 1e4 1e5 1e5
Continuous Time Continuous Time Continuous Time Continuous Time Continuous Time
distance grad norm r-low-rankness

Figure 6.2: GD passes by the same set of critical points as GLRL when the initialization
scale is small, and gets much closer to the critical points when L ≥ 3. Depth-2 GD
requires a much smaller initialization scale to maintain small low-rankness. Here the
ground truth matrix W ∗ ∈ R20×20 is of rank 3 as stated in Section 6.8.1. In this case,
GLRL has 3 phases and 4 critical points {W r }3r=0 , where W 0 = 0 and W 3 = W ∗ .
For each depth L and initialization scale kW (0)kF , we plot the distance between the
current step of GD and the closest critical point of GLRL, min0≤r≤3 kWGD (t) − W r kF ,
the norm of full gradient, k∇U1:L L(U1:L )kF and the (r + 1)-low-rankness of WGD (t)
with r := argmin0≤r≤3 kWGD (t) − W r kF .

difficult to compute, when L ≥ 3. Intuitively, the direction that the infinitesimal


M (t)
initialization escapes 0 is exactly M := limt→∞ kM (t)kF
, where M (t) is the solution of
(6.16). Showing M = v1 v1> is a critical step in our analysis towards convergence to
GLRL.
dM
= −∇f (0)M L/2 − M L/2 ∇f (0). (6.16)
dt

However, unlike the depth-2 case, M can be different from v1 v1> even if v1> M (0)v1 > 0.
We here give an example for diagonal M (0) and ∇f (0) at Section 6.6.3. Nevertheless,

181
we still conjecture that except for a zero measure set of M (0), M = v1 v1> , based on
the following theoretical and experimental evidences:

• If v1> M (0)v1 > 0 and rank(M (0)) = 1, we prove that M = v1 v1> . (See Theo-
rem 6.6.5)

• For the counter-example, we show experimentally, even with perturbation of only


magnitude 10−5 , M = v1 v1> . The results are shown at Figure 6.3. The y-axis
indicates hv1 , u1 (t)i where u1 (t) is the top eigenvector of M (t). As kW (t)kF
becomes larger, u1 (t) aligns better with v1 , which means the noise helps M
escaping from v1 . The larger the noise is, the faster u1 (t) converges to v1 .

6.6.2 Rank-one Case

Theorem 6.6.5 (rank-1 initialization escapes along the top eigenvector). When
M (t)
rank(M (0)) = 1, limt→∞ kM (t)kF
= v1 v1> , if v1> M (0)v1 > 0.

Proof. Let u(0) be the vector such that M (0) = u(0)u(0)> and u(t) ∈ Rd be the
solution of
du(t)
= ku(t)kL−2
2 ∇f (0)u(t).
dt

It is easy to check that M (t) = u(t)u(t)> is the solution of (6.16), because

dM du > du >
= u +u = − ∇f (0)M (t) ku(t)kL−2
2 − M (t)∇f (0) ku(t)kL−2
2
dt dt dt
= − ∇f (0)M L/2 − M L/2 ∇f (0).

Rt
Let τ (t) = 0
ku(s)kL−2
2 ds. Then

du du dt 1
= =− dτ
kukL−2
2 ∇f (0)u = −∇f (0)u.
dτ dt dτ dt

182
That is, under time rescaling t → τ (t), the trajectory of u(t) still follows the power
iteration, regardless of the depth L.

6.6.3 Counter-example for Escaping Direction

Let ∇f (0) = diag(2, 0.9, 0.8, . . . , 0.1) ∈ R10×10 be diagonal. Let W (0) be also diagonal
and W (0)i,i ∼ Unif[0.9, 1.1] · α for i ∈ [10] \ {2}, W (0)2,2 = 16α, where α = 10−16 is a
small constant. Let the depth be 4.

Lemma 6.6.6. With ∇f (0) and W (0) constructed above, v1 M (0)v1> > 0 and M 6=
v1 v1> .

Proof. It is easy to check that v1 = e1 , so v1 M (0)v1> > 0. Now we prove that


M (∞) 6= v1 v1> .
As both W (0) and ∇f (0) are diagonal, W (t) is always diagonal and has dynamics

dM (t)i,i
= −2∇f (0)i,i M (t)2i,i , ∀i ∈ [10],
dt

therefore we have closed form of M (t):

M (t)−1 −1
i,i = M (0)i,i − 2∇f (0)i,i t, ∀i ∈ [10].

For i ∈ [10], the time for M (t)i,i going to infinity is (2M (0)i,i ∇f (0)i,i )−1 . By simple
calculation, M (t)2,2 goes to infinity the fastest, thus M = e2 e> >
2 6= v1 v1 .

We remark that the scales of W (0) and ∇f (0) do not matter as in gradient flow,
as scaling ∇f (0) is equivalent to scaling time (by Lemma 6.6.7 below). And for this
kW (t)kF
reason, the x-axis is the chosen as kW (0)kF
, the relative growth rate.

183
dynamics ddtM = f(0)M2 M2 f(0)
10 1

| v1, ut(t) |
10 3

= 10 4
10 5 = 10 3
= 10 2

102 105 108 1011 1014 1017


W(t) F
W(0) F

Figure 6.3: Dynamics of dM dt


= −∇f (0)M L/2 − M L/2 ∇f (0) plotted, where L = 4,
u1 (t) is the top eigenvector of W (t) and  is the relative magnitude of noise. The
initialization we use in this experiment is Wnoise (0) = W (0) + α
2
(Z + Z > ), where W (0)
is what we construct at Section 6.6.3, and Z is a matrix where entries are i.i.d. samples
from the standard Gaussian distribution N (0, 1). We run 5 fixed random seeds (the
noise matrix) for each . The trajectory of W is calculated by simulating gradient
flow on M with small timestep and RMSprop [58] for faster convergence.

Lemma 6.6.7. Suppose g : Rd → Rd is a P -homogeneous function, that is, g(αx) =


dx0 (t)
λP g(α) for any α > 0, and dt
= g(x0 (t)). Then αx0 (αP −1 t) is the solution of

dx(t)
= g(x(t)), x(0) = αx0 (0). (6.17)
dt

Proof. Simply plug in x(t) = αx0 (αP −1 t), then we have

dx(t) dαx0 (αP −1 t) dx0 (αP −1 t)


= = αP = αP g(x0 (αP −1 t)) = g(αx0 (αP −1 t)) = g(x(t)).
dt dt d(αP −1 t)

6.7 The Marginal Value of Being Deeper

Theorem 6.7.1 shows that the end-to-end dynamics (6.18) converges point-wise while
L → ∞ if the product of learning rate and depth, ηL, is fixed as constant. Interestingly,
(6.18) also allows us to simulate the dynamics of W (t) for all depths L while the
computation time is independent of L. In Figure 6.4, we compare the effect of depth
184
while fixing the initialization and ηL. We can see that deeper models converge faster.
The difference between L = 1, 2, and 4 is large, while difference among L ≥ 16 is
marginal.

Towards infinite depth


1.0 L=1
0.8 L=2
L=4
0.6 L=8
L = 16
Loss

0.4 L = 32
L = 64
0.2 L = 128
0.0
0 250 500 750 1000
Normalized Continuous Time
Figure 6.4: The marginal value of being deeper. The trajectory of GD converges when
depth goes to infinity. Solid (dotted) curves correspond to test (train) loss. The x-axis
stands for the normalized continuous time t (multiplied by L).

Theorem 6.7.1. Suppose W = Ũ Σ̃Ṽ > is the SVD decomposition of W , where Σ̃ =


diag(σ1 , . . . , σd ). The dynamics of L-layer linear net is the following, ◦ denotes the
entry-wise multiplication:

dW   
= −LŨ Ũ > ∇f (W )Ṽ ◦ K (L) Ṽ > , (6.18)
dt

(L) 2−2/L (L) σi2 −σj2


where Ki,i = σi , Ki,j = 2/L 2/L for i 6= j.
Lσi −Lσj

185
Proof. We start from (6.11):

L−1
dW X l L−1−l
=− (W W > ) L ∇f (W )(W > W ) L
dt l=0
L−1
X 2l 2(L−1−l)
=− Ũ Σ̃ L Ũ > ∇f (W )Ṽ Σ̃ L Ṽ
l=0
" L−1
#
X 2l 2(L−1−l)
= −LŨ L−1 Σ̃ (Ũ > ∇f (W )Ṽ )Σ̃
L L Ṽ.
l=0

Note that Σ̃ is diagonal, so

2l 2(L−1−l)
Σ̃ L (Ũ > ∇f (W )Ṽ )Σ̃ L = (Ũ > ∇f (W )Ṽ ) ◦ H (l) ,

2l 2(L−1−l)
(l)
where Hi,j = σiL σj L
. Therefore,

L−1 L−1
X 2l 2(L−1−l) X
−1 > −1
L Σ̃ (Ũ ∇f (W )Ṽ )Σ̃
L L =L (Ũ > ∇f (W )Ṽ ) ◦ H (l)
l=0 l=0

= (Ũ > ∇f (W )Ṽ ) ◦ K (L) ,

PL−1
where K (L) = L−1 l=0 H (l) . Hence,

dW h
> (L)
i
= −LŨ (Ũ ∇f (W )Ṽ ) ◦ K Ṽ.
dt

The entries of K (L) can be directly calculated by


σ 2−2/L ,

L−1
X 2l 2(L−1−l)

i i = j,
(L) −1
Ki,j =L σi σj
L L
=
 σi2 −σj2
l=0 
 2/L 2/L , i 6= j.
Lσi −Lσj

186
σi2 −σj2
Corollary 6.7.2. As L → ∞, K (L) converges to K ∗ , where Ki,i
∗ ∗
= σi2 , Ki,j = ln σi2 −ln σj2

for i 6= j.

Experiment details. We follow the general setting in Section 6.8.1. The ground
truth W ∗ is different but is generated in the same manner and has the same shape of
20 × 20 and p = 0.3 is used for observation generation. We directly apply (6.18), in
which we compute Ṽ and Ũ through SVD, to simulate the trajectory together with a
10−3
constant learning rate of L
for depth L. W (0) is sampled from 10−3 × N (0, Id ).

6.8 Experiments

6.8.1 General Setup

The code is written in Julia [122] and PyTorch [123].


The ground-truth matrix W∗ is low-rank by construction: we sample a random
orthogonal matrix U , a diagonal matrix S with Frobenius norm kSkF = 1 and set
W∗ = U SU > . Each measurement X in X1 , . . . , Xm is generated by sampling two
one-hot vectors u and v uniformly and setting X = 12 uv > + 12 vu> .
In Figures 6.1 to 6.3 and 6.5 to 6.7, the ground truth matrix W∗ has shape 20 × 20
and rank 3, where kW ∗ kF = 20, λ1 (W ∗ ) = 17.41, λ2 (W ∗ ) = 8.85, λ3 (W ∗ ) = 4.31 and
λ1 (−∇f (0)) = 6.23, λ2 (−∇f (0)) = 5.41. p = 0.3 is used for generating measurements,
except p = 0.25 in Figure 6.5, i.e., each pair of entries of Wij∗ and Wji∗ is observed with
probability p.

Gradient Descent. Let ˜ > 0 be the Frobenius norm of the target random initial-
ization. For the depth-2 case, we sample 2 orthogonal matrices V1 , V2 and a diagonal
matrix D with Frobenius norm ˜, and we set U = V1 D1/2 V2> ; for the depth-L case
with L ≥ 3, we sample L orthogonal matrices V1 , . . . , VL and a diagonal matrix D

187
Depth (L) Simulation method
2 Constant LR, η = 10−3 for 106 iterations
3 Adaptive LR, η = 2 × 10−5 and ε = 10−4 for 106 iterations
4 Adaptive LR, η = 3 × 10−4 and ε = 10−3 for 106 iterations

Table 6.1: Choice of hyperparameters for simulating gradient flow. For L = 2,


gradient descent escapes saddles in O(log 1 ) time, where  is the distance between the
initialization and the saddle.

>
with Frobenius norm ˜, and we set Ui := Vi D1/L Vi+1 (VL+1 = V1 ). In this way,
we can guarantee that the end-to-end matrix W = U1 · · · UL is symmetric and the
initialization is balanced for L ≥ 3.
We discretize the time to simulate gradient flow. When L > 2, gradient flow
stays around saddle points for most of the time, therefore we use full-batch GD with
adaptive learning rate η̃t , inspired by RMSprop [58], for faster convergence:

vt+1 = αvt + (1 − α) k∇L(xt )k22 ,


η
η̃t = q ,
vt+1
1−αt+1

xt+1 = xt − η̃t ∇L(xt ),

where α = 0.99, η is the (unadjusted) learning rate. The choices of hyperparameters


are summarized in Table 6.1. The continuous time for xt is measured as t−1
P
i=0 η̃i .

GLRL. In Figures 6.1, 6.2, 6.5 and 6.6, the GLRL’s trajectory is obtained by running
Algorithm 5 with  = 10−7 and η = 10−3 . The stopping criterion is that if the loop
has been iterated for 107 times.

188
6.8.2 Experimental Equivalence between GLRL and Gradi-

ent Descent

Here we provide experimental evidence supporting our theoretical claims about the
equivalence between GLRL and GF for both cases, L = 2 and L ≥ 3.
In Figure 6.1, we show the distance from every point on GF (simulated by GD)
from random initialization is close to the trajectory of GLRL. In Figure 6.2, we first
run GLRL and obtain the critical points {W r }3r=0 passed by GLRL. We also define
the distance of a matrix W to the critical points to be min0≤r≤3 kW − W r kF .

6.8.3 How well does GLRL work?

We compare GLRL with gradient descent (with not-so-small initialization), nuclear


norm minimization and R1MP [114]. We use CVXPY [124, 125] for finding the nuclear
norm solution. The results are shown in Figure 6.5. GLRL can fully recover the
ground truth, while others have difficulty doing so.

d = 20, W(0) F = 10 2
GLRL, L = 2
100
GD, L = 2
nuclear norm
10 1 R1MP (rank 3)
R1MP (rank 10)
loss

10 2

10 3
0.0 0.5 1.0 1.5 2.0 2.5
Continuous Time 1e4

Figure 6.5: GD with small initialization outperforms R1MP and minimal nuclear
norm solution on synthetic data with low-rank ground truth. Solid (dotted) curves
correspond to test (training) loss. Here the loss f (W ) := d12 kW − W ∗ k2F and f (0) = 1.
We run 10 random seeds for GD and plot them separately (most of them overlap).

189
6.8.4 How does initialization affect the convergence rate to

the rank-1 GLRL trajectory?

We use the general setting in Section 6.8.1. In these experiments, we use the constant
learning rate 10−5 for 4 × 107 iterations. The reference matrix Wref is obtained by
running the first stage of GLRL with kW (0)kF = 10−48 and we pick one matrix in the
trajectory with kWref kF about 0.6.
For every  = 10i , i ∈ {−1, −2, −3, −4, −5}, we run both gradient descent and
the first phase of GLRL with kW (0)kF = . For gradient descent, we use random
initialization so kW (0)kF is full rank w.p. 1. The distance of a trajectory to Wref
is defined as mint≥0 kW (t) − Wref kF . In practice, as we discretized time to simulate
gradient flow, we check every t during simulation to compute the distance. As a result,
the estimation might be inaccurate when a trajectory is really close to Wref .
The result is shown at Figure 6.6. We observe that GLRL trajectories are closer
to the reference matrix Wref by magnitudes. Thus the take home message here is that
GLRL is in general a more computational efficient method to simulate the trajectory
of GF (GD) with infinitesimal initialization, as one can start GLRL with a much
larger initialization, while still maintaining high precision.

6.8.5 Benefit of Depth: polynomial vs exponential depen-

dence on initialization

To verify the our theory in Section 6.6, we run gradient descent with different depth and
initialization. The results are shown in Figure 6.7. We can see that as the initialization
becomes smaller, the final solution gets closer to the ground truth. However, a
depth-2 model requires exponentially small initialization, while deeper models require
polynomial small initialization, though it takes much longer to converge.

190
100

t 0 Wref WGD(t)
10 2

rank 1
10 4 rank d
min

10 6

10 5 10 4 10 3 10 2 10 1
WGD(0) F

Figure 6.6: Using v1 v1> (denoted by “rank 1”) as initialization makes GD much closer
to GLRL compared to using random initialization (denoted by “rank d”), where v1 is
the top eigenvector of −∇f (0). We take a fixed reference matrix on the trajectory of
GLRL with constant norm and plot the distance of GD with each initialization to it
respectively..

L=2 L=2 L=3 L=4


0 0 2.5
10 3 10 3 0 10 3 10 3
10 6 10 4 10 4 0.0 10 4
2 2
10 9
1 10 5 10 5
2.5 10 5

4 10 12 10 6 4 10 6 10 6
10 15 10 7 10 7 5.0 10 7
log10 test loss

2 6
6 10 18 7.5
8
8 3 10.0
10
12.5
10 12
4 15.0
12 14
17.5
5 16
0 1 2 3 4 0 2 4 6 0 2 4 6 0 1 2 3
Continuous Time 1e4 Continuous Time 1e4 Continuous Time 1e4 Continuous Time 1e5

Figure 6.7: Deep matrix factorization encourages GF to find low rank solutions at a
much practical initialization scale, e.g. 10−3 . Here the ground truth is rank-3. For
each setting, we run 5 different random seeds. The solid curves are the mean and
the shaded area indicates one standard deviation. We observe that performance of
GD is quite robust to its initialization. Note that for L > 2, the shaded area with
initialization scale 10−7 is large, as the sudden decrement of loss occurs at quite
different continuous times for different random seeds in this case.

191
6.9 Future Directions

Our result on the equivalence between gradient flow with infinitesimal initialization
and GLRL is based on some regularity conditions that we expect to hold generically.
We leave it a future work to justify these condition, possibly through a smoothed
analysis on the objective f ( · ). Another interesting future direction is to find the
counterpart of GLRL in training deep neural nets. This could be one way to go beyond
the view of norm minimization in the study of the implicit regularization of gradient
descent.

6.10 Preliminary Lemmas

Lemma 6.10.1. For U0 ∈ Rd×r and W0 := U0 U0> , the following statements are
equivalent:

(1). U0 is a stationary point of L(U ) = 12 f (U U > );

(2). ∇f (W0 )W0 = 0;

(3). W0 := U0 U0> is a critical point of (6.2).

Proof. (2) ⇒ (3) is trivial. We only prove (1) ⇒ (2), (3) ⇒ (1).

Proof for (1) ⇒ (2). If U0 is a stationary point, then 0 = ∇L(U0 ) = ∇f (W0 )U0 .
So
∇f (W0 )W0 = (∇f (W0 )U0 ) U0> = 0.

Proof for (3) ⇒ (1). If W0 is a critical point, then

0 = hg(W0 ), ∇f (W0 )i = −2 Tr(∇f (W0 )W0 ∇f (W0 )) = −2k∇f (W0 )U0 k2F ,

which implies ∇L(U0 ) = 0.


192
Lemma 6.10.2. For a stationary point U0 ∈ Rd×r of L(U ) = 12 f (U U > ) where f ( · )
is convex, W0 := U0 U0> attains the global minimum of f ( · ) in S+
d := {W : W  0} iff

∇f (W0 )  0.

Proof. Since f (W ) is a convex function and S+


d is convex, we know that W0 is a global

minimizer of f (W ) in S+
d iff

h∇f (W0 ), W − W0 i ≥ 0, ∀W  0. (6.19)

Note that h∇f (W0 ), W0 i = Tr(∇f (W0 )W0 ). By Lemma 6.10.1, h∇f (W0 ), W0 i = 0.
Combining this with (6.19), we know that W0 is a global minimizer iff

h∇f (W0 ), W i ≥ 0, ∀W  0. (6.20)

It is easy to check that this condition is equivalent to ∇f (W0 )  0.

6.11 Proofs for Counter-example

Conjecture 6.11.1 (Formal Statement, Gunasekar et al. 16). Suppose f : Rd×d →


R is a quadratic function and min f (W ) = 0. Then for any Winit  0 if W 1 =
W 0

lim lim φ(αWinit , t) exists and f (W 1 ) = 0, then kW 1 k∗ = min kW k∗ s.t. f (W ) = 0.


α→0 t→+∞ W 0

Proposition 6.11.2 (Formal Statement for Example 6.5.9). For constant R > 1, let

     
? ? 1 R R 1 1 R 1 R 1 R 
     
? ? R ? 1 R R 1 R R 2 R R 2 
M =  , Mnorm =   , and Mrank =  .
     
1 R ? ? 1 R R 1 1 R 1 R 
   
 
     
2 2
R ? ? ? R 1 1 R R R R R

193
and
1 1 X
L(U ) = f (U U > ), f (W ) = (Wij − Mij )2
2 2
(i,j)∈Ω

where Ω = {(1, 3), (1, 4), (2, 3), (3, 1), (3, 2), (4, 1)}.
Then for any Winit  0, s.t. u>
1 Winit u1 > 0,

lim lim φ(αWinit , t) = Mrank .


α→0 t→+∞

Moreover, we have

kMrank k∗ = 2R2 + 2 > 4R = kMnorm k∗ = min kW k∗ .


W 0,f (W )=0

G
Proof. We define W1, (t), W1G (t) in the same way as in Definition 6.5.1, Theorem 6.5.6.

G
(t) := φ u1 u>

W1, 1 , t ,

W1G (t) := lim W1,


G
( 2µ1 1 log 1 + t).
→0

Below we will show

1. Assumption 6.5.7 and Assumption 6.5.5 are satisfied.

2. W1G (t) F
bounded for t ≥ 0;

3. limt→+∞ W1G (t) = Mrank ;

4. Mnorm = argminW 0,f (W )=0 kW k∗ .

Thus Since Mrank is a global minimizer of f ( · ), applying Theorem 6.5.8 finishes


the proof.

194
Proof for Item 1. Let M0 := ∇f (0), then

 
0 0 1 R
 
0 0 R 0
M0 =  .
 
1 R 0 0
 
 
R 0 0 0

√ √
1+ 1+R2 1− 1+R2
Let A := [ R1 R0 ], then we have λ1 (A) = 2
, λ2 (A) = 2
, thus λ1 (A) >
|λ2 (A)| > 0 > λ2 (A). As a result, λ1 (A) = kAk2 . Let v1 ∈ R2 be the top eigenvector
of A. We claim that u1 = [ vv11 ] ∈ R4 is the top eigenvector of ∇f (0). First by definition
 2 
it is easy to check that M0 u1 = λ1 (A)u1 . Further noticing that M02 = A0 A02 , we
know λ2i (M0 ) ∈ {λ21 (A), λ22 (A)} for all eigenvalues λi (M0 ). That is, λ1 (M0 ) = λ1 (A),
λ2 (M0 ) = −λ2 (A), λ3 (M0 ) = λ2 (A), and λ4 (M0 ) = −λ1 (A). Thus Assumption 6.5.5
is satisfied. Also note that f is quadratic, thus analytic, i.e., Assumption 6.5.7 is also
satisfied.

Proof for Item 2. Let (x (t), y (t)) ∈ R2 be the gradient flow of g(x, y) = 12 (x2 −

1)2 + (xy − R)2 starting from (x (0), y (0)) = v1 .

dx(t)
= (1 − x(t)2 )x(t) − 2y(t)(x(t)y(t) − R)
dt (6.21)
dy(t)
= −2x(t)(x(t)y(t) − R)
dt

Let W (t) be the following matrix:

 
x (t)
 
 y (t)   
W (t) :=   x (t) y (t) x (t) y (t) .
 
x (t) 
  
 
y (t)

195
G
Then it is easy to verify that W (0) = W1, (0) and W (t) satisfies (6.2). Thus by the
G
existence and uniqueness theorem, we have W (t) = W1, (t) for all t. Taking the limit
 → 0, we know that W1G (t) can also be written in the following form:

 
x(t)
 
y(t)  
G
W1 (t) =   x(t) y(t) x(t) y(t) ,
 
x(t)
 
 
y(t)

and (x (t), y (t)) ∈ R2 is a gradient flow of g(x, y) = 12 (x2 − 1)2 + (xy − R)2 .
Since g(x(t), y(t)) is non-increasing overtime, and lim g(x(−t), y(−t)) =
t→−∞

g(x(−∞), y(−∞)) = g(0, 0) = R2 + 0.5, we know |x(t)y(t)| ≤ 3R for all t.


9R2 9R2
So whenever y 2 (t) − x2 (t) ≥ 9R2 , we have x2 (t) ≤ y 2 (t)
≤ y 2 (t)−x2 (t)
≤ 1.
d(y 2 (t)−x2 (t))
In this case, dt
= 2x2 (t)(x2 (t) − 1) ≤ 0. Combining this with
y(−∞)2 − x(−∞)2 = 0 ≤ 9R2 , we have y 2 (t) − x2 (t) ≤ 9R2 for all t, which
also implies that y(t) is bounded. Noticing that 9R2 ≥ g(x(t), y(t)) ≥ (x2 (t) − 1)2 , we
know x2 (t) is also bounded. Therefore, W1G (t) is bounded.

Proof for Item 3. Note that (x(∞), y(∞)) is a stationary point of g(x, y). It is
clear that g(x, y) only has 3 stationary points — (0, 0), (1, R) and (−1, −R). Thus W 1
can only be 0 or Mrank . However, since for all t, f (W1G (t)) < f (0), W 1 = limt→∞ W1G (t)
cannot be 0. So W 1 must be Mrank .

Proof for Item 4. Let mij be (i, j)th element of M . Suppose M  0, we have

(e1 − e4 )> M (e1 − e4 ) ≥ 0 =⇒ m11 + m44 ≥ m14 + m41 = 2R

(e2 − e3 )> M (e2 − e3 ) ≥ 0 =⇒ m22 + m33 ≥ m23 + m32 = 2R

196
Thus 4R = minW 0,f (W )=0  kW k∗ , where
 the equality is
 only attained at mii = R, i =
m11 m14  m22 m23 
1, 2, 3, 4. Otherwise, either   or   will have negative eigenvalues.
m41 m44 m32 m33
Contradiction to that M  0.
Below we will show the rest unknown off-diagonal entries must be 1. Let
 
1 −1 0 0
 
V =
0 0 1 0

 
0 0 0 1

, then we have that


 
 0 m13 − m23 m14 − m24 
M  0 =⇒ V M V >
 
 0 =⇒ 
m31 − m32 R R   0,

 
m41 − m42 R R

which implies m13 = m23 , m14 = m24 


. 
1 0 0 0
 
With the same argument for V = 
0 1 0 0 , we have m13 = m14 , m23 = m24 .
 
0 0 1 −1
Also note M is symmetric and m13 = 1, thus mij = mji = 1, ∀i = 1, 2, j = 3, 4. Thus
Mnorm = argminW 0,f (W )=0 kW k∗ , which is unique.

6.12 Proofs for Dynamical System

In this section, we prove Theorem 6.5.3 in Section 6.5.1. In Section 6.12.1, we show
how to reduce Theorem 6.5.3 to the case where J(0) is exactly a diagonal matrix, then
we prove this diagonal case in Section 6.12.2. Finally, in Section 6.12.3, we discuss
how to extend it to the case where J(0) is non-diagonalizable.

197
6.12.1 Reduction to the Diagonal Case

Theorem 6.12.1. If J(0) = diag(µ̃1 , . . . , µ̃d ) is diagonal, then the statement in


Theorem 6.5.3 holds.

Proof for Theorem 6.5.3. We show how to prove Theorem 6.5.3 based on Theo-
dx
rem 6.12.1. Let dt
= g(x) be the dynamical system in Theorem 6.5.3. Let J(0) =
Ṽ D̃Ṽ −1 be the eigendecomposition, where Ṽ is an invertible matrix and D̃ =
diag(µ̃1 , . . . , µ̃d ). Now we define the following new dynamics by changing the ba-
sis:
x̂(t) = Ṽ −1 x(t).

dx̂(t)
Then dt
= ĝ(x̂) for ĝ(x̂) := Ṽ −1 g(Ṽ x̂), and the associated Jacobian matrix is
ˆ
J(x̂) ˆ = diag(µ̃1 , . . . , µ̃d ).
:= Ṽ −1 J(Ṽ x̂)Ṽ , and thus J(0)
Now we apply Theorem 6.12.1 to x̂(t). Then ẑα (t) := Ṽ −1 zα (t) converges to the
limit ẑ(t) := lim ẑα (t). This shows that the limit z(t) = Ṽ ẑ(t) exists in Theorem 6.5.3.
α→0

We can also verify that z(t) is a solution of (6.6).


Given δα converging to 0 with positive alignment with ũ1 as α → 0, we can define
δ̂α := Ṽ −1 δα , then δ̂α converges to 0 with positive alignment with e1 , where e1 is the
ˆ
first vector in the standard basis and is also the top eigenvector of J(0). Therefore,
for every t ∈ (−∞, +∞), there is a constant C > 0 such that

  γ̃
−1 1 1
Ṽ φ δα , t + log − ẑ(t) ≤ C · kδ̂α k2µ̃1 +γ̃ (6.22)
µ̃1 hδα , ũ1 i 2

for every sufficiently small α. As Ṽ are invertible, this directly implies (6.7).

198
6.12.2 Proof for the Diagonal Case

Now we only need to prove Theorem 6.12.1. Let e1 , . . . , ed be the standard basis.
Then ũ1 = ṽ1 = e1 in this diagonal case. We only use e1 to stand for ũ1 and ṽ1 in the
rest of our analysis.
Let R > 0. Since g(x) is C 2 -smooth, there exists β > 0 such that

kJ(x) − J(x + h)k2 ≤ βkhk2 (6.23)

for all kxk2 , kx + hk2 ≤ R. Then the following can be proved by integration:

Z 1 
g(x + h) − g(x) = J(x + ξh)dξ h, (6.24)
0

kg(x + h) − g(x) − J(x)hk2 ≤ βkhk22 . (6.25)

By (6.25), we also have

kg(x) − J(0)xk2 = kg(x) − g(0) − J(0)xk2 ≤ βkxk22 . (6.26)

Let κ := β/µ̃1 . We assume WLOG that R ≤ 1/κ. Let F (x) = log x − log(1 + κx).
It is easy to see that F 0 (x) = 1
x+κx2
and F (x) is an increasing function with range
(−∞, log(1/κ)). We use F −1 (y) to denote the inverse function of F (x). Define
Tα (r) := µ̃11 (F (r) − F (α)) = µ̃11 log αr − log 1+κα
1+κr

.
Our proof only relies on the following properties of J(0) (besides that µ̃1 , e1 are
the top eigenvalue and eigenvector of J(0)):

Lemma 6.12.2. For J(0) := diag(µ̃1 , . . . , µ̃d ), we have

1. For any h ∈ Rd , h> J(0)h ≤ µ̃1 khk22 ;

2. For any t ≥ 0, etJ(0) − eµ̃1 t e1 e>


1 2
= eµ̃2 t .

199
Pd
Proof. For Item 1, h> J(0)h = i=1 µ̃i h2i ≤ µ̃1 khk22 . For Item 2, etJ(0) − eµ̃1 t e1 e>
1 2
=
diag(0, eµ̃2 t , . . . , eµ̃d t ) 2
= eµ̃2 t .

Lemma 6.12.3. For x(t) = φ(x0 , t) with kx0 k2 ≤ α and t ≤ Tα (r),

1 + κr
kx(t)k2 ≤ α · eµ̃1 t ≤ r.
1 + κα

Proof. By (6.26) and Lemma 6.12.2, we have

1 dkx(t)k22
= hx(t), g(x(t))i ≤ hx(t), J(0)x(t)i + βkx(t)k32 ≤ µ̃1 kx(t)k22 + βkx(t)k32 .
2 dt

dkx(t)k2
This implies dt
≤ µ̃1 (kx(t)k2 + κkx(t)k22 ). Since F 0 (x) = 1
x+κx2
, we further have

d
F (kx(t)k2 ) ≤ µ̃1 .
dt

So F (kx(t)k2 ) ≤ F (α) + µ̃1 t. By definition of Tα (r), we then know that kx(t)k2 ≤ r


for all t ≤ Tα (r). So

log kx(t)k2 ≤ F (kx(t)k2 ) + log(1 + κr) ≤ F (α) + µ̃1 t + log(1 + κr).

Expending F (α) proves the lemma.

Lemma 6.12.4. For x(t) = φ(x0 , t) with kx0 k2 ≤ α and t ≤ Tα (r), we have

x(t) = etJ(0) x0 + O(r2 ).

200
Proof. Let x̂(t) = etJ(0) x0 . Then we have

1d
kx(t) − x̂(t)k22 ≤ hg(x(t)) − J(0)x̂(t), x(t) − x̂(t)i
2 dt
= hg(x(t)) − J(0)x(t), x(t) − x̂(t)i + (x(t) − x̂(t))> J(0)(x(t) − x̂(t))

≤ kg(x(t)) − J(0)x(t)k2 · kx(t) − x̂(t)k2 + µ̃1 kx(t) − x̂(t)k22 ,

where the last inequality is due to Lemma 6.12.2. By (6.26) and Lemma 6.12.3, we
have
 2
1 + κr
kg(x(t)) − J(0)x(t)k2 ≤ βkx(t)k22 ≤β α · e2µ̃1 t .
1 + κα
d 1+κr
2
So we have dt
kx(t) − x̂(t)k2 ≤ β 1+κα
α · e2µ̃1 t + µ̃1 kx(t) − x̂(t)k2 . By Grönwall’s
inequality,
Z t  2
1 + κr
kx(t) − x̂(t)k2 ≤ β α · e2µ̃1 τ eµ̃1 (t−τ ) dτ.
0 1 + κα

Evaluating the integral gives

2 2
eµ̃1 t − 1
 
1 + κr µ̃1 t 1 + κr
kx(t) − x̂(t)k2 ≤ β α e · ≤κ α · eµ̃1 t ≤ κr2 ,
1 + κα µ̃1 1 + κα

which proves the lemma.

Lemma 6.12.5. Let x(t) = φ(x0 , t), x̂(t) = φ(x̂0 , t). If max{kx0 k2 , kx̂0 k2 } ≤ α, then
for t ≤ Tα (r),
kx(t) − x̂(t)k2 ≤ eµ̃1 t+κr kx0 − x̂0 k2 .

Proof. For t ≤ Tα (r), by (6.24),

1d
kx(t) − x̂(t)k22 = hg(x(t)) − g(x̂(t)), x(t) − x̂(t)i
2 dt Z  1
= (x(t) − x̂(t))> J(xξ (t))dξ (x(t) − x̂(t)),
0

201
where xξ (t) := ξx(t) + (1 − ξ)x̂(t). By Lemma 6.12.3, max{kx(t)k2 , kx̂(t)k2 } ≤
1+κr 1+κr
1+κα
α · eµ̃1 t for all t ≤ Tα (r). So kxξ (t)k2 ≤ 1+κα
α · eµ̃1 t . Combining these with (6.23)
and Lemma 6.12.2, we have

 
> > > 1 + κr µ̃1 t
h J(xξ (t))h = h J(0)h + h (J(xξ (t)) − J(0))h ≤ µ̃1 + β · α·e khk22 ,
1 + κα

d 1+κr

for all h ∈ Rd . Thus, dt
kx(t) − x̂(t)k2 ≤ µ̃1 + β · 1+κα
α · eµ̃1 t kx(t) − x̂(t)k2 . This
implies

Z t 
kx(t) − x̂(t)k2 1 + κr µ̃1 τ
log ≤ µ̃1 + β · α·e dτ
kx(0) − x̂(0)k2 0 1 + κα
1 + κr µ̃1 t
≤ µ̃1 t + κ · αe
1 + κα
≤ µ̃1 t + κr.

Therefore, kx(t) − x̂(t)k2 ≤ eµ̃1 t+κr kx(0) − x̂(0)k2 .

Lemma 6.12.6. For every t ∈ (−∞, +∞), z(t) exists and zα (t) converges to z(t) in
the following rate:
kzα (t) − z(t)k2 = O(α),

where O hides constants depending on g(x) and t.

Proof. We prove the lemma in the cases of t ∈ (−∞, F (R)/µ̃1 ] and t > F (R)/µ̃1
respectively.

α̃
Case 1. Fix t ∈ (−∞, F (R)/µ̃1 ]. Let α̃ be the unique number such that 1+κα̃

(i.e., F (α̃) = log α). Let α0 be an arbitrary number less than α. Let t0 := 1
µ̃1
log αα0 .
Then t0 = 1
µ̃1
(F (α̃) − log α0 ) ≤ Tα0 (α̃). By Lemma 6.12.4, we have

kφ (α0 e1 , t0 ) − αe1 k2 = φ (α0 e1 , t0 ) − et0 J(0) α0 e1 2


= O(α̃2 ).

202
Let r := F −1 (µ̃1 t) ≤ R. Then t + 1
µ̃1
log α1 = Tα̃ (r) if α̃ < r.
By Lemma 6.12.3, kφ (α0 e1 , t0 )k2 ≤ α̃. Also, kαe1 k2 = α̃
1+κα̃
≤ α̃. By Lemma 6.12.5,

   
0 1 1 1 1
kzα (t) − zα0 (t)k2 = φ α e1 , t + log 0 − φ αe1 , t + log
µ̃1 α µ̃1 α 2
   
0 1 1 1 1
= φ φ(α e1 , t0 ), t + log − φ αe1 , t + log
µ̃1 α µ̃1 α 2
µ̃1 (t+ µ̃1 log 1
)+κr
≤ O(α̃2 · e 1 α
)
 2
α̃
≤O .
α

For α small enough, we have α̃ = O(α), so for any α0 ∈ (0, α),

kzα (t) − zα0 (t)k2 = O(α).

This implies that {zα (t)} satisfies Cauchy’s criterion for every t, and thus the limit
z(t) exists for t ≤ F (R)/µ̃1 . The convergence rate can be deduced by taking limits
for α0 → 0 on both sides.

Case 2. For t = F (R)/µ̃1 + τ with τ > 0, φ(x, τ ) is locally Lipschitz with respect
to x. So

kzα (t) − zα0 (t)k2 = kφ(zα (F (R)/µ̃1 ), τ ) − φ(zα0 (F (R)/µ̃1 ), τ )k2

= O(kzα (F (R)/µ̃1 ) − zα0 (F (R)/µ̃1 )k2 )

= O(α),

which proves the lemma for t > F (R)/µ̃1 .


 
1 1
Proof for Theorem 6.12.1. The existence of z(t) := limα→0 zα (t) = limα→0 φ αe1 , t + µ̃1
log α

has already been proved in Lemma 6.12.6, where we show kzα (t) − z(t)k2 = O(α).

203
By the continuity of φ( · , t) for every t ∈ R, we have

     
1 1 1 1
z(t) = lim φ αṽ1 , t + log = φ lim φ αṽ1 , log , t = φ (z(0), t) .
α→0 µ̃1 α α→0 µ̃1 α

Now it is only left to prove (6.7). WLOG we can assume that kδα k2 is decreasing and
α
2
≤ kδα k2 ≤ α (otherwise we can do reparameterization). Then our goal becomes
proving
 γ̃ 
kxα (t) − z(t)k2 = O α µ̃1 +γ̃ . (6.27)
 
1
where xα (t) := φ δα , t + µ̃1
log hδα1,e1 i . We prove (6.27) in the cases of t ∈
(−∞, F (R)/µ̃1 ] and t > F (R)/µ̃1 respectively.

γ̃
α̃1
Case 1. Fix t ∈ (−∞, (F (R) + log q)/µ̃1 ]. Let α̃1 = α µ̃1 +γ̃ . Let α1 := eF (α̃1 ) = 1+κα̃1
.
1
Let t0 := µ̃1
(F (α̃1 ) − log α) ≤ Tkδα k2 (α̃1 ). At time t0 , by Lemma 6.12.2 we have

µ̃2  α  µ̃µ̃2
(F (α̃1 )−log α) 1
e t0 J(0)
−eµ̃1 t0
e1 e>
1 2 =e µ̃2 t0
=e µ̃1
= 1
. (6.28)
α

δα
Let qα := α
, e1 . By Definition 6.5.2, there exists q > 0 such that qα ≥ q for all
sufficiently small α. Then we have

kφ (δα , t0 ) − α1 qα e1 k2 = φ (δα , t0 ) − et0 J(0) δα 2 + et0 J(0) − eµ̃1 t0 e1 e>



1 δα 2
 α  µ̃µ̃2
2 1 1
= O(α̃1 ) + kδα k2
α
µ̃ /µ̃1
= O(α̃12 + α1 2 α1−µ̃2 /µ̃1 )

= O(α12 ).

Let r := F −1 (µ̃1 t + log q1α ) ≤ R. Then t + 1


µ̃1
log α11qα = Tα̃ (r) if α̃ < r.
α̃1
By Lemma 6.12.3, kφ (δα , t0 )k2 ≤ α̃1 . Also, kα1 qα e1 k2 ≤ α1 = 1+κα̃1
≤ α̃1 .

204
By Lemma 6.12.5,

   
1 1 1 1
kxα (t) − zα1 (t)k2 ≤ φ φ (δα , t0 ) , t + log − φ α1 qα e 1 , t + log
µ̃1 α1 q α µ̃1 α1 qα 2
   
1 1
µ̃ t+ log +κr
= O α12 · e 1 µ̃1 α1 qα

= O(α1 ).

Combining this with the convergence rate for zα1 (t), we have

kxα (t) − z(t)k2 ≤ kxα (t) − zα1 (t)k2 + kzα1 (t) − z(t)k2 = O(α1 ).

Case 2. For t = (F (R) + log q)/µ̃1 + τ with τ > 0, φ(x, τ ) is locally Lipschitz with
respect to x. So

kxα (t) − z(t)k2 = kφ(xα ((F (R) + log q)/µ̃1 ), τ ) − φ(z((F (R) + log q)/µ̃1 ), τ )k2

= O(kxα ((F (R) + log q)/µ̃1 ) − z((F (R) + log q)/µ̃1 )k2 )

= O(α1 ),

which proves (6.27) for t > (F (R) + log q)/µ̃1 .

6.12.3 Extension to Non-Diagonalizable Case

The proof in Section 6.12.2 can be generalized to the case where J(0). Now we state the
theorem formally and sketch the proof idea. We use the notations g(x), φ(x0 , t), J(x)
as in Section 6.5.1, but we do not assume that J(0) is diagonalizable. Instead, we use
µ̃1 , µ̃2 , . . . , µ̃d ∈ C to denote the eigenvalues of J(0), repeated according to algebraic
multiplicity. We sort the eigenvalues in the descending order of the real part of each
eigenvalue, i.e., <(µ̃1 ) ≥ <(µ̃2 ) ≥ · · · ≥ <(µ̃d ), where <(z) stands for the real part of

205
a complex number z ∈ C. We call the eigenvalue with the largest real part the top
eigenvalue.

Theorem 6.12.7. Assume that x = 0 is a critical point and the following regularity
conditions hold:

1. g(x) is C 2 -smooth;

2. φ(x0 , t) exists for all x0 and t;

3. The top eigenvalue of J(0) is unique and is a positive real number, i.e.,

µ̃1 > max{<(µ̃2 ), 0}.

Let ṽ1 , ũ1 be the left and right eigenvectors associated with µ̃1 , satisfying ũ>
1 ṽ1 = 1.

1
Let zα (t) := φ(αṽ1 , t + µ̃1
log α1 ) for every α > 0, then ∀t ∈ R, z(t) := lim zα (t) exists
α→0

and z(t) = φ(z(0), t). If δα converges to 0 with positive alignment with ũ1 as α → 0,
then for any t ∈ R and for any  > 0, there is a constant C > 0 such that for every
sufficiently small α,

γ̃
−
 
1
φ δα , t + µ̃1
log hδα1,ũ1 i − z(t) ≤ C · kδα k2µ̃1 +γ̃ , (6.29)
2

where γ̃ := µ̃1 − <(µ̃2 ) is the eigenvalue gap.

Proof Sketch. Define the following two types of matrices. For r ≥ 1, a, δ ∈ R, we


define  
a δ 
 
 a δ 
 
 
a δ
 
(r)
 
Ja,δ := 
  ∈ Rr×r .
. . . . 

 . . 

 

 a δ 

 
a
206
For r ≥ 1, a, b, δ ∈ R, we define
 
C δI 
 

 C δI 

 
C δI
 
(r)
 
Ja,b,δ := 
  ∈ R2r×2r ,
 . .
.. .. 

 
 

 C δI 

 
C

 a −b 
where C = b a ∈ R2×2 .
By linear algebra, the real matrix J(0) can be written in the real Jordan normal
form, i.e., J(0) = Ṽ diag(J[1] , . . . , J[m] )Ṽ −1 , where Ṽ ∈ Rd×d is an invertible matrix,
and each J[j] is a real Jordan block. Recall that there are two types of real Jordan
(r) (r)
blocks, Ja,1 or Ja,b,1 . The former one is associated with a real eigenvalue a, and the
latter one is associated with a pair of complex eigenvalues a ± bi. The sum of sizes
of all Jordan blocks corresponding to a real eigenvalue a is its algebraic multiplicity.
The sum of sizes of all Jordan blocks corresponding to a pair of complex eigenvalues
a ± bi is two times the algebraic multiplicity of a + bi or a − bi (note that a ± bi have
the same multiplicity).
(r) (r)
It is easy to see that Ja,δ = DJa,1 D−1 for D = diag(δ r , δ r−1 , . . . , δ) ∈ Rr×r and
(r) (r)
Ja,b,δ = DJa,b,1 D−1 for D = diag(δ r , δ r , δ r−1 , δ r−1 , . . . , δ, δ) ∈ R2r×2r . This means for
every δ > 0 there exists Ṽδ such that J(0) = Ṽδ Jδ Ṽδ−1 , where Jδ := diag(Jδ[1] , . . . , Jδ[m] ),
(r) (r) (r) (r)
Jδ[j] := Ja,δ if J[j] := Ja,1 , or Jδ[j] := Ja,b,δ if J[j] := Ja,b,1 . Since the top eigenvalue of
J(0) is positive and unique, µ̃1 corresponds to only one block [µ̃1 ] ∈ R1×1 . WLOG we
let J1 = [µ̃1 ], and thus Jδ[1] = [µ̃1 ].
We only need to select a parameter δ > 0 and prove the theorem in the case
of J(0) = Jδ since we can change the basis in a similar way as we have done in
Section 6.12.1. By scrutinizing the proof for Theorem 6.12.1, we can find that we only

207
need to reprove Lemma 6.12.2. However, Lemma 6.12.2 may not be correct since J(0)
is not diagonal anymore. Instead, we prove the following:

1. If δ ∈ (0, γ̃), then h> Jδ h ≤ µ̃1 khk22 for all h ∈ Rd ;

0
2. For any µ̃02 ∈ (<(µ̃2 ), µ̃1 ), if δ ∈ (0, µ̃02 − <(µ̃2 )), then etJδ − eµ̃1 t e1 e>
1 2
≤ eµ̃2 t
for all t ≥ 0.

Proof for Item 1. Let K be the set of pairs (k1 , k2 ) such that k1 6= k2 and the
entry of Jδ at the k1 -th row and the k2 -th column is non-zero. Then we have

d
> > Jδ + Jδ> X X
h Jδ h = h h= <(µ̃k )h2k + hk1 hk2 δ
2 k=1 (k1 ,k2 )∈K
d
X X h2k1 + h2k2
≤ <(µ̃k )h2k + δ.
k=1
2
(k1 ,k2 )∈K

Note that <(µ̃k ) ≤ <(µ̃2 ) for k ≥ 2. Also note that there is no pair in K has k1 = 1
or k2 = 1, and for every k ≥ 2 there are at most two pairs in K has k1 = k or k2 = k.
Combining all these together gives

d
X
>
h Jδ h ≤ µ̃1 h21 + (<(µ̃2 ) + δ) h2k ≤ µ̃1 khk22 ,
k=2

which proves Item 1.

Proof for Item 2. Since Jδ is block diagonal, we only need to prove that ketJδ[j] k2 ≤
0 (r)
eµ̃2 t for every j ≥ 2. If Jδ[j] = Ja,δ = aI + δN , where N is the nilpotent matrix, then

etJδ[j] = eatI+δtN = eatI eδtN = eat eδtN ,

208
where the second equality uses the fact that I and N are commutable. So we have

ketJδ[j] k2 ≤ eat keδtN k2 = eat eδtkN k2 ≤ e(a+δ)t .

(r)
If Jδ[j] = Ja,δ = D + δN 2 , where D = diag(C, C, . . . , C) and N is the nilpotent matrix,
then
2 2
etJδ[j] = etD+δtN = etD eδtN ,

where the second equality uses the fact that D and N 2 are commutable. Note that
h i
− sin(bt)
etC = eat cos(bt) tD tC at
sin(bt) cos(bt) , which implies ke k2 = ke k2 = e . So we have

2 2k
ketJδ[j] k2 ≤ ketD k2 · keδtN k2 = eat eδtkN 2
≤ e(a+δ)t .

Since δ ∈ (0, µ̃02 − <(µ̃2 )), we know that a + δ < µ̃02 , which completes the proof.

Proof for a fixed δ. Since Item 1 continues to hold for δ ∈ (0, γ̃), Lemmas 6.12.3
to 6.12.6 also hold. This proves that z(t) exists and satisfies (6.6).
It remains to prove (6.29) for any  > 0. Let γ̃ 0 ∈ (0, γ̃) be a number such
γ̃ 0 γ̃
that µ̃1 +γ̃ 0
≥ µ̃1 +γ̃
− . Fix µ̃02 = µ̃1 − γ̃ 0 , δ = µ̃02 − <(µ̃2 ). By Item 2, we have
0
etJδ − eµ̃1 t e1 e>
1 2
≤ eµ̃2 t for all t ≥ 0. By scrutinizing the proof for Theorem 6.12.1,
we can find that the only place we use Item 2 in Lemma 6.12.2 is in (6.28). For proving
(6.29), we can repeat the proof while replacing all the occurrences of µ̃2 by µ̃02 . Then
we know that for every t ∈ R, there is a constant C > 0 such that

  γ̃ 0
0
Ṽδ−1 φ δα , t + 1
µ̃1
log 1
hδα ,ũ1 i
− Ṽδ−1 z(t) ≤C· kṼδ−1 δα k2µ̃1 +γ̃ ,
2

209
γ̃ 0 γ̃
for every sufficiently small α. By definition of γ̃ 0 , µ̃1 +γ̃ 0
≥ µ̃1 +γ̃
− . Since δα → 0 as
α → 0, we have kṼδ−1 δα k2 < 1 for sufficiently small α. Then we have

   
φ δα , t + 1
µ̃1
log 1
hδα ,ũ1 i
− z(t) ≤ kṼδ k2 · Ṽδ−1 φ δα , t + 1
µ̃1
log 1
hδα ,ũ1 i
− Ṽδ−1 z(t)
2 2
γ̃ 0
µ̃1 +γ̃ 0
≤ kṼδ k2 · C · kṼδ−1 δα k2
γ̃ 0 γ̃
0 −
≤ C · kṼδ k2 · kṼδ−1 k2µ̃1 +γ̃ · kδα k2µ̃1 +γ̃ .

γ̃ 0
0
Absorbing kṼδ k2 · kṼδ−1 k2µ̃1 +γ̃ into C proves (6.29).

6.13 Eigenvalues of Jacobians and Hessians

In this section we analyze the eigenvalues of the Jacobian J(W ) at critical points of
(6.2).
For notation simplicity, we write sz(A) := A + A> to denote the symmetric matrix
produced by adding up A and its transpose, and write ac{A, B} = AB + BA to
denote the anticommutator of two matrices A, B. Then g(W ) can be written as
g(W ) := −ac{∇f (W ), W }.
Let U0 ∈ Rd×r be a stationary point of the function L : Rd×r → R, L(U ) =
1
2
f (U U > ), i.e., ∇L(U0 ) = ∇f (U0 U0> )U0 = 0. By Lemma 6.10.1, this implies

∇f (W0 )W0 = 0 (6.30)

for W0 := U0 U0> , and thus W0 is a critical point of (6.2).


For function F (x), we use DF (x)[∆], D2 F (x)[∆1 , ∆2 ] to denote the first- and
second-order directional derivatives of F ( · ) at x. For convenience, for a real-
valued function F , we use D2 F and D(∇F ) interchangeably in the sense that
DF (x)[∆1 , ∆2 ] := hDF (x)[∆1 ], ∆2 i and D2 F (x)[∆] = D(∇F (x))[∆].

210
Define J(W ) := Dg(W ). By simple calculus, we can compute the formula for
J(W0 ):

J(W0 )[∆] = −ac{∇f (W0 ), ∆} − ac{D2 f (W0 )[∆], W0 },

J(W0 )[∆1 , ∆2 ] = − ∇f (W0 ), sz(∆1 ∆> 2 >


2 ) − D f (W0 )[∆1 , sz(W0 ∆2 )],

where ∆, ∆1 , ∆2 ∈ Rd×d .
We can also compute the formula for D2 L(U0 ):

D2 L(U0 )[∆] = ∇f (W0 )∆ + D2 f (W0 )[sz(∆U0> )]U0 ,


1
D2 L(U0 )[∆1 , ∆2 ] = ∇f (W0 ), sz(∆1 ∆> 2 > >

2 ) + D f (W 0 )[sz(∆1 U 0 ), sz(∆2 U 0 )] ,
2

where ∆, ∆1 , ∆2 ∈ Rd×r .

6.13.1 Eigenvalues at the Origin

The eigenvalues of J(0) is given in Lemma 6.5.4. Now we provide the proof.

Proof for Lemma 6.5.4. For W0 = 0, we have

J(0)[∆] = −∇f (0)∆ − ∆∇f (0)

J(0)[∆1 , ∆2 ] = − ∇f (0), sz(∆1 ∆>


2)

It is easy to see from the second equation that J(0) is symmetric.

211
Pd
Let −∇f (0) = i=1 µi u1[i] u>
1[i] be the eigendecomposition of the symmetric matrix

−∇f (0). Then we have

d
X
µi u1[i] u> >

J(0)[∆] = 1[i] ∆ + ∆u1[i] u1[i]
i=1
d X
X d
µi u1[i] u> > > >

= 1[i] ∆u1[j] u1[j] + u1[j] u1[j] ∆u1[i] u1[i]
i=1 j=1
d X
X d
= (µi + µj )u1[i] u> >
1[i] ∆u1[j] u1[j]
i=1 j=1
d X
X d
= (µi + µj ) ∆, u1[i] u> >
1[j] u1[i] u1[j] ,
i=1 j=1

which proves (6.8).


For ∆ = u1[i] u> >
1[j] + u1[j] u1[i] , we have

J(0)[∆] = (µi + µj )u1[i] u> >


1[j] + (µi + µj )u1[j] u1[i] = (µi + µj )∆.

So u1[i] u> >


1[j] + u1[j] u1[i] is an eigenvector of J(0) associated with eigenvalue µi + µj .

Note that {u1[i] u> >


1[j] + u1[j] u1[i] : i, j ∈ [d]} spans all the symmetric matrices, so these

are all the eigenvectors in the space of symmetric matrices.


For every antisymmetric matrix ∆ (i.e., ∆ = −∆> ), we have

J(0)[∆] = J(0)[∆> ] = J(0)[−∆].

So J(0)[∆] = 0 and every antisymmetric matrix is an eigenvector associated with


eigenvalue 0.
Since every matrix can be expressed as the sum of a symmetric matrix and an
antisymmetric matrix, we have found all the eigenvalues.

212
6.13.2 Eigenvalues at Second-Order Stationary Points

Now we study the eigenvalues of J(W0 ) when U0 is a second-order stationary point of


L( · ), i.e., ∇L(U0 ) = 0, D2 L(U0 )[∆, ∆] ≥ 0 for all ∆ ∈ Rd×r . We further assume that
U0 is full-rank, i.e., rank(U0 ) = r. This condition is meet if W0 := U0 U0T is a local
minimizer of f ( · ) in S+ +
d but not a minimizer in Sd .

Lemma 6.13.1. For r ≤ d, if U0 ∈ Rd×r is a second-order stationary point of L( · ),


then either rank(U0 ) = rank(W0 ) = r, or W0 is a minimizer of f ( · ) in S+
d , where

W0 = U0 U0> .

Proof. Assume to the contrary that U0 has rank < r and W0 is a minimizer of f ( · ) in
S+ r
d . The former one implies that there exists a unit vector q ∈ R such that U0 q = 0,

and the latter one implies that there exists v ∈ Rd such that v > ∇f (W0 )v < 0 by
Lemma 6.10.2.
Let ∆ = vq > . Then we have

1
D2 L(U0 )[∆, ∆] = ∇f (W0 ), vv > + D2 f (W0 )[sz(v(U0 q)> ), sz(v(U0 q)> )]
2
1
= ∇f (W0 ), vv > + D2 f (W0 )[0, 0]
2
= ∇f (W0 ), vv > .

So D2 L(U0 )[∆, ∆] < 0, which leads to a contradiction.

By (6.30), the symmetric matrices −∇f (W0 ) and W0 commute, so they can be
simultaneously diagonalizable. Since (6.30) also implies that they have different
column spans, we can have the following diagonalization:

d−r
X d
X
− ∇f (W0 ) = µi vi vi> , W0 = µi vi vi> . (6.31)
i=1 i=d−r+1

213
First we prove the following lemma on the eigenvalues and eigenvectors of the
linear operator −D2 L(U0 ):

Lemma 6.13.2. For every ∆ ∈ Rd×d , if

U0 ∆> + ∆U0> = 0 (6.32)

then ∆ is an eigenvector of the linear operator −D2 L(U0 )[ · ] : Rd×r → Rd×r associated
with eigenvalue 0. Moreover, the solutions of (6.32) spans a linear space of dimension
r(r−1)
2
.

Proof. Suppose U0 ∆> + ∆U0> = 0. Then we have U0 ∆> = −∆U0> , and thus ∆> =
−U0+ ∆U0> , where U0+ is the pseudoinverse of the full-rank matrix U0 . This implies
that there is a matrix R ∈ Rr×r , such that ∆ = U0 R. Then we have

−D2 L(U0 )[∆] = −∇f (W0 )U0 R − D2 f (W0 )[U0 ∆> + ∆U0> ]U0

= − (∇f (W0 )U0 ) R − D2 f (W0 )[0]U0

= 0.

Replacing ∆ with U0 R in (6.32) gives U0 (R + R> )U0> = 0, which is equivalent to


R = −R> since U0 is full-rank. Since the dimension of r × r antisymmetric matrices
r(r−1) r(r−1)
is 2
, the span spanned by the solutions of (6.32) also has dimension 2
.

Definition 6.13.3 (Eigendecomposition of −D2 L(U0 )). Let

rd
X
−D2 L(U0 )[∆] = ξp hEp , ∆i Ep
p=1

be the eigendecomposition of the symmetric linear operator −D2 L(U0 )[ · ] : Rd×r →


Rd×r , where ξ1 , . . . , ξrd ∈ R are eigenvalues, E1 , . . . , Erd ∈ Rd×r are eigenvectors

214
satisfying hEp , Eq i = δpq . We enforce ξp to be 0 and Ep to be a solution of (6.32) for
r(r−1)
every rd − 2
< p ≤ rd.

Lemma 6.13.4. Let A ∈ RD×D be a matrix. If {û1 , . . . , ûK } is a set of linearly inde-
pendent left eigenvectors associated with eigenvalues λ̂1 , . . . , λ̂K and {ṽ1 , . . . , ṽD−K } is a
set of linearly independent right eigenvectors associated with eigenvalues λ̃1 , . . . , λ̃D−K ,
and hûi , ṽj i = 0 for all 1 ≤ i ≤ K, 1 ≤ j ≤ D − K, then λ̂1 , . . . , λ̂K , λ̃1 , . . . , λ̃D−K are
all the eigenvalues of A.

Proof. Let Û := (û1 , . . . , ûK )> ∈ RK×D and Ṽ := (ṽ1 , . . . , ṽD−K ) ∈ RD×(D−K) . Then
both Û and Ṽ are full-rank. Let Û + = Û > (Û Û > )−1 , Ṽ + = (Ṽ > Ṽ )−1 Ṽ > be the
pseudoinverses of Û and Ṽ .
Now we define  
 Û 
 
P :=   , Q := Û +
Ṽ .
Ṽ +

Then we have  
+
 Û Û Û Ṽ 
PQ =  .
Ṽ + Û + Ṽ + Ṽ

Note that Û Û + = IK , Û Ṽ = 0, Ṽ + Û + = (Ṽ > Ṽ )−1 (Û Ṽ )> (Û Û > )−1 = 0, Ṽ + Ṽ = ID−K .
So P Q = ID , or equivalently Q = P −1 . Then we have
 
diag(λ̂1 , . . . , λ̂K ) ∗
P −1 AP =  ,

0 diag(λ̃1 , . . . , λ̃D−K )

where ∗ can be any K × (D − K) matrix. Since P −1 AP is upper-triangular, we know


that P −1 AP has eigenvalues λ̂1 , . . . , λ̂K , λ̃1 , . . . , λ̃D−K , and so does A.

Theorem 6.13.5. The eigenvalues of J(W0 ) can be fully classified into the following
3 types:

215
1. µi + µj is an eigenvalue for every 1 ≤ i ≤ j ≤ d − r, and Ûij := vi vj> + vj vi> is
an associated left eigenvector.

r(r−1)
2. ξp is an eigenvalue for every 1 ≤ p ≤ rd − 2
, and Ṽp := Ep U0> + U0 Ep> is
an associated right eigenvector.

3. 0 is an eigenvalue, and any antisymmetric matrix is an associated right eigen-


d(d−1)
vector, which spans a linear space of dimension 2
.

Proof of Theorem 6.13.5. We first prove each item respectively, and then prove that
these are all the eigenvalues of J(W0 ).

Proof for Item 1. For Ûij = vi vj> + vj vi> , it is easy to check:

ac{−∇f (W0 ), Ûij } = (λi + λj )Ûij

Ûij W0 = 0

W0 Ûij = 0

So we have

D E D E
J(W0 )[∆, Ûij ] = (λi + λj ) ∆, Ûij − D2 f (W0 )[∆, 0] = (λi + λj ) ∆, Ûij ,

which shows that Ûij is a left eigenvector associated with eigenvalue λi + λj .

Proof for Item 2. By definition of eigenvector, we have −D2 L(U0 )[Ep ] = ξp Ep , so

ξp Ep = −∇f (W0 )Ep − D2 f (W0 )[U0 Ep> + Ep U0> ]U0 .

216
Right-multiplying both sides by U0> , we get

ξp Ep U0> = −∇f (W0 )Ep U0> − D2 f (W0 )[Ṽp ]W0

= −∇f (W0 )(Ep U0> + U0 Ep> ) − D2 f (W0 )[Ṽp ]W0

= −∇f (W0 )Ṽp − D2 f (W0 )[Ṽp ]W0 ,

where the second equality uses the fact that ∇f (W0 )U0 = 0 since U0 is a critical point.
Taking both sides into sz(·) gives

ξp Ṽp = −sz(∇f (W0 )Ṽp ) − sz(D2 f (W0 )[Ṽp ]W0 )

= J(W0 )[Ṽp ],

which proves that Ṽp is a right eigenvector associated with eigenvalue ξp .

Proof for Item 3. Since ∇f (W ) is symmetric, g(W ) is also symmetric. For any
∆ = −∆> ,
J(W0 )[∆] = J(W0 )[∆> ] = J(W0 )[−∆].

So J(W0 )[∆] = 0 and ∆ is an eigenvector associated with eigenvalue 0.

No other eigenvalues. Let Sd be the space of symmetric matrices and Ad be the


space of antisymmetric matrices. It is easy to see that Sd and Ad are orthogonal to
each other, and Sd and Ad are invariant subspaces of J(W0 )[∆]. Let h : Sd → Sd , ∆ 7→
J(W0 )[∆] be the linear operator J(W0 )[∆] restricted on symmetric matrices. We only
need to prove that h is diagonalizable.
It is easy to see that {Ûij } are linearly independent to each other and thus spans
(d−r)(d−r+1)
a subspace of Sd with dimension 2
. We can also prove that {Ṽp } spans a
r(r−1)
subspace of Sd with dimension rd − 2
by contradiction. Assume to the contrary
that there exists scalars αp for 1 ≤ p ≤ rd − r(r − 1)/2, not all zero, such that
217
Prd−r(r−1)/2 Prd−r(r−1)/2
p=1 α p Ṽp = 0. Then p=1 αp Ep is a solution of (6.32). However,
Prd−r(r−1)/2
this suggests that p=1 αp Ep lies in the span of {Ep }rd−r(r−1)/2<p≤rd , which
contradicts to the linear independence of {Ep }1≤p≤rd .
Note that

 
(d − r)(d − r + 1) r(r − 1) d(d + 1)
+ rd − = = dim(Sd ).
2 2 2

D E
Also note that Ûij , Ṽp = 2vi> Ep U0> vj + 2vj> Ep U0> vi = 0. By Lemma 6.13.4, Items 1
and 2 give all the eigenvalues of h, and thus Items 1, 2, 3 give all the eigenvalues of
J(W0 ).

6.14 Proofs for the Depth-2 Case

6.14.1 Proof for Theorem 6.5.6

Proof for Theorem 6.5.6. Since W (t) is always symmetric, it suffices to study the
dynamics of the lower triangle of W (t). For any symmetric matrix W ∈ Sd , let
d(d+1) d(d+1)
vecLT (W ) ∈ R 2 be the vector consisting of the 2
entries of W in the lower
triangle, permuted according to some fixed order.
Let g(W ) be the function defined in (6.2), which always maps symmetric matrices to
d(d+1) d(d+1)
symmetric matrices. Let g̃ : R 2 →R 2 be the function such that g̃(vecLT (W )) =
vecLT (g(W )) for any W ∈ Sd . For W (t) evolving with (6.2), we view vecLT (W (t)) as
a dynamical system.
d
vecLT (W (t)) = g̃(vecLT (W (t))).
dt

By Lemma 6.5.4, the spaces of symmetric matrices Sd and antisymmetric matrices Ad


n o
are invariant subspaces of J(0), and (µi + µj , u1[i] u>
1[j] + u u >
1[j] 1[i] ) is the set
1≤i≤j≤d
of all the eigenvalues and eigenvectors in the invariant subspace Sd . Thus, µ̃1 := 2µ1
and µ̃2 := µ1 + µ2 are the largest and second largest eigenvalues of the Jacobian
218
of g̃(·) at vecLT (W ) = 0, and ũ1 = ṽ1 = u1 u>
1 are the corresponding left and right

eigenvectors of the top eigenvalue. Then it is easy to translate Theorem 6.5.3 to


Theorem 6.5.6.

6.14.2 Proof for Theorem 6.5.8

The proof for Theorem 6.5.8 relies on the following Lemma on the gradient flow around
a local minimizer:

Lemma 6.14.1. If x̄ is a local minimizer of L(x) and for all kx − x̄k2 ≤ r, x satisfies
Lojasiewicz inequality:
k∇L(x)k2 ≥ c (L(x) − L(x̄))µ

for some µ ∈ [1/2, 1), then the gradient flow x(t) = φ(x0 , t) converges to a point x∞
near x̄ if x0 is close enough to x̄, and the distance can be bounded by kx∞ − x̄k2 =
2(1−µ)
O(kx0 − x̄k2 ).

Proof. For every t ≥ 0, if kx(t) − x̄k2 ≤ r,

 
d 1−µ −µ dx
(L(x(t)) − L(x̄)) = (1 − µ) (L(x(t)) − L(x̄)) · ∇L,
dt dt
dx
= −(1 − µ) (L(x(t)) − L(x̄))−µ · k∇Lk2 ·
dt 2
dx
≤ −(1 − µ)c .
dt 2

Rt dx 1 2(1−µ)
Therefore, kx(t) − x0 k2 ≤ 0 dt 2
dt ≤ (1−µ)c
L(x0 )1−µ = O(kx0 − x̄k2 ). If we
choose kx(t) − x̄k2 small enough, then kx(t) − x̄k2 ≤ kx(t) − x0 k2 + kx0 − x̄k2 =
2(1−µ) R +∞ dx
O(kx0 − x̄k2 ) < r, and thus 0 dt 2
dt is convergent and finite. This implies
2(1−µ)
that x∞ := limt→+∞ x(t) exists and kx∞ − x̄k2 = O(kx0 − x̄k2 ).

Proof for Theorem 6.5.8. Since W1G (t) ∈ S+


d,≤1 satisfies (6.2), there exists u(t) ∈ R
d

such that u(t)u(t)> = W1G (t) and u(t) satisfies (6.1), i.e., du
dt
= −∇L(u), where
219
L : Rd → R, u 7→ 12 f (uu> ). If W1G (t) does not diverge to infinity, then so does u(t).
This implies that there is a limit point ū of the set {u(t) : t ≥ 0}.
Let U := {u : L(u) ≥ L(ū)}. Since L(u(t)) is non-increasing, we have u(t) ∈ U for
all t. Note that ū is a local minimizer of L( · ) in U. By analyticity of f ( · ), Lojasiewicz
inequality holds for L( · ) around ū [126]. Applying Lemma 6.14.1 for L restricted
on U, we know that if u(t0 ) is sufficiently close to ū, the remaining length of the
trajectory of u(t) (t ≥ t0 ) is finite and thus limt→+∞ u(t) exists. As ū is a limit point,
this limit can only be ū. Therefore, W 1 := limt→+∞ W1G (t) = ūū> exists.
If W 1 is a minimizer of f ( · ), U = (ū, 0, · · · , 0) ∈ Rd×d is also a minimizer
of L : Rd×d → R, U 7→ 12 f (U U > ). By analyticity of f ( · ), Lojasiewicz inequality
holds for L( · ) around U . For every  > 0, we can always find a time t such that
ku(t ) − ūk2 ≤ /2. On the other hand, by Theorem 6.5.6, there exists a number α
such that for every α < α ,

1 1
φ(Wα , T (Wα ) + t ) − W1G (t ) ≤ /2, where T (W ) := log .
2 2µ1 W, u1 u>
1

Combining these together we have φ(Wα , T (Wα ) + t ) − W 1 2


≤ .
>
It is easy to construct a factorization φ(Wα , T (Wα ) + t ) := Uα, Uα, such that
Uα, − U 2
= O(), e.g., we can find an arbitrary factorization and then right-multiply
an orthogonal matrix so that the row vector with the largest norm aligns with the
direction of ū. Applying Lemma 6.14.1, we know that gradient flow starting with Uα,
converges to a point that is only O(2(1−µ) ) far from ū. So we have

lim φ(Wα , T (Wα ) + t) − W 1 = O(2(1−µ) ).


t→+∞
2

Taking  → 0 complete the proof.

220
6.14.3 Proof for Theorem 6.5.11

Theorem 6.14.2. Let W be a critical point of (6.2) satisfying that W is a local


minimizer of f ( · ) in S+ +
d,≤r for some r ≥ 1 but not a minimizer in Sd . Let −∇f (W ) =
Pd >
i=1 µi vi vi be the eigendecomposition of −∇f (W ). If µ1 > µ2 , the following limit

exists and is a solution of (6.2).

 
1 1
G
W (t) := lim φ W + v1 v1> , log + t .
→0 2µ1 

For {Wα } ⊆ S+
d , if there exists time Tα ∈ R for every α so that φ(Wα , Tα ) converges

to W with positive alignment with the top principal component v1 v1> as α → 0, then
∀t ∈ R, !
1 1
lim φ Wα , Tα + log >
+ t = W G (t).
α→0 2µ1 φ(Wα , Tα ), v1 v1

Moreover, there exists a constant C > 0 such that


!
γ̃
1 1
φ Wα , Tα + log + t − W G (t) ≤ C kφ(Wα , Tα )kF2µ1 +γ̃
2µ1 φ(Wα , Tα ), v1 v1>
F

for every sufficiently small α, where γ̃ := 2µ1 − max{µ1 + µ2 , 0}.

Proof. Following Section 6.14.1, we view vecLT (W (t)) as a dynamical system.

d
vecLT (W (t)) = g̃(vecLT (W (t))).
dt

>
Let W = U U be a factorization of W , where U ∈ Rd×r . Since W is a local minimizer
of f ( · ) in S+
d,≤r , U is also a local minimizer of L : R
d×r
→ R, U 7→ 12 f (U U > ). Since W
is not a minimizer of f ( · ) in S+
d , by Lemma 6.13.1, U is full-rank. By Theorem 6.13.5,

J(W ) has eigenvalues µi + µj , ξp , 0. By a similar argument as in Section 6.14.1, the


Jacobian of g̃ at vecLT (W (t)) has eigenvalues µi + µj , ξp .

221
Since U is a local minimizer, ξp ≤ 0 for all p. If µ1 > µ2 , then 2µ1 is the unique
largest eigenvalue, and Theorem 6.13.5 shows that vecLT (v1 v1> ) is a left eigenvector
associated with 2µ1 . The eigenvalue gap γ̃ := 2µ1 − max{µ1 + µ2 , max{ξp : 1 ≤ p ≤
r(r−1)
rd − 2
}} ≥ 2µ1 − max{µ1 + µ2 , 0}.
Also note that φ(Wα , Tα ) − W , v1 v1> = φ(Wα , Tα ), v1 v1> because W , v1 v1> =
0 by (6.31). If φ(Wα , Tα ) converges to W as α → 0, then it has positive alignment
hφ(Wα ,Tα ),v1 v> i
with v1 v1> iff lim inf α→0 φ(W ,T )−W1 > 0. Then it is easy to translate Theorem 6.5.3
k α α kF
to Theorem 6.14.2.

6.14.4 Gradient Flow only finds minimizers (Proof for Theo-

rem 6.5.10)

The proof for Theorem 6.5.10 is based on the following two theorems from the literature.

Theorem 6.14.3 (Theorem 3.1 in Du and Lee 127). Let f : Rd×d → R be a C 2 convex
function. Then L : Rd×k → R, L(U ) = f (U U > ), k ≥ d satisfies that (1). Every
local minimizer of L is also a global minimizer; (2). All saddles are strict. Here
saddles denote those stationary points whose hessian are not positive semi-definite
3
(thus including local maximizers).

Theorem 6.14.4 (Theorem 2 in Lee et al. 119). Let g be a C 1 mapping from X → X


and det(Dg(x)) 6= 0 for all x ∈ X . Then the set of initial points that converge to
an unstable fixed point has measure zero, µ {x0 : limk→∞ g k (x0 ) ∈ A∗g } = 0, where


A∗g = {x : g(x) = x, maxi |λi (Dg(x))| > 1}.

Theorem 6.14.5 (GF only finds minimizers, a continuous analog of Theorem 6.14.4).
Let f : Rd → Rd be a C 1 -smooth function, and φ : Rd × R → Rd be the solution of the
Pn
3
Though the original theorem is proven for convex functions of form i=1 `(xi U U > x> i , yi ), where
`(·, ·) is C 2 convex for its first variable. By scrutinizing their proof, we can see the assumption can be
relaxed to f is C 2 convex.
222
following differential equation,

dφ(x, t)
= f (φ(x, t)), φ(x, 0) = x, ∀x ∈ Rd , t ∈ R.
dt

Then the set of initial points that converge to a unstable critical point has measure
zero, µ x0 : limt→∞ φ(x0 , t) ∈ Uf∗ = 0, where Uf∗ = {x : f (x) = 0, λ1 (Df (x)) > 0}
 

and Df is the Jacobian matrix of f .

Proof of Theorem 6.14.5. By Theorem 1 in Section 2.3, Perko [128], we know φ(·, ·)
is C 1 -smooth for both x, t. We let g(x) = φ(x, 1), then we know g −1 (x) = φ(x, −1)
and both g, g −1 are C 1 -smooth. Note that Dg −1 (x) is the inverse matrix of Dg(x). So
both of the two matrices are invertible. Thus we can apply Theorem 6.14.4 and we
know µ {x0 : limk→∞ g k (x0 ) ∈ A∗g } = 0.


Note that if limt→∞ φ(x, t) exists, then limk→∞ g k (x) = limt→∞ φ(x, t). It remains
to show that Uf∗ ⊆ A∗g . For f (x0 ) = 0, we have φ(x0 , t) = x0 and thus g(x0 ) = x0 . Now
it suffices to prove that λ1 (Dg(x0 )) > 1. For every t ∈ [0, 1], by Corollary of Theorem

1 in Section 2.3, Perko [128], we have ∂t
Dφ(x, t) = Df (φ(x, t))Dφ(x, t), ∀x, t. Thus,


Dφ(x0 , t) = Df (φ(x0 , t))Dφ(x0 , t) = Df (x0 )Dφ(x0 , t).
∂t

Solving this ODE gives Dg(x0 ) = Dφ(x, 1) = eDf (x0 ) Dφ(x, 0) = eDf (x0 ) , where the
last equality is due to Dφ(x, 0) ≡ I, ∀x. Combining this with λ1 (Df (x0 )) > 0, we
have λ1 (Dg(x0 )) > 1.
Thus we have Uf∗ := {x0 : f (x0 ) = 0, λ1 (Df (x0 )) > 0} ⊆ A∗g , which implies that
{x0 : limt→∞ φ(x0 , t) ∈ U ∗ } ⊆ {x0 : limk→∞ g k (x0 ) ∈ A∗g }

Theorem 6.5.10. Let f : Rd×d → R be a convex C 2 -smooth function. (1). All


stationary points of L : Rd×d → R, L(U ) = 12 f (U U > ) are either strict saddles or global

223
minimizers; (2). For any random initialization, GF (6.1) converges to strict saddles
of L(U ) with probability 0.

Proof of Theorem 6.5.10. For (1), by Theorem 6.14.3, we immediately know all the
stationary points of L( · ) are either global minimizers or strict saddles. (2) is just a
direct consequence of Theorem 6.14.5 by setting f in the above proof to −∇L.

6.15 Proofs for Deep Matrix Factorization

6.15.1 Preliminary Lemmas

Lemma 6.15.1. If W (0)  0, then W (t)  0 and rank(W (t)) = rank(W (0)) for all
t.

Proof. Note that we can always find a set of balanced Ui (t), such that U1 (t) . . . UL (t) =
W (t), d2 = d3 = · · · = dL = rank(W (t)) and write the dynamics of W (t) in the space
of {Ui }Li=1 . Thus it is clear that for all t0 , rank(W (t0 )) ≤ rank(W (t)). We can apply
the same argument for t0 and we know rank(W (t)) ≤ rank(W (t0 )). Thus rank(W (t))
is constant over time, and we denote it by k. Since eigenvalues are continuous matrix
functions, and ∀t, λi (W (t)), i ∈ [k] 6= 0. Thus they cannot change their signs and it
must hold that W (t)  0.

aP −bP
Lemma 6.15.2. ∀a, b, P ∈ R, if a > b ≥ 0, P ≥ 1, then a−b
≤ P aP −1 .

Proof. Let f (x) = P (1 − x) − (1 − xP ). Since f 0 (x) = −P + P xP −1 < 0 for all


b
x ∈ [0, 1), f (x) ≥ f (0) = 0. Then substituting x by a
completes the proof.

Recall we use DF (N )[M ] to denote the directional derivative along M of F at N .

Lemma 6.15.3. Let F : S+ + P


d → Sd , M 7→ M , where P ≥ 1 and P ∈ Q. Then

∀M, N  0,
kDF (N )[M ]kF ≤ P kN kP2 −1 kM kF ,
224
F (N +tM )−F (N )
where DF (N )[M ] := limt→0 t
is the directional derivative of F along M .

Proof. Let N = U ΣU > , where U U > = I and Σ = diag(σ1 , · · · , σd ). Note that


F (U M U > ) = U F (M )U > for any M ∈ S+
d . Then we have

kF (N + tM ) − F (N )kF
kDF (N )[M ]kF = lim
t→0 t
>
F (Σ + tU M U ) − F (Σ) F
= lim
t→0 t
= DF (Σ)[U > M U ] F
.

Therefore, it suffices to prove the lemma for the case where N is diagonal, i.e., N = Σ.
1
q
Assume P = p
, where p, q ∈ N and q ≥ p > 0. Define G(N ) = N p . Then
G(Σ)p = Σ. Taking directional derivative on both sides along direction M , we have

p
X
G(Σ)i−1 DG(Σ)[M ]G(Σ)p−1 = M,
i=1

So we have
mij
[DG(Σ)[M ]]ij = P k−1 p−k .
p p p
k=1 σi σj

Let H(G) = Gq . With the same argument, we know

q k−1 q−k
X
[DH(G(Σ))[M ]]ij = mij σi p σj p .
k=1

Note that H(G(Σ)) = F (Σ). By chain rule, we have

DF (Σ)[M ] = DH(G(Σ))[DG(Σ)[M ]].

That is,
k−1 q−k
Pq p
k=1 σi σj p
[DF (Σ)[M ]]ij = mij P k−1 p−k .
p p p
k=1 σi σj

225
q−p

When σi = σj , clearly [DF (Σ)[M ]]ij = mij · q


p
· σi p = P mij σiP −1 . Otherwise, we
assume WLOG that σi > σj , we multiply σi − σj to both numerator and denominator
and we have

σiP − σjP
|[DF (Σ)[M ]]ij | = |mij | ≤ |mij | P σiP −1 ≤ |mij | P kΣkP2 −1 .
σi − σj

where the first inequality is by Lemma 6.15.2. Thus we conclude the proof.

Lemma 6.15.4. For any A, B  0 and P ∈ R, P ≥ 1,

n o
AP − B P F
≤ P kA − BkF max kAkP2 −1 , kBkP2 −1 .

Proof. Since both sides are continuous in P and Q is dense in R, it suffices to prove
the lemma for P ∈ Q. Let ρ := max {kAk2 , kBk2 } and F (M ) = M P . Define
N : [0, 1] → S+
d , N (t) = (1 − t)A + tB, we have

1. kN (t)k2 ≤ ρ, since k·k2 is convex.

2. kDF (N (t))[B − A]kF ≤ P kN (t)kP2 −1 kB − AkF by Lemma 6.15.3.

Therefore,

Z 1
dF (N (t))
kF (N (1)) − F (N (0))kF ≤ dt
0 dt F
Z 1
= kDF (N (t))[B − A]kF dt
t=0

≤ P kA − BkF ρP −1 ,

which completes the proof.

226
For a locally Lipschitz function f ( · ), the Clarke subdifferential [129–131] of f at
any point x is the following convex set

∂ ◦ f (x) n o
:= co lim ∇f (xk ) : xk → x, f is differentiable at xk ,
∂x k→∞

where co denotes the convex hull.


Clarke subdifferential generalize the standard notion of gradients in the sense that,
∂ ◦ f (x)
when f is smooth, ∂x
= {∇f (x)}. Clarke subdifferential satisfies the chain rule:

Theorem 6.15.5 (Theorem 2.3.10, Clarke 130). Let F : Rk → Rd be a differentiable


function and g : Rd → R Lipschitz around F (x). Then f = g ◦ F is Lipschitz around
x and one has
∂ ◦ f (x) ∂ ◦ g(F (x)) dF (x)
⊆ ◦ .
∂x ∂F dx

Let λm : Sd → R, M 7→ λm (M ) be the m-th largest eigenvalue of a symmetric


matrix M . The following theorem gives the Clarke’s subdifferentials of the eigenvalue:

Theorem 6.15.6 (Theorem 5.3, Hiriart-Urruty and Lewis 132). The Clarke subdif-
ferential of the eigenvalue function λm is given below, where co denotes the convex
hull:
∂ ◦ λm (M )
= co{vv > : M v = λm (M )v, kvk2 = 1}.
∂M

6.15.2 Proof of Lemma 6.6.1

Since W (t)  0 by Lemma 6.15.1, (6.11) can be rewritten as the following:

L−1
dW X 2i 2i+2
=− W L ∇f (W )W 2− L . (6.33)
dt i=0

227
Proof for Lemma 6.6.1. Suppose W (t) is a symmetric solution of (6.11). By
Lemma 6.15.1, we know W (t) also satisfies (6.33). Below we prove the lemma for
even L and odd L respectively:

1
• L is odd: let R(t) be the solution of the following ODE with R(0) := (W (0)) L .
1
Note we do not define R(t) by (W (t)) L .

L−1
dR X
=− (−1)i Ri ∇f (RL )RL−1−i . (6.34)
dt i=0

The calculation below shows that RL (t) also satisfies (6.33).

L−1 L−1 X
L−1
dRL X
j dR L−1−j
X
=− R R = (−1)i Ri+j ∇f (RL )R2L−2−i−j
dt j=0
dt j=0 i=0
L−1 X k
!
X
=− (−1)i Rk ∇f (RL )R2L−2−k
k=0 i=0
2L−2 L−1
!
X X
− (−1)i Rk ∇f (RL )R2L−2−k
k=L i=k−L+1
k 2+k
X
=− (RL ) L ∇f (RL )(RL )2− L

0≤k≤2L−2
k even
L−1
2+2i
L 2i
X
=− (R ) ∇f (RL )(RL )2−
L L .
i=0

Since RL (0) = W (0), by existence and uniqueness theorem, RL (t) = W (t),


∀t ∈ R. So

dM dR dR
=R + R = −∇f (M L/2 )M L/2 − M L/2 ∇f (M L/2 ),
dt dt dt

which completes the proof when L is odd.

228
• L is even: let M f(0) := (W (0)) L2 .
f(t) be the solution of the following ODE with M
f(t) by (W (t)) L2 .
Note we do not define M

dMf
f)L/2 )(M
f)L/2 − (M
f)L/2 ∇f ((M
f)L/2 ).
= −∇f ((M (6.35)
dt

f)L/2 (t) also satisfies (6.33).


The calculation below shows that (M

f)L/2 L/2−1
d(M
L−1
f)j dM (M
X f X
= (M f)L/2−1−j = − f)j ∇f ((M
(M f)L/2 )(M
f)L−1−j
dt j=0
dt j=0

f)L/2 (0) = W (0), by existence and uniqueness theorem, (M


Since (M f)L/2 (t) =
f(t) = W 2/L (t) = M (t), ∀t ∈ R. This completes
W (t), ∀t ∈ R. In other words, M
the proof when L is even.

6.15.3 Proof for Theorem 6.6.2

Now we turn to prove Theorem 6.6.2. Let P = L/2. Then (6.12) can be rewritten as

dM
= − ∇f (M P )M P + M P ∇f (M P ) .

(6.36)
dt

The following lemma about the growth rate of λk (M ) is used later in the proof.

Lemma 6.15.7. Suppose M (t) satisfies (6.36), we have for any T 0 > T , and k ∈ [d],

Z T0
0
λk (M (T )) − λk (M (T )) ≤ 2λk (M (t))P k∇f (M P (t))k2 dt. (6.37)
T

and

Z T0
1 0
λ1−P 1−P
2k∇f (M P (t))k2 dt.

(M (T )) − λ (M (T )) ≤ (6.38)
P −1 k k
T

229
Proof. Since λk (M (t)) is locally Lipschitz in t, by Rademacher’s theorem, we know
λk (M (t)) is differentiable almost everywhere, and the following holds

Z T0
0 dλk (M (t))
λk (M (T )) − λk (M (T )) = dt.
T dt

dλk (M (t))
When dt
exists, we have

∂ ◦ λk (M )
  
dλk (M (t)) dM (t)
∈ G, :G∈
dt dt ∂M
∂ ◦ λk (M )
 
P P
= 2λk (M (t)) G, −∇f (M (t)) : G ∈
∂M

Note that kGkF ≤ kGk∗ = 1. So G, −∇f (M P (t)) ≤ k∇f (M P (t))k2 . We can


prove (6.38) with a similar argument.

To prove Theorem 6.6.2, it suffices to consider the case that M (0) = α̂I where
α̂ := α1/P . WLOG we can assume −∇f (0) = diag(µ1 , . . . , µd ) by choosing a suitable
standard basis. By assumption in Theorem 6.6.2, we have µ1 > max{µ2 , 0} and
µ1 = k∇f (0)k2 . We use φm (M0 , t) to denote the solution of M (t) when M (0) = M0 .
Let R > 0. Since f ( · ) is C 3 -smooth, there exists β > 0 such that

k∇f (W1 ) − ∇f (W2 )kF ≤ β kW1 − W2 k2

for all W1 , W2 with kW1 k2 , kW2 k2 ≤ R.


1
Let κ = β/µ1 . We assume WLOG that R ≤ κ(P −1)
. Let Fα̂ (x) :=
R α̂−(P −1) (P −1)x−P P −1
dz
x−(P −1) 1+κz −P/(P −1)
. Then Fα̂0 (x) = 1+κxP
= (1+κxP )xP
. We will use this
1
function to bound norm growth. Let gα̂,c (t) = α̂−(P −1) −κ(P −1)c−2µ1 (P −1)t
. Define
α̂−(P −1) −κ(P −1)r−r−(P −1)
Tα̂ (r) = 2µ1 (P −1)
. It is easy to verify that gα̂,r (Tα̂ (r)) = rP −1 .

230
Lemma 6.15.8. For any x ∈ [α̂, R] we have

α̂−(P −1) − x−(P −1) − Fα̂ (x) ∈ [0, κ(P − 1)x].




Proof. On the one hand, we have

Z α̂−(P −1)  
−(P −1) −(P −1) 1
α̂ −x − Fα̂ (x) = 1− dz ≥ 0.
x−(P −1) 1 + κz −P/(P −1)

On the other hand,

Z α̂−(P −1) Z α̂−(P −1)


−(P −1) −(P −1) κ 1
α̂ −x − Fα̂ (x) = dz ≤ κ dz
x−(P −1) z P/(P −1) + κ x−(P −1) z P/(P −1)
α̂−(P −1)
−1
= κ(P − 1) ·
z 1/(P −1) x−(P −1)

≤ κ(P − 1)x,

which completes the proof.

Lemma 6.15.9. Let M0 be a PSD matrix with kM0 k2 ≤ 1. For M (t) := φm (α̂M0 , t)
and t ≤ Tα̂ (c),
1
kM (t)k2 = λ1 (M (t)) ≤ gα̂,c (t) P −1 .

Proof. Since k∇f (M P )k2 ≤ k∇f (0)k2 + βkM kP2 ≤ µ1 + β(λ1 (M ))P , by Lemma 6.15.7,
we have

Z t
λ1 (M (t)) ≤ λ1 (M (0)) + 2(µ1 + β(λ1 (M (τ )))P )(λ1 (M (τ )))P dτ
0
Z t

= α̂ + 2µ1 (P − 1) 0
0 Fα̂ (λ1 (M (τ ))

So
Fα̂ (λ1 (M (t))) ≤ 2µ1 (P − 1)t.

231
1
If kM (t)k2 < α̂, then kM (t)k2 ≤ gα̂,c (t) P −1 . If kM (t)k2 ≥ α̂, then by Lemma 6.15.8,

Fα̂ (kM (t)k2 ) ≤ 2µ1 (P − 1)Tα̂ (c) = α̂−(P −1) − κ(P − 1)c − c−(P −1) ≤ Fα̂ (c),

so kM (t)k2 ≤ c for all t ≤ Tα̂ (c). Applying Lemma 6.15.8 again, we have

−(P −1)
α̂−(P −1) − kM (t)k2 ≤ F (kM (t)k2 ) + κ(P − 1)c ≤ 2µ1 (P − 1)t + κ(P − 1)c,

1
which implies kM (t)k2 ≤ gα̂,c (t) P −1 by definition.

Consider the following ODE:

dMc 
P P

= − ∇f (0)M + M ∇f (0) .
c c
dt

We use φ̂m (M
c0 , t) to denote the solution of M
c(t) when M
c(0) = M
c0 . For diagonal

matrix M
c0 , M
c(t) is also diagonal for any t, and it is easy to show that

 1
 P −1
1
e>
i M0 ei 6= 0,

 c
−(P −1)

e> (α̂e>i M
c0 ei ) −2µi (P −1)t
i M (t)ei = (6.39)
c

e>

0
i M0 ei = 0.
c

Remark 6.15.10. Unlike depth-2 case, the closed form solution, M


c(t) is only tractable

for diagonal initialization, i.e., (6.39) (note that the identity matrix is diagonal). And
this is the main barrier for extending our two-phase analysis to the case of general
initialization when L ≥ 3. In Section 6.6.1, we give a more detailed discussion on this
barrier.

The following lemma shows that the trajectory of M (t) is close to M


c(t).

232
Lemma 6.15.11. Let M0 be a diagonal PSD matrix with kM0 k2 ≤ 1. For M (t) :=
φm (α̂M0 , t) and M
c(t) := φ̂m (α̂M0 , t), we have

c(Tα̂ (r))kF = O(rP +1 ).


kM (Tα̂ (r)) − M

Proof. We bound the difference D := M − M


c between M and M
c.

dD  
= 2 ∇f (0) M P − M
cP + ∇f (M P ) − ∇f (0) M P

dt F F
 
P
cP kF + k∇f (M P ) − ∇f (0)kF kM P k2
≤ 2 k∇f (0)k2 kM − M
 
P −1 P −1 2P
≤ 2 µ1 P max{kM k2 , kM k2 }kDkF + β kM k2 ,
c

where the last step is by Lemma 6.15.4. This implies that

Z t Z t
dD(τ )  2P

kD(t)kF ≤ dτ ≤ 2 µ1 P gα̂,r (τ ) kD(τ )kF + βgα̂,r (τ ) P −1 dτ.
τ =0 dτ F 0

So
!
Z Tα̂ (r) Z Tα̂ (r)
2P
kD(Tα̂ (r))kF ≤ 2βgα̂,r (t) P −1 exp 2µ1 P gα̂,r (τ )dτ dt
0 t
Z Tα̂ (r)  
2P P gα̂,r (Tα̂ (r))
= 2βgα̂,r (t) P −1 exp ln dt
0 P −1 gα̂,r (t)
Z Tα̂ (r)
P P
= 2βgα̂,r (t) P −1 gα̂,r (Tα̂ (r)) P −1 dt
0
1 1 P
= 2β · gα̂,r (Tα̂ (r)) P −1 · gα̂,r (Tα̂ (r)) P −1
2µ1
P +1
= κgα̂,r (Tα̂ (r)) P −1

= κrP +1 .

which proves the bound.

233
Lemma 6.15.12. Let M (t) = φm (α̂M0 , t), M f0 , t). If max{kM0 k2 , kM
f(t) = φm (α̂M f0 k2 } ≤

1. For t ≤ Tα̂ (r), we have

 r P P
kM (t) − M
f(t)kF ≤ e2κr kM (0) − M
f(0)kF .
α̂

Proof. Define D(t) = M (t) − M


f(t). Then we have

dD      
=2 ∇f (M P ) M P − M
fP + ∇f (M P ) − ∇f (M
fP ) MfP
dt F F
 
≤ 2 k∇f (M P )k2 kM P − M
fP kF + βkM P − M
fP kF kM
fP k2
 
≤ 2 µ1 + βkM k2 + βkM k2 P max{kM kP2 −1 , kM
f P P fkP −1 }kDkF ,
2

where the last step is by Lemma 6.15.4. So


!
Z Tα̂ (r)  
P
kD(Tα̂ (r))kF ≤ kD(0)kF · exp 2P µ1 1 + 2κgα̂,r (t) P −1 gα̂,r (t)dt
0
 
P gα̂,r (Tα̂ (r)) P
≤ kD(0)kF · exp ln + 2κgα̂,r (Tα̂ (r)) P −1
P −1 gα̂,r (0)
 r P P
≤ kD(0)kF e2κr ,
α̂

which proves the bound.


 
α̂−(P −1)
Let MαG (t) := φm αe1 e> ,
1 2µ1 (P −1) + t . Let M (t) := lim MαG (t).
α→0

Lemma 6.15.13. For every t ∈ (−∞, +∞), M (t) exists and Mα̂G (t) converges to
M (t) in the following rate:

Mα̂G (t) − M (t) F


= O(α̂).

−κ(P −1)c−c−(P −1)


Proof. Let c be a sufficiently small constant. Let T̄ := 2µ1 (P −1)
. We prove this
lemma in the cases of t ∈ (−∞, T̄ ] and t > T̄ respectively.

234
α̂−(P −1)
Case 1. Fix t ∈ (−∞, T̄ ]. Then 2µ1 (P −1)
+ t ≤ Tα̂ (c). Let α̃ be the unique number
such that κ(P − 1)α̃ + α̃−(P −1) = α̂−(P −1) . Let α̂0 < α̂ be an arbitrarily small number.
(α̂0 )−(P −1) −α̂−(P −1)
Let t0 := Tα̂0 (α̃) = 2µ1 (P −1)
. By Lemma 6.15.11 and (6.39), we have

φm (α̂0 e1 e> >


1 , t0 ) − α̂e1 e1 F
≤ φm (α̂0 e1 e> 0 >
1 , t0 ) − φ̂m (α̂ e1 e1 , t0 ) ≤ O(α̃P +1 ).
F

By Lemma 6.15.9, kφm (α̂0 e1 e>


1 , t0 )k2 ≤ α̃. Then by Lemma 6.15.12, we have

 c P P
φm (α̂0 e1 e> >
1 , t0 + t) − φ(α̂e1 e1 , t) F
≤ e2κc · O(α̃P +1 ) = O(α̃) = O(α̂).
α̃

This implies that {Mα̂G (t)} satisfies Cauchy’s criterion for every t, and thus the limit
M (t) exists for t ≤ T̄ . The convergence rate can be deduced by taking limits for
α̂0 → 0 on both sides.

Case 2. For t = T̄ + τ with τ > 0, φm (M, τ ) is locally Lipschitz with respect to M .


So

Mα̂G (t) − Mα̂G0 (t) F


= φm (Mα̂G (T̄ ), τ ) − φm (Mα̂G0 (T̄ ), τ ) F

= O( Mα̂G (T̄ ) − Mα̂G0 (T̄ ) F


)

= O(α̂),

which proves the lemma for t > T̄ .

Theorem 6.15.14. For every t ∈ (−∞, +∞), as α → 0, we have:

α̂−(P −1)
 
1
φm α̂I, + t − M (t) = O(α̂ P +1 ), (6.40)
2µ1 (P − 1) F

235
and for any 2 ≤ k ≤ d,

α̂−(P −1)
  
λk φm α̂I, +t = O(α̂). (6.41)
2µ1 (P − 1)

 
α̂−(P −1)
Proof. Let Mα̂ (t) := φm α̂I, 2µ 1 (P −1)
+ t . Again we let c be a sufficiently small
−κ(P −1)c−c−(P −1)
constant and T̄ := 2µ1 (P −1)
. We prove in the cases of t ∈ (−∞, T̄ ] and t > T̄
respectively.

1
Case 1. Fix t ∈ (−∞, T̄ ]. Let α̂1 := α̂ P +1 . Let α̃1 be the unique number such that
−(P −1)
−(P −1) −(P −1) α̂−(P −1) −α̂1
κ(P − 1)α̃1 + α̃1 = α̂1 . Let t0 := Tα̂ (α̃1 ) = 2µ1 (P −1)
. Then

φm (α̂I, t0 ) − α̂1 e1 e>


1 F
≤ φm (α̂I, t0 ) − φ̂m (α̂I, t0 ) + φ̂m (α̂I, t0 ) − α̂1 e1 e>
1
F F

= O(α̃1P +1 + α̂)

= O(α̂).

By Lemma 6.15.9, kφm (α̂0 I, t0 )k2 ≤ α̃1 . Then by Lemma 6.15.12, we have

Mα̂ (t) − Mα̂G1 (t) F


= φm (α̂I, t0 + t) − φm (α̂1 e1 e>
1 , t) F
 P
c P 1
≤ e2κc · O(α̂) = O(α̂ P +1 ).
α̃1

Combining this with the convergence rate for Mα̂G1 (t) proves the bound (6.40).
For (6.41), by Lemma 6.15.7, we have

Z T̄
λ1−P
k (Mα̂ (T̄ )) − λ1−P
k (Mα̂ (t0 )) ≤ 2(P − 1) ∇f (Mα̂P (t)) 2
dt
t0
Z T̄
≤ 2(P − 1)(µ1 + β kMα̂ (t)kP2 ))dt (6.42)
t0
 κ 1

≤ −2(P − 1) µ1 (t − T1 ) + · gα̂,c (t) P −1 .
2

236
By Lemma 6.15.11, λ1 (Mα̂ (T̄ )) = Mα̂ (T̄ ) 2
= c + O(cP +1 ). For k ≥ 2,

−(P −1) −(P −1)


 κ 
λk (Mα̂ (T̄ )) ≥ Ω(α̂ ) − 2(P − 1) µ1 (T̄ − T1 ) + · c
2
−1
−P −(P −1)
α̂ P +1 −c
≥ Ω(α̂−(P −1) ) − − O(c)
2µ1 (P − 1)
≥ Ω(α̂−(P −1) ).

Thus λk (Mα̂ (T̄ )) ≤ O(α̂).

Case 2. For t = T̄ + τ with τ > 0, φm (M, τ ) is locally Lipschitz with respect to M .


So

Mα̂ (t) − Mα̂G1 (t) F


= φm (Mα̂ (T̄ ), τ ) − φm (Mα̂G1 (T̄ ), τ ) F
1
= O Mα̂ (T̄ ) − Mα̂G1 (T̄ ) F = O(α̂ P +1 ),


which proves the bound (6.40).


For (6.41), again by Lemma 6.15.7, we have

λ1−P
k (Mα̂ (T̄ )) − λ1−P
k (Mα̂ (T̄ + τ ))
Z T̄ +τ
≤ 2(P − 1) ∇f (Mα̂P (t)) 2 dt

Z T̄ +τ
2(P − 1) β Mα̂P (t) − (M G )P (t) G P
 
≤ 2
+ ∇f (M ) (t) 2
dt

Z T̄ +τ
1 P
≤ 2(P − 1)(O(α̂ 1+P ) + β M G (t) 2
)dt

≤ O(1).

Thus λ1−P
k (Mα̂ (T̄ + τ )) = Ω(α̂−(P −1) ), that is, λk (Mα̂ (T̄ + τ )) = O(α̂), ∀k ≥ 2.

237
P
Proof of Theorem 6.6.2. Note that M (t) = W (t) and

P !
α̂−(P −1) α̂−(P −1)
 
φm α̂I, +t = φ αI, +t .
2µ1 (P − 1) 2µ1 (P − 1)

By Theorem 6.15.14, We have

α−(1−1/P )
 
φ αI, + t − W (t)
2µ1 (P − 1) F
  −(P −1)
P
α̂ P
≤ φm α̂I, +t − M (t)
2µ1 (P − 1)
F
−(P −1)
P −1
α̂−(P −1)
    
α̂
≤ P φm α̂I, + t − M (t) max φm α̂I, +t , M (t) 2
2µ1 (P − 1) F 2µ1 (P − 1) 2
1 1
= O(α̂ P +1 )O(1) = O(α P (P +1) ),

and for 2 ≤ k ≤ d,

α−(1−1/P ) α̂−(P −1)


     
λk φ αI, +t = λk φm α̂I, +t = O(α̂P ) = O(α).
2µ1 (P − 1) 2µ1 (P − 1)

6.16 Proof of Linear Convergence to Minimizer

In this section, we will present the theorems that guarantee the linear convergence to
a minimizer W0 of f ( · ) if the dynamics (6.43) is initialized sufficiently close to W0 ,
i.e., kW (0) − W0 kF is sufficiently small. In Section 6.16.3, we will apply this result to
prove Theorem 6.6.4.

L−1
dW X 2i 2i+2
=− W L ∇f (W )W 2− L =: g(W ). (6.43)
dt i=0

238
Throughout this section, we assume rank(W0 ) = k and use m := λk (W0 ) to denote
the k-th smallest non-zero eigenvalue of W0 . The tangent space of manifold of rank-k
symmetric matrices at W0 is T = {V W0> + W0 V > : V ∈ Rd×d }. It can be shown that
k(k+1) k(2d−k+1)
dim(T ) = k(d − k) + 2
= 2
.
Let J(W ) be the Jacobian of g(W ) in (6.43). For depth-2 case, we have shown
that T is an invariant subspace of J(W0 ) in Theorem 6.13.5, property 2. This can
be generalize to the deep case where L ≥ 3. Therefore, we can use J(W0 )|T : T → T
2
to denote the linear operator J(W0 ) restricted on T . We also define Πd1 (W ) as the
2 2
projection of W ∈ Rd×d on T , and Πd2 (W ) := W − Πd1 (W ).
Towards showing the main convergence result in the section, we make the following
assumption.

Assumption 6.16.1. Suppose J(W0 )|T is diagonalizable and all eigenvalues are
negative real numbers.

W0 is a minimizer, so it is clear that J(W0 )|T has no eigenvalues with positive real
parts (otherwise there is a descending direction of f ( · ) from W0 , since the loss f ( · )
strictly decreases along the trajectory of (6.43)). If further Assumption 6.16.1 holds,
then we know J(W0 )|T : T → T can be diagonalized as J(W0 )|T [ · ] = V(ΣV −1 ( · )),
where Σi = diag(−µ1 , . . . , −µdim(T ) ), V : Rdim(T ) → T , V(x) = dim(T )
P
i=1 xi Vi , and Vi is
the eigenvector associated with eigenvalue −µi .
As shown in Theorem 6.16.3 below, this assumption implies that if W (0) is rank-k
and is sufficiently close to W0 , then kW (t) − W0 kF ≤ Ce−µ1 t for some constant C.
For depth-2 case, the above assumption is equivalent to that L(U0 ) is “strongly convex”
at U0 , except those 0 eigenvalues due to symmetry, by property 2 of Theorem 6.13.5).
For the case where L ≥ 3, because this dynamics is not gradient flow, in general it
does not correspond to a loss function and strongly convexity does not make any
sense. Nevertheless, in experiments we do observe linear convergence to W0 , so this
assumption is reasonable.
239
6.16.1 Rank-k Initialization

For convenience, we define for all W ∈ Sd ,

 2
 2 2
−1
kW kV := V Πd1 (W ) , kW kF,1 := Πd1 (W ) , kW kF,2 := Πd2 (W ) .
F F F

The reason for such definition of norms, as we will see later, is that the norm (or
the difference) in the tangent space of the manifold of symmetric rank-r matrices,
kW − W 0 kF,1 , dominates that in the orthogonal complement of the tangent space,
kW − W 0 kF,2 , when both W, W 0 get very close to the W0 (see a more rigorous statement
in Lemma 6.16.2). WLOG, we can assume

k · kF,1
≤ k · kV ≤ k · kF,1 ,
K

for some constant K, which may depend on f and W0 . This also implies that
k · kV ≤ k · kF . Below we also assume for sufficiently small R, and any W such that
kW − W0 kF ≤ R, we have k∇f (W )k2 ≤ ρ and kJ(W )[∆]kF ≤ β k∆kF for any ∆.
In the proof below, we assume such properties hold as long as we can show the
boundedness of W (t) − W0 .

Lemma 6.16.2. Let max{kW − W0 kF,1 , kW 0 − W0 kF,1 } = r, when r ≤ m


2
, we have

5r
kW − W 0 kF,2 ≤ kW − W 0 kF,1 .
m

As a special case, we have

5 kW − W 0 k2F,1
kW − W0 kF,2 ≤ .
m

240
Proof. WLOG we can assume W0 is only non-zero in the first k dimension, i.e.,
[W0 ]ij = 0, for all i ≥ k + 1, j ≥ k + 1. We further denote W and W 0 by

   
> 0 0>
A B  A B 
W =  and W 0 =  ,
B C B0 C 0

where A, A0 ∈ Rk×k , B, B 0 ∈ R(d−k)×k , C, C 0 ∈ R(d−k)×(d−k) . By definition, we have


kA − A0 kF , kB − B 0 kF ≤ kW − W 0 kF,1 and kW − W 0 kF,2 = kC − C 0 kF . Moreover,
m
we have λmin (A) ≥ m − kA − W0 kF ≥ m − kW − W0 kF,1 ≥ 2
.
Since W, W 0 is rank-k, we have C = BA−1 B > , C 0 = B 0 A0 −1 B 0 > . Thus

kW − W 0 kF,2

= kC − C 0 kF
−1 >
= BA−1 B > − B 0 A0 B 0
F
−1 > −1 >
≤ kB − B 0 kF kA−1 B > kF + kBA−1 kF kA0 − AkF kA0 B 0 kF + kB 0 A0 kF kB > − B 0 kF
 2
0 2r 0 2r 2r
≤ kW − W kF,1 + kW − W kF,1 + kW − W 0 kF,1
m m m
5r
≤ kW − W 0 kF,1 .
m

Theorem 6.16.3 (Linear convergence of rank-k matrices). Suppose that rank(W (0)) =
rank(W0 ) = k and

 
m µ1
kW (0) − W0 kV ≤ R := max , 2 ,
2K K (29β + 10ρ/m)

we have kW (t) − W0 kV ≤ Ce−µ1 t kW (0) − W0 kV for some constant C depending on


W0 , where W (t) satisfies (6.43).

241
2 2
Proof. For convenience, we define W1 (t) := Πd1 (W (t) − W0 ) , W2 (t) := Πd2 (W (t) − W0 ) =
2
Πd2 (W (t)). We also use h·, ·iV −1 = hV −1 (·) , V −1 (·)i for short.

d kW1 (t)k2V d kW (t) − W0 k2V


=
dt  dt
  
d 2 dW (t) d2
= 2 Π1 , Π1 (W (t) − W0 )
dt V −1
D 2 E
= 2 Πd1 (g(W (t))) , W1 (t)
V −1
D 2 E
≤ 2 Πd1 (J(W0 )[W (t) − W0 ]) , W1 (t)
V −1

+ 2 kg(W (t) − W0 ) − J(W0 )[W (t) − W0 ]kV kW (t) − W0 kV


D 2 E
= 2 Πd1 (J(W0 )[W1 (t) + W2 (t)]) , W1 (t)
V −1

+ 2 kg(W (t) − W0 ) − J(W0 )[W (t) − W0 ]kV kW1 (t)kV


D 2 E
= 2 Πd1 (J(W0 )[W1 (t)]) , W1 (t) + 2 kJ(W0 )[W2 (t)]kV kW1 (t)kV
V −1

+ 2 kg(W (t) − W0 ) − J(W0 )[W (t) − W0 ]kV kW1 (t)kV .

D E
2
For the first term Πd1 (J(W0 )[W1 (t)]) , W1 (t) , we know W1 (t) ∈ T , and T is an
V −1
invariant space of J(W0 ). Recall J(W0 )|T [·] = V (ΣV −1 (·)), we have

D 2 E
2 Πd1 (J(W0 )[W1 (t)]) , W1 (t) = 2 ΣV −1 (W1 (t)) , V −1 (W1 (t)) ≤ −2µ1 kW1 (t)kF,1 .
V −1

For the second term 2β kJ(W0 )[W2 (t)]kV kW1 (t)kV , we have

2 kJ(W0 )[W2 (t)]kV ≤ 2 kJ(W0 )[W2 (t)]kF ≤ 2 kJ(W0 )k2 kW2 (t)kF = 2ρ kW2 (t)kF .

242
For the third term 2 kg(W (t) − W0 ) − J(W0 )[W (t) − W0 ]kV kW1 (t)kV , we have

2 kg(W (t) − W0 ) − J(W0 )[W (t) − W0 ]kV ≤ 2β kW (t) − W0 k2F

≤ 4β(kW1 (t)k2F + kW2 (t)k2F )

≤ 4β(K 2 kW1 (t)k2V + kW2 (t)k2F ).

Thus we have shown the following. Note so far we have not used the assumption that
W is rank-k.

d kW1 (t)k2V
≤ −2µ1 kW1 (t)k2V +2 kW1 (t)kV ρ kW2 (t)kF + 2βK 2 kW1 (t)k2V + 2β kW2 (t)k2F ,

dt

that is,

d log kW1 (t)k2V 2 4β kW2 (t)k2F + 2ρ kW2 (t)kF


≤ −2µ1 + 4βK kW1 (t)kV + . (6.44)
dt kW1 (t)kV

Let T := sup{t ≥ 0 : kW1 (t)kV ≤ m


2K
}. Setting W 0 = W0 in Lemma 6.16.2, we have
m
for t < T , r = kW (t) − W0 kF,1 ≤ kW (t) − W0 kF ≤ K kW (t) − W0 kV ≤ 2
. Thus,

5 kW (t) − W0 k2F,1 5K 2 kW (t) − W0 k2V 5


kW2 (t)kF = kW2 (t)kF,2 ≤ ≤ ≤ m.
m m 4

Thus, from (6.44) we can derive that

d log kW1 (t)k2V


≤ −2µ1 + K 2 (29β + 10ρ/m)) kW1 (t)kV ≤ −µ1 . (6.45)
dt

Since µ1 < 0, kW1 (t)kV decreases for [0, T ). Thus T must be ∞, otherwise
kW1 (T )kV = limt→T − kW1 (t)kV < R1 . Contradiction.

243
µ1
Therefore, for any t ∈ [0, ∞), we have kW1 (t)kV ≤ kW1 (0)kV e− 2
t
. That is,

Z ∞
2 2R
kW1 (t)kV dt ≤ kW1 (0)kV ≤ .
0 µ1 µ1

Thus from (6.45), we have

Z ∞
K2
 
kW (t)kV = kW1 (t)kV ≤ kW1 (0)kV exp −µ1 t + (29β + 10ρ/m) kW1 (t)kV dt
2 0
K 2R
 
≤ kW1 (0)kV exp −µ1 t + (29β + 10ρ/m)
µ1
=: C kW (0)kV e−µ1 t ,

which completes the proof.

6.16.2 Almost Rank-k Initialization

We use M (t) to denote the top-k components of W (t) in SVD, and N (t) to denote
the rest part, i.e., W (t) − M (t). One can think M (t) as the main part and N (t) as
the negligible part.
Below we show that for deep overparametrized matrix factorization, where W (t)
satisfies (6.43), if the trajectory is initialized at some W (0) in a small neighborhood
of the k-th critical point W0 of deep GLRL, and W (0) is approximately rank-k, in
the sense that N (0) is very small, then inf t≥0 kW (t) − W0 kV is roughly at the same
magnitude of N (0).

Theorem 6.16.4 (Linear convergence of almost rank-k matrices, deep case). Suppose
W0 is a critical point of rank k and W0 satisfies Assumption 6.16.1, there exists
constants C0 and r, such that if C0 kN (0)kF ≤ kW1 (0)kV ≤ r, then there exists a time
T and constants C, C 0 , such that

(1). kW (t) − W0 kV ≤ Ce−µ1 t/2 kW (0) − W0 kV , for t ≤ T .

244
(2). kW (T ) − W0 kF ≤ C 0 kN (0)kF .

λmin (W0 ) λmin (W0 )


Proof. When kW (t) − W0 kF ≤ 4
, kN (t)kF ≤ 4
, thus we have

λmin (W0 )
kM (t) − W0 kF,1 ≤ kW (t) − W0 kF,1 + kN (t)kF,1 ≤ ,
2

thus by Lemma 6.16.2, we have

kW2 (t)kF,2 ≤ kM (t) − W0 kF,2 + kN (t)kF,2


5 kM (t) − W0 k2F,1
≤ + kN (t)kF,2
λmin (W0 )
10 kW1 (t)k2F,1 + 10 kN (t)k2F
≤ + kN (t)kF,2
λmin (W0 )
10K 2 kW1 (t)k2V + 10 kN (t)k2F
≤ + kN (t)kF,2 .
λmin (W0 )

Thus we can pick constant C0 large enough and r small enough, such that for any
t ≥ 0, if C0 kN (t)kF ≤ kW1 (t)kV ≤ r, then it holds that:

• The “small terms” in the RHS of (6.44) satisfies that

2 4β kW2 (t)k2F + 2ρ kW2 (t)kF


4βK kW1 (t)kV + ≤ C1 kW1 (t)kV +C2 kN (t)kF ≤ µ1
kW1 (t)kV

for some C1 and C2 independent of t.

1
• The spectral norm 2
k∇f (W (t))k2 ≤ k∇f (W0 )k2 =: ρ for all t ≥ 0.
2
κL x L −1 2 L−2
• ∀x < r, (L−2)ρ
> µ1
ln C2r0 x , where κL = 1 − 0.5 L .

Note these conditions can always be satisfied by some C0 and r because we can first
find 3 groups (C0 , r) to satisfy each individual condition, and then take the maximal
C0 and minimal r, it’s easy to check these conditions are still verified. And we let
TC0 ,r be the earliest time that such condition, i.e., C0 kN (t)kF ≤ kW1 (t)kV ≤ r fails.

245
µ1 t
Thus by (6.44), for t ∈ [0, TC0 ,r ), we have kW (t)kV = kW1 (t)kV ≤ kW1 (0)kV e− 2 =
µ1 t
kW (0)kV e− 2 . Thus (1) holds for any T smaller than TC0 ,r . If TC0 ,r = ∞, then clearly
we can pick a sufficiently large T , such that (2) holds. Therefore, below it suffices to
consider the case where TC0 ,r is finite. And we know the condition that fails must be
C0 kN (t)kF ≤ kW1 (t)kV , i.e. C0 kN (TC0 ,r )kF = kW1 (TC0 ,r )kV .
By (6.38) in Lemma 6.15.7, we have

2 2
−1 −1
kN (0)k2L − kN (t)k2L ≤ (L − 2)ρt.

2 −1
2 2
κL kN (0)k2L −1 −1
Define T 0 := (L−2)ρ
, we know for any t < T 0 , we have kN (0)k2L − kN (t)k2L ≤
2
−1
κL kN (t)k2L . That is,

2
−1   h
kN (t)k2L 1 L−2 L−2
i kN (t)k2
2 ∈ 1 − κL , = 0.5 L , 0.5− L =⇒ ∈ [1/2, 2].
kN (0)k2L
−1 1 − κL kN (0)k2

Now we claim it must hold that T 0 ≥ TC0 ,r . Otherwise, we have

C0 0 0
kN (0)k2 ≤ C0 kN (T 0 )kF ≤ kW1 (T 0 )kV ≤ e−µ1 T /2 kW1 (0)kV ≤ e−µ1 T /2 r.
2

2 −1
κL kN (0)k2L
Therefore, (L−2)ρ
= T0 ≤ 2
µ1
ln C0 kN2r(0)k , which contradicts to the definition of C0
2

and r.
As a result, we have


2C0 d kN (0)k2 ≥ 2C0 kN (0)kF ≥ C0 kN (Tc0 ,r )kF = kW1 (TC0 ,r )kV

≥ kW1 (0)kV e−µ1 TC0 ,r /2 ,

and therefore,
2 kW1 (0)kV
TC0 ,r ≤ ln √ .
µ1 2 dC0 kN (0)kF

246
Thus by Lemma 6.16.2, we know

kW (TC0 ,r ) − W0 kF ≤ kW (TC0 ,r ) − W0 kF,1 + kW (TC0 ,r ) − W0 kF,2

≤ K kW1 (TC0 ,r )kV + kM (TC0 ,r ) − W0 kF,2 + kN (TC0 ,r )kF,2

≤ O(kN (0)kF ) + O(kN (0)k2F ) + O(kN (0)kF )

= O(kN (0)kF ).

6.16.3 Proof for Theorem 6.6.4

Proof for Theorem 6.6.4. Let C0 , r be the constants predicted by Theorem 6.16.4 w.r.t.
to W (∞). We claim that we can pick large enough constant T , and α0 sufficiently
small, such that for all α ≤ α0 , the initial condition in Theorem 6.16.4 holds, i.e.
 −(P −1)

C0 kN (0)kF ≤ kW1 (0)kV ≤ r, where W (0) := φ αI, 2µα−1 (P −1) + T .
1

This is because we can first ensure W (T ) − W (∞) 2


is sufficiently small, i.e.,
smaller than 2r . By Theorem 6.6.2, we know when α → 0, W (T ) − W (0) V

K W (T ) − W (0) F
= o(1) and kN (0)kF = O(α).
By Theorem 6.16.4, we know there is a time T (either TC0 ,r or some sufficiently
large number when TC0 ,r = ∞), such that kW (T ) − W0 kF = O(kN (0)kF ) = O(α).

247
Chapter 7

Implicit Bias of Parametrization:


On Equivalence to Mirror Descent

As part of the effort to understand implicit bias of gradient descent in overparametrized


models, several results have shown how the training trajectory on the overparametrized
model can be understood as mirror descent on a different objective. The main result in
this chapter is a characterization of this phenomenon under a notion termed commuting
parametrization, which encompasses all the previous results in this setting. It is shown
that gradient flow with any commuting parametrization is equivalent to continuous
mirror descent with a related Legendre function. Conversely, continuous mirror descent
with any Legendre function can be viewed as gradient flow with a related commuting
parametrization. The latter result relies upon Nash’s embedding theorem.

7.1 Introduction

Implicit bias refers to the phenomenon in machine learning that the solution obtained
from loss minimization has special properties that were not implied by value of the
loss function and instead arise from the trajectory taken in parameter space by the
optimization. Quantifying implicit bias necessarily has to go beyond the traditional
248
black-box convergence analyses of optimization algorithms. Implicit bias can explain
how choice of optimization algorithm can affect generalization [20, 25, 133].
Many existing results about implicit bias treat training (in the limit of infinitesimal
step size) as a differential equation or process {x(t)}t≥0 ⊂ RD . To show the implicit
bias of x(t), the idea is to show for another (more intuitive or better understood)
process {w(t)}t≥0 ⊂ Rd that x(t) is simulating w(t), in the sense that there exists a
mapping G : RD → Rd such that w(t) = G(x(t)). Then the implicit bias of x(t) can
be characterized by translating the special properties of w(t) back to x(t) through
G. A related term, implicit regularization, refers to a handful of such results where
particular update rules are shown to lead to regularized solutions; specifically, x(t) is
simulating w(t) where w(t) is solution to a regularized version of the original loss.
The current paper develops a general framework involving optimization in the
continuous-time regime of a loss L : Rd → R that has been re-parametrized before
optimization1 as w = G(x) for some G : RD → Rd . Then the original loss L(w) in
the w-space induces the implied loss (L ◦ G)(x) ≡ L(G(x)) in the x-space, and the
gradient flow in the x-space is given by

dx(t) = −∇(L ◦ G)(x(t))dt. (7.1)

Using w(t) = G(x(t)) and the fact that ∇(L ◦ G)(x) = ∂G(x)> ∇L(G(x)) where
∂G(x) ∈ Rd×D denotes the Jacobian of G at x, the corresponding dynamics of (7.1)
in the w-space is

dw(t) = ∂G(x(t))dx(t) = −∂G(x(t))∂G(x(t))> ∇L(w(t))dt. (7.2)

1
Two examples from recent years, where G does not change expressiveness of the model, involve
(a) overparametrized linear regression where the parameter vector w is reparametrized (for example as
w = u 2 −v 2 [20]) and (b) deep linear nets [95] where a matrix W is factorized as W = W1 W2 · · · WL
where each W` is the weight matrix for the `-th layer.
249
Our framework is developed to fully understand phenomena in recent papers [20,
105, 134–138], which give examples suggesting that gradient flow in the x-space
could end up simulating a more classical algorithm, mirror descent (specifically, the
continuous analog, mirror flow) in the w-space. Recall that mirror flow is continuous-
time limit of the classical mirror descent, written as d∇R(w(t)) = −∇L(w(t))dt where
R : Rd → R ∪ {∞} is a strictly convex function [139, 140], which is called mirror map
or Lengendre function in literature. Equivalently it is Riemannian gradient flow with
metric tensor ∇2 R, an old notion in geometry:

dw(t) = −∇2 R(w(t))−1 ∇L(w(t))dt. (7.3)

If there exists a Legendre function R such that ∂G(x(t))∂G(x(t))> = ∇2 R(w(t))−1 for


all t, then (7.2) becomes a simple mirror flow in the w-space. Many existing results
about implicit bias indeed concern reparametrizations G that satisfy ∂G(x)∂G(x)> =
∇2 R(w)−1 for a strictly convex function R, and the implicit bias/regularization is
demonstrated by showing that the convergence point satisfies the KKT conditions
needed for minimizing R among all minimizers of the loss L. A concrete exam-
ple is that wi (t) = Gi (x(t)) = (xi (t))2 for all i ∈ [d], so here D = d. In this
case, the Legendre function R must satisfy (∇2 R(w(t)))−1 = ∂G(x(t))∂G(x(t))> =
4diag((x1 (t))2 , . . . , (xd (t))2 ) = 4diag(w1 (t), . . . , wd (t)) which suggests R is the classical
negative entropy function, i.e., R(w) = di=1 wi (ln wi − 1).
P

However, in general, it is hard to decide whether gradient flow for a given


parametrization G can be written as mirror flow for some Legendre function R,
especially when D > d and G is not an injective map. In such cases, there could be
multiple x’s mapping to the same G(x) yet having different ∂G(x)∂G(x)> . If more
than one of such x can be reached by gradient flow, then the desired Legendre function

250
2
cannot exist. If only one of such x can be reached by gradient flow, we must decide
which x it is in order to decide the value of ∇2 R using ∂G∂G> . Conversely, Amid and
Warmuth [137] raises the following question: for what Legendre function R can the
corresponding mirror flow be the result of gradient flow after some reparametrization
G? Answering the questions in both directions requires a deeper understanding of the
impact of parametrizations.
The following are the main contributions of the current paper:

• In Section 7.4, building on classic study of commuting vector fields we identify


a notion of when a parametrization w = G(x) is commuting (Definition 7.4.1)
and use it to give a sufficient condition (Theorem 7.4.9) and a slightly weaker
necessary condition (Theorem 7.4.10) of when the gradient flow in the x-space
governed by −∇(L ◦ G) is simulating a mirror flow in the w-space with respect
to some Legendre function R : Rd → R, which encompasses all the previous
results [20, 105, 134–138]. Moreover, the Legendre function is independent of
the loss L and depends only on the initialization xinit and the parametrization G.

• We recover and generalize existing implicit regularization results for underde-


termined linear regression as implications of the above characterization (Corol-
lary 7.4.17). We also give new convergence analysis in such settings (Theo-
rem 7.4.15), filling the gap in previous works [20, 105, 138] where parameter
convergence is only assumed but not proved.

• In the reverse direction, we use the famous Nash embedding theorem to show
that every mirror flow in the w-space with respect to some Legendre function R
simulates a gradient flow with commuting parametrization under some embedding
x = F (w) where F : Rd → RD and the parametrization G is the inverse of F
(Theorem 7.5.1). This provides an affirmative and fully general answer to the
2
To avoid such an issue, Amid and Warmuth [137] has to assume all the preimages of G at w
have the same ∂G(∂G)> and a recent paper Ghai et al. [141] assumes that G is injective.
251
question of when such reparametrization functions exist, giving a full answer to
questions raised in a more restricted setting in Amid and Warmuth [137].

7.2 Related work

Implicit bias. With high overparametrization as used in modern machine learning,


there usually exist multiple optima, and it is crucial to understand which particular
solutions are found by the optimization algorithm. Implicit bias of gradient descent
for classification tasks with separable data was studied in Soudry et al. [100], Nacson
et al. [101], Gunasekar et al. [142], Ji and Telgarsky [143], Moroshko et al. [144], Ji
and Telgarsky [145] and for non-separable data in Ji and Telgarsky [146, 147], where
the implicit bias appears in the form of margin maximization. The implicit bias for
regression problems has also been analyzed by leveraging tools like mirror descent [20,
134–137, 142], later generalized in Azulay et al. [138].
The sharp contrast between the so-called kernel and rich regimes [20] reflects the
importance of the initialization scale, where a large initialization often leads to the
kernel regime with features barely changing during training [15, 148–156], while with
a small initialization, the solution exhibits richer behavior with the resulting model
having lower complexity [14, 25, 95, 98, 105, 112, 157–162]. Recently Yang and Hu
[163] give a complete characterization on the relationship between initialization scale,
parametrization and learning rate in order to avoid kernel regime.
There are also papers on the implicit bias of other types of optimization algorithms,
e.g., stochastic gradient descent [21, 133, 164–167] and adaptive and momentum-based
methods [168–171], to name a few.

Understanding mirror descent. In the continuous-time limit as step size goes


to 0, the mirror flow is equivalent to the Riemannian gradient flow. Gunasekar
et al. [172] showed that a partial discretization of the latter gives rise to the classical
252
mirror descent. Assuming the existence of some reparametrization function, Amid
and Warmuth [137] showed that a particular mirror flow can be reparametrized as a
gradient flow. Our paper shows that such reparametrization always exists by using
Nash’s embedding theorem. Ghai et al. [141] generalizes the equivalence result of
Amid and Warmuth [137] to discrete updates.

7.3 Preliminaries and notations

Notations. We denote N as the set of natural numbers. For any positive integer n,
we denote {1, 2, . . . , n} by [n]. For any vector u ∈ RD , we denote its i-th coordinate
by ui . For any vector u, v ∈ RD and α ∈ R, we define u v = (u1 v1 , . . . , uD vD )>
and u α
= ((u1 )α , . . . , (uD )α )> . For any k ∈ N ∪ {∞}, we say a function f is C k
if it is k times continuously differentiable, and use C k (M ) to denote the set of all
C k functions from M to R. We use ◦ to denote the composition of functions, e.g.,
f ◦ g(x) = f (g(x)). For any convex function R : RD → R ∪ {∞}, we denote its
domain by dom R = {w ∈ RD | R(w) < ∞}. For any set S, we denote its interior by
int(S) and its closure by S.

We assume that the model has parameter vector w ∈ Rd and C 1 loss func-
tion L : Rd → R. Training involves a reparametrized vector x ∈ RD , which is a
reparametrization of w such that w = G(x) for some differentiable parametrization
function G, and the objective is L(G(x)). From now on, we follow the convention
that d is the dimension of the original parameter w and D is the dimension of the
reparametrized x. We also refer to Rd as the w-space and RD as the x-space.
In particular, we are interested in understanding the dynamics of gradient flow
under the objective L ◦ G on some submanifold M ⊆ RD . Most of our results also
generalize to the following notion of time-dependent loss.

253
Definition 7.3.1 (Time-dependent loss). A time-dependent loss Lt (w) is a function
piecewise constant in time t and continuously differentiable in w ∈ Rd , that is, there
exists k ∈ N, 0 = t1 < t2 < · · · < tk+1 = ∞ and C 1 loss functions L(1) , L(2) , . . . , L(k)
such that for each i ∈ [k] and all t ∈ [ti , ti+1 ),

Lt (w) = L(i) (w), ∀w ∈ Rd .

We denote the set of such time-dependent loss functions by L.

7.3.1 Manifold and vector field

Vector fields are a natural way to formalize the continuous-time gradient descent (a
good reference is Lee [173]). Let M be any smooth submanifold of RD . A vector field
X on M is a continuous map from M to RD such that for any x ∈ M , X(x) is in the
tangent space of M at x, which is denoted by Tx (M ). Formally, Tx (M ) := { dγ
dt t=0
|
∀ smooth curves γ : R → M, γ(0) = x}.

Definition 7.3.2 (Complete vector field; p.215, Lee 173). Let M be a smooth
submanifold of RD and X be a vector field on M . We say X is a complete vector
field on M if and only if for any initialization xinit ∈ M , the differential equation
dx(t) = X(x(t))dt has a solution on (−∞, ∞) with x(0) = xinit .

When the smooth submanifold M ⊆ RD is equipped with a metric tensor g, we


then have a Riemannian manifold (M, g), where for each x ∈ M , gx : Tx M ×Tx M → R
is a positive definite bilinear form. In particular, the standard Euclidean metric g
corresponds to g x (u, v) = u> v for each x ∈ M and u, v ∈ Tx M , under which the
length of any arc on M is given by its length as a curve in RD .
For any differentiable function f : M → R, we denote by ∇g f its gradient vector
field with respect to metric tensor g. More specifically, ∇g f (x) is defined as the
unique vector in RD such that ∇g f (x) ∈ Tx (M ) and df (γ(t)) dγ(t) 
dt t=0
= gx ∇f (x), dt t=0
.
254
Throughout the paper, we assume by default that the metric on the submanifold
M ⊆ RD is inherited from (RD , g), and we will use ∇f as a shorthand for ∇g f . If M
is an open set of RD , ∇f is then simply the ordinary gradient of f .
For any x ∈ M and C 1 function f : M → R, we denote by φtf (x) the point on M
reached after time t by following the vector field −∇f starting at x, i.e., the solution
at time t (when it exists) of

dφtf = −∇f (φtf )dt, φ0f (x) = x.

We say φtf (x) is well-defined at time t when the above differential equation has a
solution at time t. Moreover, for any differentiable function X : M → Rd , we denote
its Jacobian by

∂X(x) = (∇X1 (x), ∇X2 (x), . . . , ∇Xd (x))> .

Definition 7.3.3 (Lie bracket). Let M be a smooth submanifold of RD . Given two


C 1 vector fields X, Y on M , we define the Lie Bracket of X and Y as [X, Y ](x) :=
∂Y (x)X(x) − ∂X(x)Y (x).

7.3.2 Parametrizations

We use the term parametrization to refer to differentiable maps from a smooth


submanifold of RD (x-space) to Rd (w-space). We reserve G to denote parametrizations,
and omit the dependence on G for notations of objects related to G when it is clear
from the context.
The following notion of regular parametrization plays an important role in our
analysis, and it is necessary for our main equivalence result between mirror flow and
gradient flow with reparametrization. This is because if the null space of ∂G(x) is non-
trivial, i.e., it contains some vector u 6= 0, then the gradient flow with parametrization
255
G obviously cannot simulate any mirror flow with nonzero velocity in the direction of
u.

Definition 7.3.4 (Regular parametrization). Let M be a smooth submanifold of RD .


A regular parametrization G : M → Rd is a C 1 parametrization such that ∂G(x) is of
rank d for all x ∈ M .

Note that a regular parametrization G can become irregular when its domain is
changed. For example, G(x) = x2 is regular on R+ , but it is not regular on R as
∂G(0) = 0.
Given a C 2 parametrization G : M → Rd , for any x ∈ M and µ ∈ Rd , we define

ψ(x; µ) := φµG11 ◦ φµG22 ◦ · · · ◦ φµGdd (x) (7.4)

when it is well-defined, i.e., the corresponding integral equation has a solution. For
any x ∈ M , we define the domain of ψ(x; ·) as

U(x) = µ ∈ Rd | ψ(x; µ) is well-defined .



(7.5)

When every ∇Gi is a complete vector field on M as in Definition 7.3.2, we have


U(x) = Rd . However, such completeness assumption is relatively strong, and most
3
polynomials would violate it. For example, consider G(x) = x for x ∈ Rd , then the
solution to dxi (t) = 3xi (t)2 dt explodes in finite time for each i ∈ [d]. To relax this, we
consider parametrizations such that the domain of the flows induced by its gradient
vector fields is pairwise symmetric. More specifically, for any x ∈ M and i, j ∈ [d], we
define

Uij (x) = (s, t) ∈ R2 | φsGi ◦ φtGj (x) is well-defined ,




and we make the following assumption.


256
Assumption 7.3.5. Let M be a smooth submanifold of RD and G : M → Rd
be a parametrization. We assume that for any x ∈ M and i ∈ [d], φtx (x) is well-
defined for t ∈ (T− , T+ ) such that either limt→T+ kφtx (x)k2 = ∞ or T+ = ∞ and
similarly for T− . Also, we assume that for any x ∈ M and i, j ∈ [d], it holds that
Uji (x) = {(t, s) ∈ R2 | (s, t) ∈ Uij (x)}, i.e., φsGi ◦ φtGj (x) is well-defined if and only if
φtGj ◦ φsGi (x) does.

Indeed, under Assumption 7.3.5, we can show that for any x ∈ M , U(x) is a
hyperrectangle, as summarized in the following lemma. See Section 7.7 for a proof.

Lemma 7.3.6. Let M be a smooth submanifold of RD and G : M → Rd be a


C 2 parametrization satisfying Assumption 7.3.5. Then for any x ∈ M , U(x) is a
hyperrectangle, i.e., U(x) can be decomposed as

U(x) = I1 (x) × I2 (x) × · · · × Id (x)

where Ij (x) := {x0j | x0 ∈ U(x)} is an open interval.

For any initialization xinit ∈ M , the set of points that are reachable via gradient
flow under G with respect to some time-dependent loss (see Definition 7.3.1) is a
subset of M that depends on G and xinit .

Definition 7.3.7 (Reachable set). Let M be a smooth submanifold of RD . For any


C 2 parametrization G : M → Rd and any initialization xinit ∈ M , the reachable set
Ωx (xinit ; G) is defined as

n o
Ωx (xinit ; G) = φµL11 ◦G ◦ φµL22 ◦G ◦ · · · ◦ φµLkk ◦G (xinit ) ∀k ∈ N, ∀i ∈ [k], Li ∈ C 1 (Rd ), µi ≥ 0 .

It is clear that the above definition induces a transitive “reachable” relationship


between points on M , and it is also reflexive since for all L ∈ C 1 (Rd ) and t > 0,
φtL◦G ◦ φt(−L)◦G is the identity map on the domain of φt−L◦G . In this sense, the reachable
257
sets are orbits of the family of gradient vector fields {∇(L ◦ G) | L ∈ C 1 (Rd )}, i.e., the
reachable sets divide the domain M into equivalent classes. The above reachable set
in the x-space further induces the corresponding reachable set in the w-space given by
Ωw (xinit ; G) = G(Ωx (xinit ; G)).
In most natural examples, the parametrization G is smooth (though this is not
necessary for our results), and by Sussman’s Orbit Theorem [174], each reachable set
Ωx (xinit ; G) is an immersed submanifold of M . Moreover, it follows that Ωx (xinit ; G)
can be generated by {∇Gi }di=1 , i.e., Ωx (xinit ; G) = {φµG1j ◦ φµG2j ◦ · · · ◦ φµGkj (xinit ) | ∀k ∈
1 2 k

N, ∀i ∈ [k], ji ∈ [d], µi ≥ 0}.

7.3.3 Mirror descent and mirror flow

Next, we introduce some basic notions for mirror descent [139, 140]. We refer the
readers to Section 7.6 for more preliminaries on convex analysis.

Definition 7.3.8 (Legendre function and mirror map). Let R : Rd → R ∪ {∞} be a


differentiable convex function. We say R is a Legendre function when the following
holds:

(a) R is strictly convex on int(dom R).

(b) For any sequence {wi }∞


i=1 going to the boundary of dom R, limi→∞ k∇R(wi )k2 =

∞.

In particular, we call R a mirror map if R further satisfies the following condition (see
p.298 in Bubeck et al. 175):

(c) The gradient map ∇R : int(dom R) → Rd is surjective.

Given a Legendre function R : Rd → R ∪ {∞}, for any initialization w0 = winit ∈


int(dom R), mirror descent with step size η updates as follows:

∇R(wk+1 ) = ∇R(wk ) − η∇L(wk ). (7.6)

258
Usually ∇R is required to be surjective so that after a discrete descent step in the
dual space, it can be projected back to the primal space via (∇R)−1 . Nonetheless,
as long as ∇R(wk ) − η∇L(wk ) is in the range of ∇R, the above discrete update is
well-defined. In the limit of η → 0, (7.6) becomes the continuous mirror flow:

d∇R(w(t)) = −∇L(w(t))dt. (7.7)

Given a differentiable function R, the corresponding Bregman divergence DR is


defined as

DR (w, w0 ) = R(w) − R(w0 ) − h∇R(w0 ), w − w0 i.

We recall a well-known implicit bias result for mirror flow (which holds for mirror
descent as well) [142], which shows that for a specific type of loss, if mirror flow
converges to some optimal solution, then the convergence point minimizes some convex
regularizer among all optimal solutions.

Theorem 7.3.9. Given any data Z ∈ Rn×d and corresponding label Y ∈ Rn , suppose
the loss L(w) is in the form of L(w) = L(Zw)
e e : Rn → R.
for some differentiable L
Assume that initialized at w(0) = winit , the mirror flow (7.7) converges and the
convergence point w∞ = limt→∞ w(t) satisfies Zw∞ = Y , then

DR (w∞ , w0 ) = min DR (w, w0 ).


w:Zw=Y

See Section 7.7 for a proof. The above theorem is the building block for proving
the implicit bias induced by any commuting parametrization in overparametrized
linear models (see Theorem 7.4.16).

259
t t
φtGi i (x) −∇Gj φtGi i ◦ φGj j (x) = φGj j ◦ φtGi i (x)
tj

−∇Gi ti ti −∇Gi

tj
x −∇Gj t
φGj j (x)

Figure 7.1: Illustration of commuting parametrizations. Suppose G : M → Rd is


a commuting parametrization satisfying Assumption 7.3.5, then starting from any
x ∈ M , first moving along −∇Gi for time ti then moving along −∇Gj for time tj
yields the same result as first moving along −∇Gj for time tj then moving along
t t
−∇Gi for time ti does, i.e., φtGi i ◦ φGj j (x) = φGj j ◦ φtGi i (x).

7.4 Any gradient flow with commuting parametriza-

tion is a mirror flow

7.4.1 Commuting parametrization

We now formalize the notion of commuting parametrization. We remark that M is a


smooth submanifold of RD , and it is the domain of the parametrization G.

Definition 7.4.1 (Commuting parametrization). Let M be a smooth submanifold of


RD . A C 2 parametrization G : M → Rd is commuting in a subset S ⊆ M if and only
if for any i, j ∈ [d], the Lie bracket [∇Gi , ∇Gj ](x) = 0 for all x ∈ S. Moreover, we say
G is a commuting parametrization if it is commuting in the entire M .

In particular, when M is an open subset of Rd , {∇Gi }di=1 are ordinary gradients


in RD , and the Lie bracket between any pair of ∇Gi and ∇Gj is given by

[∇Gi , ∇Gj ](x) = ∇2 Gj (x)∇Gi (x) − ∇2 Gi (x)∇Gj (x).

This provides an easy way to check whether G is commuting or not.

260
The above definition of commuting parametrizations builds upon the differential
properties of the gradient vector fields {∇Gi }di=1 , where each Lie bracket [∇Gi , ∇Gj ]
characterizes the change of ∇Gj along the flow generated by ∇Gi . In particular,
when G is a commuting parametrization satisfying Assumption 7.3.5, it is further
equivalent to a characterization of ‘commuting’ in the integral form, as summarized in
Theorem 7.4.2. Also see Figure 7.1 for an illustration.

Theorem 7.4.2 (Adapted from Theorem 9.44 in Lee [173]). Let M be a smooth
submanifold of RD and G : M → Rd be a C 2 parametrization. For any i, j ∈ [d],
[∇Gi , ∇Gj ](x) = 0 for all x ∈ M if and only if for any x ∈ M , whenever both
φsGi ◦ φtGj (x) and φtGj ◦ φsGi (x) are well-defined for all (s, t) in some rectangle I1 × I2
where I1 , I2 ⊆ R are open intervals, it holds that φsGi ◦ φtGj (x) = φtGj ◦ φsGi (x) for all
(s, t) ∈ I1 × I2 .

Under Assumption 7.3.5, Lemma 7.3.6 implie s that the domain of φsGi ◦ φtGj (x) is
exactly Ii (x) × Ij (x), and thus the above theorem simplifies into the following.

Theorem 7.4.3. Let M be a smooth submanifold of RD and G : M → Rd be a C 2


parametrization satisfying Assumption 7.3.5. For any i, j ∈ [d], [∇Gi , ∇Gj ](x) = 0
for all x ∈ M if and only if for any x ∈ M , it holds that φsGi ◦ φtGj (x) = φtGj ◦ φsGi (x)
for all (s, t) ∈ I1 (x) × I2 (x).

The commuting condition clearly holds when each Gi only depends on a different
subset of coordinates of x, because we then have ∇2 Gi (·)∇Gj (·) ≡ 0 for any distinct
i, j ∈ [d] as ∇2 Gi and ∇Gj live in different subspaces of RD . We call such G separable
parametrizations, and this case covers all the previous examples [20, 105, 134, 136, 137].
Another interesting example is the quadratic parametrization: We parametrize w ∈ Rd
by G : RD → Rd where for each i ∈ [d], there is a symmetric matrix Ai ∈ RD×D such
that Gi (x) = 12 x> Ai x. Then each Lie bracket [Gi , Gj ](x) = (Aj Ai − Ai Aj )x, and thus
G is a commuting parametrization if and only if matrices {Ai }di=1 commute.
261
For concreteness, we analyze two examples below. The first one is both a separable
parametrization and a commuting quadratic parametrization. The second one is a
quadratic parametrization but not commuting.

2 2
Example 7.4.4 (u −v parametrization, Woodworth et al. [20]). Parametrize
w ∈ Rd by w = u 2
− v 2 . Here D = 2d, and the parametrization G is given by
for x = uv ∈ RD . Since each Gi (x) involves only ui and vi , G is
2 2

G(x) = u −v
a separable parametrization and hence a commuting parametrization. Meanwhile,
each Gi (x) is a quadratic form in x, and it can be directly verified that the matrices
underlying these quadratic forms commute with each other.

Example 7.4.5 (Matrix factorization). As a counter-example, consider two


parametrizations for matrix factorization: G(U ) = U U > and G(U, V ) = U V > ,
where U, V ∈ Rd×r and d ≥ 2, r ≥ 1. These are both non-commuting quadratic
parametrizations. Here we only demonstrate for the parametrization G(U ) = U U > ,
and G(U, V ) = U V > follows a similar argument. For each i, j ∈ [d], we define Eij ∈ Rd
as the one-hot matrix with the (i, j)-th entry being 1 and the rest being 0, and denote
E ij = 12 (Eij + Eji ). For r = 1, we have Gij (U ) = Ui Uj = U > E ij U for any i, j ∈ [d], so
G is a quadratic parametrization. Note that E ii E ij = 12 Eij =
6 12 Eji = E ij E ii for all
distinct i, j ∈ [d], which implies that [∇Gij , ∇Gii ] 6= 0, so G is non-commuting. More


generally, we can reshape U as a vector U := [U:1> , . . . , U:r> ]> ∈ Rrd where each U:j is
the j-th column of U , and the resulting quadratic form for the (i, j)-entry of G(U )
corresponds to a block-diagonal matrix:
 
E ij 

−  →
−
Gij (U ) = ( U )> 

...
U.
 
E ij

Therefore, ∇2 Gij does not commute with ∇2 Gii due to the same reason as in the
rank-1 case.
262
Remark 7.4.6. This non-commuting issue for general matrix factorization does not
conflict with the theoretical analysis in Gunasekar et al. [105] where the measurements
are commuting, or equivalently, only involves diagonal elements, as {Gii }di=1 are
indeed commuting parametrizations. Gunasekar et al. [105] is the first to identify the
above non-commuting issue and conjectured that the implicit bias result for diagonal
measurements can be extended to the general case.

7.4.2 Main Equivalence Result

Next, we proceed to present our analysis for gradient flow with commuting parametriza-
tion. The following two lemmas highlight the special properties of commuting
parametrizations. Lemma 7.4.7 shows that the point reached by gradient flow with
any commuting parametrization is determined by the integral of the negative gradient
of the loss along the trajectory.

Lemma 7.4.7. Let M be a smooth submanifold of RD and G : M → Rd be a


commuting parametrization. For any initialization xinit ∈ M , consider the gradient
flow for any time-dependent loss L· ∈ L as in Definition 7.3.1:

dx(t) = −∇(Lt ◦ G)(x(t))dt, x(0) = xinit .

Rt
Further define µ(t) = 0
−∇Lt (G(x(s)))ds. Suppose µ(t) ∈ U(xinit ) for all t ∈ [0, T )
where T ∈ R ∪ {∞}, then it holds that x(t) = ψ(xinit ; µ(t)) for all t ∈ [0, T ).

Based on Lemma 7.4.7, the next key lemma reveals the essential approach to find
the Legendre function.

Lemma 7.4.8. Let M be a smooth submanifold of RD and G : M → Rd be a


commuting and regular parametrization satisfying Assumption 7.3.5. Then for any
xinit ∈ M , there exists a Legendre function Q : Rd → R ∪ {∞} such that ∇Q(µ) =

263
G(ψ(xinit ; µ)) for all µ ∈ U(xinit ). Moreover, let R be the convex conjugate of Q, then
R is also a Legendre function and satisfies that int(dom R) = Ωw (xinit ; G) and

−1
∇2 R(G(ψ(xinit ; µ))) = ∂G(ψ(xinit ; µ))∂G(ψ(xinit ; µ))>

for all µ ∈ U(xinit ).

Next, we present our main result on characterizing any gradient flow with com-
muting parametrization by a mirror flow.

Theorem 7.4.9. Let M be a smooth submanifold of RD and G : M → Rd be


a commuting and regular parametrization satisfying Assumption 7.3.5. For any
initialization xinit ∈ M , consider the gradient flow for any time-dependent loss function
Lt : Rd → R:

dx(t) = −∇(Lt ◦ G)(x(t))dt, x(0) = xinit .

Define w(t) = G(x(t)) for all t ≥ 0, then the dynamics of w(t) is a mirror flow with
respect to the Legendre function R given by Lemma 7.4.8, i.e.,

d∇R(w(t)) = −∇Lt (w(t))dt, w(0) = G(xinit ).

Moreover, this R only depends on the initialization xinit and the parametrization G,
and is independent of the loss function Lt .

Proof of Theorem 7.4.9. Recall that the gradient flow in the x-space governed by
−∇(Lt ◦ G)(x) is

dx(t) = −∇(Lt ◦ G)(x(t))dt = −∂G(x(t))> ∇Lt (G(x(t)))dt.

264
Using w(t) = G(x(t)), the corresponding dynamics in the w-space is

dw(t) = ∂G(x(t))dx(t) = −∂G(x(t))∂G(x(t))> ∇Lt (w(t))dt. (7.8)

By Lemma 7.4.7, we know that the solution to the gradient flow satisfies x(t) =
Rt
ψ(xinit ; µ(t)) where µ(t) = 0 −∇Lt (G(x(s)))ds. Therefore, applying Lemma 7.4.8,
we get a Legendre function R : Rd → R ∪ {∞} with domain Ωw (xinit ; G) such that

−1
∇2 R(w(t)) = ∇2 R(G(ψ(xinit ; µ(t)))) = ∂G(ψ(xinit ; µ(t)))∂G(ψ(xinit ; µ(t)))

for all t ≥ 0. Then the dynamics of w(t) in (7.8) can be rewritten as

dw(t) = −∇2 R(w(t))−1 ∇Lt (w(t))dt,

or equivalently,

d∇R(w(t)) = −∇Lt (w(t))dt,

which is exactly the mirror flow with respect to R initialized at w(0) = G(xinit ).
Further note that the result of Lemma 7.4.8 is completely independent of the loss
function Lt , and thus R only depends on the initialization xinit and the parametrization
G. This finishes the proof.

Theorem 7.4.9 provides a sufficient condition for when a gradient flow with certain
parametrization G is simulating a mirror flow. The next question is then: What are
the necessary conditions on the parametrization G so that it enables the gradient flow
to simulate a mirror flow? We provide a (partial) characterization of such G in the
following theorem.

265
Theorem 7.4.10 (Necessary condition on smooth parametrization to be commuting).
Let M be a smooth submanifold of RD and G : M → Rd be a smooth parametrization.
If for any xinit ∈ M , there is a Legendre function R such that for all time-dependent
loss Lt ∈ L, the gradient flow under Lt ◦ G initialized at xinit can be written as the
mirror flow under Lt with respect to R, then G must be a regular parametrization, and
it also holds that for each x ∈ M ,

Lie≥2 (∂G) x
⊆ ker(∂G(x)), (7.9)

where Lie≥K (∂G) := span [[[[∇Gj1 , ∇Gj2 ], . . .], ∇Gjk−1 ], ∇Gjk ] | k ≥ K, ∀i ∈ [k], ji ∈


[d]} is the subset of the Lie algebra generated by the gradients of coordinate functions of
G only containing elements of order higher than K, and ker(∂G(x)) is the orthogonal
complement of span({∇Gi (x)}di=1 ) in RD .

With the above necessary condition (7.9), we can formally refute the possibility
that one can use mirror flow to characterize the implicit bias of gradient flow for matrix
factorization in general settings, as summarized in Corollary 7.4.11. In Chapter 6
we will constructed a concrete counter example showing that the implicit bias for
commuting measurements, that gradient flow finds the solution with minimal nuclear
norm, does not hold for the general case, where gradient flow could prefer the solution
with minimal rank instead.

Corollary 7.4.11 (Gradient flow for matrix factorization cannot be written as mirror
flow). For any d, r ∈ N, let M be an open set in Rd×r and G : M → Rd×d be a smooth
parametrization given by G(U ) = U U > . Then there exists a initial point xinit ∈ M
and a time-dependent loss Lt such that the gradient flow under Lt ◦ G starting from
Uinit cannot be written as a mirror flow with respect to any Legendre function R under
the loss Lt .

266
Proof of Corollary 7.4.11. It turns out that the necessary condition in Theorem 7.4.10
is already violated by only considering the Lie algebra spanned by {∇G11 , ∇G12 }. We
follow the notation in Example 7.4.5 to define each Eij ∈ Rd as the one-hot matrix with
the (i, j)-th entry being 1, and denote E ij = 12 (Eij + Eji ) and ∆ij = Eij − Eji . Then
[∇G11 , ∇G12 ](U ) = 4(E 11 E 12 − E 12 E 11 )U = ∆12 U and [∇G11 , [∇G11 , ∇G12 ]](U ) =
(E 11 ∆12 − ∆12 E 11 )U = E 12 U . Further noting that h[∇G11 , [∇G11 , ∇G12 ]], ∇G12 i =
2
2 E 12 U F = 12 ri=1 (U1i2 + U2i2 ) must be positive at some U in every open set M , by
P

Theorem 7.4.10, we know such Uinit and Lt exist. Moreover, Lt will only depend on
G11 (U ) and G12 (U ).

The following corollary shows that gradient flow with non-commuting parametriza-
tion cannot be mirror flow, when the dimension of the reachable set matches with
that of the w-space.

Corollary 7.4.12. Let M be a smooth submanifold of RD and G : M → Rd be a


regular parametrization. Then G must be a commuting parametrization if for any
xinit ∈ M , the following holds:

(a) Ωx (xinit ; G) is a submanifold of dimension d.

(b) There is a Legendre function R such that for any time-dependent loss Lt ∈ L,
the gradient flow governed by −∇(Lt ◦ G) with initialization xinit can be written
as a mirror flow with respect to R.

Proof of Corollary 7.4.12. By the condition (b) and Theorem 7.4.10, we know that
each Lie bracket [∇Gi , ∇Gj ] ∈ ker(∂G). By the condition (a), we know that each Lie
bracket [∇Gi , ∇Gj ] ∈ span{∇Gi }di=1 . Combining these two facts, we conclude that
each [∇Gi , ∇Gj ] ≡ 0, so G is a commuting parametrization.

Next, we establish the convergence of w(t) = G(x(t)) when x(t) is given by some
gradient flow with the commuting parametrization G. Here we require that the convex
267
function R given by Lemma 7.4.8 is a Bregman function (see definition in Section 7.6).
The proofs of Theorem 7.4.13, Corollary 7.4.14 and Theorem 7.4.15 are in Section 7.8.

Theorem 7.4.13. Under the setting of Theorem 7.4.9, further assume that the loss L
is quasi-convex, ∇L is locally Lipschitz and argmin{L(w) | w ∈ dom R} is non-empty
where R : Rd → R ∪ {∞} is the convex function given by Lemma 7.4.8. Suppose
R is a Bregman function, then as t → ∞, w(t) converges to some w∗ such that
∇L(w∗ )> (w − w∗ ) ≥ 0 for all w ∈ dom R. Moreover, if the loss function L is convex,
then w(t) converges to a minimizer in dom R.

Corollary 7.4.14. Under the setting of Theorem 7.4.13, if the reachable set in the w-
space satisfies Ωw (xinit ; G) = Rd , then R is a Bregman function and all the statements
in Theorem 7.4.13 hold.

Theorem 7.4.15. Under the setting of Theorem 7.4.13, consider the commuting
quadratic parametrization G : RD → Rd where each Gi (x) = 12 x> Ai x, for symmetric
matrices A1 , A2 , . . . , Ad ∈ RD×D that commute with each other, i.e., Ai Aj − Aj Ai = 0
for all i, j ∈ [d]. For any xinit ∈ RD , if {∇Gi (xinit )}di=1 = {Ai xinit }di=1 are linearly
independent, then the following holds:

(a) For all µ ∈ Rd , ψ(xinit ; µ) = exp( di=1 µi Ai )xinit where exp(·) is the matrix
P

exponential defined as exp(A) := ∞ Ak


P
k=0 k! .

(b) For each j ∈ [d] and all µ ∈ Rd , Gj (ψ(xinit ; µ)) = 12 xinit > exp( di=1 2µi Ai )Aj xinit .
P

2
kψ(xinit ; µ)k22 =
1 1
Pd
(c) Q(µ) = 4 4
exp( i=1 µi Ai )xinit 2
is a Legendre function with
domain Rd .

(d) R is a Bregman function with dom R = range ∇Q where range ∇Q is the range
of ∇Q, and thus all the statements in Theorem 7.4.13 hold.

268
7.4.3 Solving underdetermined linear regression with com-

muting parametrization

Next, we specialize to underdetermined linear regression problems to showcase our


framework.

Setting: underdetermined linear regression. Let {(zi , yi )}ni=1 ⊂ Rd × R be a


dataset of size n. Given any parametrization G, the output of the linear model on
the i-th data is zi> G(x). The goal is to solve the regression for the label vector Y =
(y1 , y2 , . . . , yn )> . For notational convenience, we define Z = (z1 , z2 , . . . , zn ) ∈ Rd×n .
We can apply Theorem 7.3.9 to obtain the implicit bias of gradient flow with any
commuting parametrization.

Theorem 7.4.16. Let M be a smooth submanifold of Rd and G : M → Rd be a


commuting and regular parametrization satisfying Assumption 7.3.5. Suppose the loss
function L satisfies L(w) = L(Zw)
e e : Rn → R. For any
for some differentiable L
xinit ∈ M , consider the gradient flow

dx(t) = −∇(L ◦ G)(x(t))dt, x(0) = xinit .

There exists a convex function R (given by Lemma 7.4.8, depending only on the
initialization xinit and the parametrization G), such that for any dataset {(zi , yi )}ni=1 ⊂
Rd × R, if w(t) = G(x(t)) converges as t → ∞ and the convergence point w∞ =
limt→∞ w(t) satisfies Zw∞ = Y , then

R(w∞ ) = min R(w),


w:Zw=Y

that is, gradient flow implicitly minimizes the convex regularizer R among all interpo-
lating solutions.

269
Proof of Theorem 7.4.16. By Theorem 7.4.9, w(t) obeys the following mirror flow:

d∇R(w(t)) = −∇L(w(t))dt, w(0) = G(xinit ).

Applying Theorem 7.3.9 yields

DR (w∞ , G(xinit )) = min DR (w, G(xinit )).


w:Zw=Y

Therefore, for any w ∈ int(dom R) such that Zw = Y , we have

R(w∞ ) − R(G(xinit )) − h∇R(G(xinit ), w∞ − G(xinit )i

≤ R(w) − R(G(xinit )) − h∇R(G(xinit ), w − G(xinit )i

which can be reorganized as

R(w∞ ) ≤ R(w) − h∇R(G(xinit )), w − w∞ i. (7.10)

Note that by Lemma 7.4.8, we also have

∇R(G(xinit )) = ∇R(G(ψ(xinit ; 0))) = ∇R(∇Q(0)) = 0 (7.11)

where the last equality follows from the property of convex conjugate. Combining
(7.10) and (7.11), we get R(w∞ ) ≤ R(w) for all w ∈ int(dom R) such that Zw = Y .
By the continuity of R, this property can be further extended to the entire dom R,
and for any w ∈
/ dom R, we have R(w) = ∞ by definition, so R(w∞ ) ≤ R(w) holds
trivially. This finishes the proof.

Note that the identity parametrization w = G(x) = x is a commuting parametriza-


tion. Therefore, if we run the ordinary gradient flow on w itself and it converges to

270
some interpolating solution, then the convergence point is closest to the initialization
in Euclidean distance among all interpolating solutions. This recovers the well-known
implicit bias of gradient flow for underdetermined regression.
Furthermore, we can recover the results on the quadratically overparametrized
linear model studied in a series of papers [20, 105, 138], as summarized in the following
Corollary 7.4.17. Note that their results assumed convergence in order to characterize
the implicit bias, whereas our framework enables us to directly prove the convergence
as in Theorem 7.4.15, where the convergence guarantee is also more general than
existing convergence results for Example 7.4.4 in Li et al. [133], Pesme et al. [176].

Corollary 7.4.17. Consider the underdetermined linear regression problem with data
Z ∈ Rd×n and Y ∈ Rn . Let L
e : Rn → R be a differentiable loss function such that L
e
e is locally Lipschitz, and Y ∈ Rn is its unique global minimizer.
is quasi-convex, ∇L
Consider solving minw L(Zw)
e by running gradient flow on L(w) = L(Zw)
e with the
quadratic parametrization w = G(x) = u 2 − v 2 where x = uv ∈ R2d

+ , for any

initialization xinit ∈ R2d


+:

dx(t) = −∇(L ◦ G)(x(t))dt, x(0) = xinit .

Then as t → ∞, w(t) = G(x(t)) converges to some w∞ such that Zw∞ = Y and

R(w∞ ) = min R(w)


w:Zw=Y

where R is given by

d
1 X  w  q
i 2 2 2 u0,i 
R(w) = wi arcsinh − wi + 4u0,i v0,i − wi ln .
4 i=1 2u0,i v0,i v0,i

271
7.5 Every mirror flow is a gradient flow with com-

muting parametrization

Consider any smooth Legendre function R : Rd → R ∪ {∞}, and recall the correspond-
ing mirror flow:

d∇R(w(t)) = −∇L(w(t))dt.

Note that int(dom R) is a convex open set of Rd , hence a smooth manifold (see
Example 1.26 in Lee [173]). Then ∇2 R is a continuous positive-definite metric on
int(dom R). As discussed previously, the above mirror flow can be further rewritten
as the Riemannian gradient flow on the Riemannian manifold (int(dom R), ∇2 R), i.e.,

dw(t) = −∇2 R(w(t))−1 ∇L(w(t))dt.

The goal is to find a parametrization G : U → Rd , where U is an open set of RD and


initialization xinit ∈ U , such that the dynamics of w(t) = G(x(t)) can be induced by
the gradient flow on x(t) governed by −∇(L ◦ G)(x). Formally, we have the following
result:

Theorem 7.5.1. Let R : Rd → R ∪ {∞} be a smooth Legendre function. There exist


a smooth submanifold of RD denoted by M , an open neighborhood U of M and a
smooth and regular parametrization G : U → Rd such that for mirror flow on any
time-dependent loss function Lt with any initialization winit ∈ int(dom R)

d∇R(w(t)) = −∇Lt (w(t))dt, w(0) = winit , (7.12)

272
it holds that w(t) = G(x(t)) for all t ≥ 0 where x(t) is given by the gradient flow under
the objective Lt ◦ G initialized at xinit , i.e.,

dx(t) = −∇(Lt ◦ G)(x(t))dt, x(0) = xinit . (7.13)

Moreover, G restricted on M , denoted by G|M is a commuting and regular parametriza-


tion and ∂G = ∂G|M on M , which implies x(t) ∈ M for all t ≥ 0. If R is further a
mirror map, then {∇Gi |M }di=1 are complete vector fields on M .

To illustrate the idea, let us first suppose such a smooth and regular parametrization
G exists and is a bijection between the reachable set Ωx (xinit ; G) ⊂ RD and int(dom R),
whose inverse is denoted by F . It turns out that we can show

∂F (w)> ∂F (w) = (∂G(F (w))∂G(F (w))> )−1 = ∇2 R(w)

where the second equality follows from the relationship between R and G as discussed
in the introduction on Equation (7.2). Note that this corresponds to expressing the
metric tensor ∇2 R using an explicit map F , which is further equivalent to embedding
the Riemannian manifold (int(dom R), ∇2 R) into a Euclidean space (RD , g) in a way
that preserves its metric. This refers to a notion called isometric embedding in
differential geometry.

Definition 7.5.2 (Isometric embedding). Let (M, g) be a Riemannian submanifold


of Rd . An isometric embedding from (M, g) to (RD , g) is an differentiable injective
map F : M → RD that preserves the metric in the sense that for any two tangent
vectors v, w ∈ Tx (M ) we have gx (v, w) = g x (∂F (x)v, ∂F (x)w) where the standard
euclidean metric tensor g is defined as g x (u, v) = hu, vi for all u, v ∈ Rd .

273
Nash’s embedding theorem is a classic result in differential geometry that guarantees
the existence of isometric embedding of any Riemannian manifold into a Euclidean
space with a plain geometry.

Theorem 7.5.3 (Nash’s embedding theorem, Nash [177], Gunther [178]). Any d-
dimensional Riemannian manifold has an isometric embedding to (RD , g) for some
D ≥ d.

The other way to understand Theorem 7.4.9 is that we can view ∇2 R(w)−1 ∇L(w)
as the gradient of L with respect to metric tensor gR , where g R is the Hessian
metric induced by strictly convex function R in the sense gxR (u, v) := u> ∇2 R(x)v
for any u, v ∈ Rd . It is well-known that gradient flow is invariant under isometric
embedding and thus we can use Nash’s embedding theorem to write the gradient flow
on riemmanian manifold (int(dom R), g R ) as that on (RD , g).

7.5.1 Existence of non-separable commuting parametrization

Despite the recent line of works on the connection between mirror descent and
gradient descent [136–138, 141, 142], so far we have not seen any concrete example
of non-separable parametrizaiton (in the sense of Definition 7.5.4) such that the
reparametrized gradient flow can be written as a mirror flow. In this subsection, we
discuss how we can use Theorem 7.5.1 to construct non-separable, yet commuting
parametrizations.

Definition 7.5.4 (Separable parametrization in the general sense). Let M be an


open subset of RD . We say a function G : M → Rd is a generalized separable
parametrization if and only if there exist d projection matrices {Pi }di=1 satisfying
Pd d
i=1 Pi = Id , Pi Pj = 1{i = j}·Pi , a function G : M → R satisfying Gi (x) = Gi (Pi x),
b b b

274
a matrix A ∈ Rd×d and a vector b ∈ Rd , such that

G(x) = AG(x)
b + b, ∀x ∈ M.

Given the above definition, it is easy to check that G


b is a commuting parametrization

as ∇2 G b j = P i ∇2 G
bi ∇G bi Pi · Pj ∇G
bj ≡ 0 for all i =
6 j, so each Lie bracket [∇Gi , ∇Gj ] is
also 0 by the linearity.
As a concrete example, for matrix sensing with commutable measurement
A1 , . . . , Am ∈ Rd×d , let V = (v1 , . . . , vd ) ∈ Rd×d be a common eigenvector matrix
for {Ai }m >
= dj=1 σi,j vi vi> for each i ∈ [m].
P
i=1 such that we can write Ai = V Σi V

With parametrization G : Rd×r → d where each Gi (U ) = vi> U U > vi , we can write


hAi , U U > i = dj=1 σi,j Gj (U ).
P

However, the bad news is that separable commuting parametrizations can only
express a restricted class of Legendre functions. It is easy to see ∂ G(x)∂
b b > must be
G(x)
diagonal for every x. Thus ∂G(x)∂G(x)> are simultaneously diagonalizable for all x,
and so are the Hessian of the corresponding Legendre function (given by Lemma 7.4.8).
There are interesting Legendre functions that does not always have their Hessians
simultaneously diagonalizable, such as

d
X  d
X   d
X  
R(w) = wi (ln wi − 1) + 1 − wi ln 1 − wi −1 ,
i=1 i=1 i=1

Pd Pd
where each wi > 0 and i=1 wi < 1. We can check that ∇R(w) = i=1 ln 1−Pwdi wi
i=1
2
and ∇ R(w) = diag(w (−1)
)+ 1d 1>
d. It is proposed as an open problem by [137]
that whether we can find a parametrization G such that the reparametrized gradient
flow in the x-space simulates the mirror flow in the w-space with respect to the
aforementioned Legendre function R.
Our Theorem 7.5.1 answers the open problem by [137] affirmatively since it shows
every mirror flow can be written as some reparametrized gradient flow. According
275
to the previous discussion, every mirror flow for Lengendre function whose Hessian
cannot be simultaneously diagonalized always induces a non-separable commuting
parametrization. But this type of construction has two caveats: First, the construction
of the Legendre function uses Nash’s Embedding theorem, which is implicit and hard
to implement; second, the parametrization given by Theorem 7.5.1, though defined on
an open set in RD , is only commuting on the reachable set, which is a d-dimensional
submanifold of RD . This is different from all the natural examples of commuting
parametrizations which are commuting on an open set, leading to the following open
question.

Open Question: Is there any smooth, regular, commuting, yet non-separable (in
the sense of Definition 7.5.4) parametrization from an open subset of RD to Rd , for
some integers D and d?

Theorem 7.5.5. All smooth, regular and commuting parametrizations are non-
separable when D = 1.

Proof of Theorem 7.5.5. Note that [∇Gi , ∇Gj ] ≡ 0 implies that all Gi share the same
set of stationary points, i.e., {x ∈ R | ∇Gi (x) = 0} is the same for all i ∈ [d]. Since
D = 1, without loss of generality, we can assume G0i (x) = ∇Gi (x) > 0 for all x ∈ M
and i ∈ [d] since G is regular. Then it holds that sign(G0i )(ln |G0i |)0 = sign(G0j )(ln |G0j |)0 ,
which implies that |G0i |/|G0j | is equal to some constant independent of x. This completes
the proof.

Remark 7.5.6. We note that the assumption that the parametrization is regular is
necessary for the open question to be non-trivial. Otherwise, consider the following
example with D = 1 and d = 2: Let f1 , f2 : R → R be any smooth function supported
Rx
on (0, 1) and (1, 2) respectively. Define Gi (x) = 0 fi (t)dt for all x ∈ R. Then
parametrization G is non-separable.

276
7.6 Related basics for convex analysis

We first introduce some additional notations. For any function f , we denote its range
(or image) by range f . For any set S, we use S to denote its closure. For any matrix
Λ ∈ Rd×D and set S ⊆ RD , we define ΛS = {Λx | x ∈ S} ⊆ Rd .
Below we collect some related basic definitions and results in convex analysis. We
refer the reader to Rockafellar [179] and Bauschke et al. [180] as main reference sources.
In particular, Sections 2, 3 and 4 in Bauschke et al. [180] provide a clear summary of
the related concepts.
Here we consider a convex function f : Rd → R ∪ {∞} whose domain is dom f =
{w ∈ Rd | f (w) < ∞}. From now on, we assume by default that f is continuous
on dom f , the interior of its domain int(dom f ) is non-empty, and f is differentiable
on int(dom f ).
The notions of essential smoothness and essential strict convexity defined below
describe certain nice properties of a convex function (see Section 26 in Rockafellar
[179]).

Definition 7.6.1 (Essential smoothness and essential strict convexity). If for any
sequence {wn }∞
n=1 ⊂ int(dom f ) going to the boundary of dom f as n → ∞, it holds

that k∇f (wn )k → ∞, then we say f is essentially smooth. If f is strictly convex on


every convex subset of int(dom f ), then we say f is essentially strictly convex.

The concept of convex conjugate is critical in our derivation. Specifically, given a


convex function f : Rd → R ∪ {∞}, its convex conjugate f ∗ is defined as

f ∗ (w) = sup hw, yi − f (y).


y∈Rd

The following results characterize the relationship between a convex function and its
conjugate.

277
Theorem 7.6.2 (Theorem 26.3, Rockafellar [179]). A convex function f is essentially
strictly convex if and only if its convex conjugate f ∗ is essentially smooth.

Proposition 7.6.3 (Proposition 2.5, Bauschke et al. [180]). If f is essentially strictly


convex, then range ∂f = int(dom f ∗ ) = dom ∇f ∗ , where ∂f is the subgradient of f .

Lemma 7.6.4 (Corollary 2.6, Bauschke et al. [180]). If f is essentially strictly convex,
then it holds for all w ∈ int(dom f ) that ∇f (w) ∈ int(dom f ∗ ) and ∇f ∗ (∇f (w)) = w.

The class of Legendre functions defined in Definition 7.3.8 contains convex functions
that are both essentially smooth and essentially strictly convex.

Theorem 7.6.5 (Theorem 26.5, Rockafellar [179]). A convex function f is a Legendre


function if and only if its conjugate f ∗ is. In this case, the gradient mapping ∇f :
int(dom f ) → int(dom f ∗ ) satisfies (∇f )−1 = ∇f ∗ .

Next, we introduce the notion of Bregman function [181, 182]. It has been shown
in Bauschke et al. [180] that the properties of Bregman functions are crucial to prove
the trajectory convergence of Riemannian gradient flow where the metric tensor is
given by the Hessian of some Bregman function f .

Definition 7.6.6 (Bregman functions; Definition 4.1, Alvarez et al. [183]). A function
f is called a Bregman function if it satisfies the following properties:

(a) dom f is closed. f is strictly convex and continuous on dom f . f is C 1 on


int(dom f ).

(b) For any w ∈ dom f and α ∈ R, {y ∈ dom f | DR (w, y) ≤ α} is bounded.

(c) For any w ∈ dom f and sequence {wi }∞


i=1 ⊂ int(dom f ) such that limi→∞ wi = w,

it holds that limi→∞ DR (w, wi ) → 0.

The following theorem provides a special sufficient condition for f to be a Bregman


function.
278
Theorem 7.6.7 (Theorem 4.7, Alvarez et al. [183]). If f is a Legendre function with
dom f = Rd , then dom f ∗ = Rd implies that f is a Bregman function.

The following theorem from Alvarez et al. [183] provides a convenient tool for
proving the convergence of a Riemannian gradient flow.

Theorem 7.6.8 (Theorem 4.2, Alvarez et al. [183]). Suppose f : Rd → R ∪ {∞}


is a Bregman function and also a Legendre function, and satisfies that f is twice
continuously differentiable on int(dom f ) and ∇2 f is locally Lipschitz. Consider the
following Riemannian gradient flow:

dw(t) = −∇2 f (w(t))−1 ∇L(w(t))dt, w(0) = winit ∈ int(dom f )

where the loss L : Rd → R satisfies that L is quasi-convex, ∇L is locally Lipschitz,


and argmin{L(w) | w ∈ dom f } is non-empty. Then as t → ∞, w(t) converges to
some w∗ ∈ dom f such that h∇L(w∗ ), w − w∗ i ≥ 0 for all w ∈ dom f . If the loss L is
further convex, then w∗ is a minimizer of L on dom f .

7.7 Omitted proofs in Section 7.3

Here we first present the proof for the result on the domain of the flow induced by G.

Proof of Lemma 7.3.6. Fix any x ∈ M . For each i ∈ [d], let Ii (x) be the domain of
φtGj (x) in terms of t. If ∇Gi is a complete vector field on M as in Definition 7.3.2, then
Ii (x) = Rd , otherwise φtGj (x) is defined for t in an open interval containing 0 (see, e.g.,
Theorem 2.1 in Lang [184]). Then we claim that for any distinct j1 , j2 , . . . , jk ∈ [d]
µ µj
where k ∈ [d], the set of all (µj1 , . . . , µjk ) ∈ Rk such that φGjj1 ◦ · · · ◦ φGjk (x) is well-
1 k

defined is a hyperrectangle given by Ij1 (x) × Ij2 (x) × · · · × Ijk (x). Then the desired
result can be obtained by letting (j1 , j2 , . . . , jd ) = (1, 2, . . . , d). We prove the claim by
induction over k ∈ [d].
279
The base case for k = 1 has already been established above. Next, assume the claim
holds for 1, 2, . . . , k − 1 where k ≥ 3, and we proceed to show it for k. By the claim for
µ µj
k−2, φGjj3 ◦· · ·◦φGjk (x) is well-defined for (µj3 , . . . , µjk ) ∈ Ij3 (x)×· · ·×Ijk (x). For any
3 k
µj
such (µj3 , . . . , µjk ), φtGj ◦φµG3j ◦· · ·◦φGjk (x) is well-defined for t in and only in the open
1 3 k
µj
interval Ij1 (x) by applying the claim for k − 1, and similarly φtGj ◦ φµG3j ◦ · · · ◦ φGjk (x)
2 3 k

is also well-defined for t in and only in the open interval Ij2 (x). Note that for any
(s, t) ∈ Ij1 (x) × Ij2 (x),

µ µj
φsGj1 ◦ φ−t t j3
Gj ◦ φGj2 ◦ φGj ◦ · · · ◦ φGj (x)
k
2 3 k

is well-defined, so by Assumption 7.3.5, we see that

−t µ µj
φG j
◦ φsGj1 ◦ φtGj2 ◦ φGjj3 ◦ · · · ◦ φGjk (x)
2 3 k

µ µj
is also well-defined, which further implies that φsGj ◦ φtGj ◦ φGjj3 ◦ · · · ◦ φGjk (x) is
1 2 3 k
µ µj
well-defined. Therefore, we conclude that φGjj1 ◦ ··· ◦ φGjk (x) is well-defined for and
1 k

only for (µj1 , . . . , µjk ) ∈ Ij1 (x) × · · · × Ijk (x). This completes the induction and hence
finishes the proof.

Next, we provide the proof for the implicit bias of mirror flow summarized in
Theorem 7.3.9. We need the following lemma that characterizes the KKT conditions
for minimizing a convex function R in a linear subspace.

Lemma 7.7.1. For any convex function R : Rd → R ∪ {∞} and Z ∈ Rn×d , suppose
∇R(w∗ ) = Z > λ for some λ ∈ Rn , then

R(w∗ ) = min R(w).


w:Z(w−w∗ )=0

280
Proof of Lemma 7.7.1. Consider another convex function defined as R(w)
e = R(w) −
w> Z > λ, then ∇R(w
e ∗ ) = ∇R(w∗ ) − Z > λ = 0, which implies that

e ∗ ) = min R(w) − w> Z > λ


R(w
w∈Rd

≤ min R(w) − w> Z > λ


w:Z(w−w∗ )=0

= min R(w) − w∗> Z > λ.


w:Z(w−w∗ )=0

e ∗ ) = R(w∗ ) − w∗> Z > λ, it follows that


Since R(w

R(w∗ ) ≤ min R(w),


w:Z(w−w∗ )=0

and the equality is achieved at w = w∗ . This finishes the proof.

We then can prove Theorem 7.3.9 by using Lemma 7.7.1.

Proof of Theorem 7.3.9. Since L(w) = L(Zw−Y


e ), the mirror flow (7.7) can be further
written as

d∇R(w(t)) = −Z > ∇L(Zw(t)


e − Y )dt.

Integrating the above yields that for any t ≥ 0,

Z t
>
∇R(w(t)) − ∇R(w0 ) = −Z ∇L(Zw(s)
e − Y )ds ∈ span(X > ),
0

which further implies that ∇R(w∞ ) − ∇R(w0 ) ∈ span(Z > ). Therefore,

∇DR (w, w0 )|w=w∞ = ∇R(w∞ ) − ∇R(w0 ) ∈ span(Z > ).

281
Then applying Lemma 7.7.1 yields

DR (w∞ , w0 ) = min DR (w, w0 ).


w:Z(w−w∞ )=0

This finishes the proof.

7.8 Omitted proofs in Section 7.4

Here we provide the omitted proofs in Section 7.4, including four main parts:

(1) Properties of commuting parametrizations (Section 7.8.1);

(2) Necessary condition for a smooth parametrization to be commuting (Sec-


tion 7.8.2);

(3) Convergence for gradient flow with commuting parametrization (Section 7.8.3);

(4) Results for the underdetermined linear regression (Section 7.8.4).

7.8.1 Properties of commuting parametrizations

We first show the representation formula for gradient flow with commuting parametriza-
tion given in Lemma 7.4.7.

Proof of Lemma 7.4.7. Let µ(t) be given by the following differential equation:

dµ(t) = −∇Lt (G(ψ(xinit ; µ(t))))dt, µ(0) = 0.

282
For any µ ∈ U(x) and j ∈ [d], µ + δej ∈ U(x) for all sufficiently small δ, thus

∂ ψ(xinit ; µ + δej ) − ψ(xinit ; µ)


ψ(xinit ; µ) = lim
∂µj δ→0 δ
δ
φG (ψ(xinit ; µ)) − ψ(xinit ; µ)
= lim j
δ→0 δ
= ∇Gj (ψ(xinit ; µ))

where the second equality follows from the assumption that G is a commuting
∂ψ(xinit ;µ)
parametrization and Theorem 7.4.2. Then we have ∂µ
= ∂G(ψ(xinit ; µ))>
for all µ ∈ U(xinit ), and thus when µ(t) ∈ U(xinit ),

∂ψ(xinit ; µ(t))
dψ(xinit ; µ(t)) = dµ(t)
∂µ(t)
= −∂G(xinit ; µ(t))∇Lt (G(ψ(xinit ; µ(t))))dt

= −∇(Lt ◦ G)(ψ(xinit ; µ(t)))dt.

Then since ψ(xinit ; µ(0)) = xinit and ψ(xinit ; µ(t)) follows the same differential equation
and has the same initialization as x(t), we have x(t) ≡ ψ(xinit ; µ(t)) for all t ∈ [0, T ).
Therefore,

Z t Z t
µ(t) = µ(0) + −∇Lt (G(ψ(xinit ; µ(s))))ds = −∇Lt (G(x(s)))ds
0 0

for all t ∈ [0, T ), which completes the proof.

Next, to prove Lemma 7.4.8, we need the following lemma which provides a
sufficient condition for a vector function to be gradient of some other function.

Lemma 7.8.1. Let Ψ : C → Rd be a differentiable function where C is a simply


∂ ∂
connected open subset of Rd . If for all w ∈ C and any i, j ∈ [d], Ψ (w)
∂wj i
= Ψ (w),
∂wi j

then there exists some function Q : C → R such that Ψ = ∇Q.

283
Proof of Lemma 7.8.1. This follows from a direct application of Corollary 16.27 in
Lee [173].

Based on the above results, we proceed to prove Lemma 7.4.8.

Proof of Lemma 7.4.8. By Lemma 7.3.6, U(xinit ) is hyperrectangle, and hence is


∂ψ(xinit ;µ)
convex. Next, recall that by the proof of Lemma 7.4.7, we have ∂µ
=
∂G(ψ(xinit ; µ))> for all µ ∈ U(xinit ). Denoting Ψ(µ) = G(ψ(xinit ; µ)), we further have

∂G(ψ(xinit ; µ)) ∂ψ(xinit ; µ)


∂Ψ(µ) = = ∂G(ψ(xinit ; µ))∂G(ψ(xinit ; µ))> , ∀µ ∈ U(x).
∂ψ(xinit ; µ) ∂µ

Since G is regular, ∂G(ψ(xinit ; µ)) is of full-rank for all µ ∈ U(xinit ), so ∂Ψ is symmetric


and positive definite for all µ ∈ U(xinit ), which implies that Ψ is the gradient of some
strictly convex function Q : Rd → R ∪ {∞} by Lemma 7.8.1. This Q satisfies that
∇Q(µ) = Ψ(µ) = G(ψ(xinit ; µ)) for all µ ∈ U(xinit ). Therefore, Q is a strictly convex
function with dom ∇Q = U(xinit ) and range ∇Q = Ωw (xinit ; G).
Next, we show that Q is essentially smooth. If U(xinit ) = Rd , then dom Q = Rd and
the boundary of dom Q is empty, so it is trivial that Q is essentially smooth. Otherwise,
it suffices to show that for any µ on the boundary of dom Q and any sequence {µk }∞
k=1 ⊂

U(xinit ) such that limk→∞ µk = µ∞ , we have limk→∞ k∇Q(µk )k2 = ∞. Since each
∇Q(µk ) = G(ψ(xinit ; µk )), we only need to show that limk→∞ kG(ψ(xinit ; µk ))k2 = ∞.
Suppose otherwise, then {G(ψ(xinit ; µk )}∞
k=1 is bounded. Note that by Lemma 7.4.7,

let Hk (x) = hµk , G(x)i, and we have

Z 1
ψ(xinit ; µk ) = φ1−Hk (xinit ) = xinit + ∇Hk (φs−Hk (xinit ))ds.
0

284
Therefore,

s
Z 1 Z 1
2
kψ(xinit ; µk ) − xinit k2 ≤ ∇Hk (φs−Hk (xinit )) 2 ds ≤ ∇Hk (φs−Hk (xinit )) 2 ds.
0 0

(7.14)

where the second inequality follows from Cauchy-Schwarz inequality. Further note
that

Z 1
d
Hk (ψ(xinit ; µk )) − Hk (xinit ) = Hk (φs−Hk (xinit ))ds
ds
Z0 1 
dφs−Hk (xinit )

s
= ∇Hk (φ−Hk (xinit )), ds
0 ds
Z 1
= k∇Hk (φs−Hk (xinit ))k22 ds. (7.15)
0

Then combining (7.14) and (7.15), we get

p
kψ(xinit ; µk ) − xinit k2 ≤ hµk , G(ψ(xinit ; µk )) − G(xinit )i
p
≤ kµk k2 · kG(ψ(xinit ; µk )) − G(xinit )k2 ,

which implies that {ψ(xinit ; µk )}∞


k=1 is bounded. Then there exists a convergent subse-

quence of {ψ(xinit ; µk )}∞


k=1 , and without loss of generality we assume that ψ(xinit ; µk )

itself converges to some x∞ ∈ M as k → ∞. Note that ψ(x∞ ; µ) is well-defined for µ


in a small open neighborhood of 0, and since limk→∞ ψ(xinit ; µk ) = x∞ , for sufficiently
large k, ψ(ψ(xinit ; µk ); µ) is well-defined for µ in a small neighborhood of 0 that does
not depend on k. Thus there exists some µ ∈ Rd such that µk + µ ∈
/ U(xinit ) but
ψ(ψ(xinit ; µk ); µ) is well-defined for sufficiently large k. But by Lemma 7.3.6 and
Theorem 7.4.2, ψ(ψ(xinit ; µk ); µ) = ψ(xinit ; µk + µ) and thus µk + µ ∈ U(xinit ), which
leads to a contradiction. Hence, we conclude that Q is essentially smooth.

285
Combining the above, it follows that Q is a Legendre function. Let R : Rd →
R ∪ {∞} be the convex conjugate of Q. Then by Theorem 7.6.5, R is also a Legendre
function. Note that for any µ ∈ U(xinit ), by the result in Crouzeix [185], we have

∇2 R(G(ψ(xinit ; µ))) = ∇2 R(∇Q(µ)) = ∇2 Q(µ)−1 = (∂G(ψ(xinit ; µ))∂G(ψ(xinit ; µ))> )−1 .

Therefore, R and Q are both Legendre functions, and by Proposition 7.6.3, we


further have range ∇R = int(dom Q) = dom ∇Q = U(x) and conversely dom ∇R =
range ∇Q = Ωw (xinit ; G). This finishes the proof.

7.8.2 Necessary condition for a smooth parametrization to

be commuting

Proof of Theorem 7.4.10. Fix any initialization xinit ∈ M , and let the Legendre func-
tion R be given such that for all time-dependent loss Lt , the gradient flow under Lt ◦ G
initialized at x can be written as the mirror flow under Lt with respect to the Legendre
function R. We first introduce a few notations that will be useful for the proof. For
any s ∈ R, we define a time-shifting operator Ts such that for any time-dependent loss
Lt (·), (Ts L)t (·) = Lt−s (·). We say a time-dependent loss Lt is supported on finite time
if Lt = ki=1 1t∈[ti ,ti+1 ) L(i) for some k ≥ 1 where t1 = 0, tk+1 = ∞ and L(k) ≡ 0, and
P

we denote len(L) = tk . We further define the concatenation of two time-dependent


loss Lt , L0t supported on finite time as L k L0 = L + Tlen(L) L0 . We also use L to denote
the time-reverse of the time-dependent loss L which is supported on finite time, that
is, Lt = Llen(L)−t for all t ≥ 0. For any j ∈ [d] and δ > 0, we define the following loss
function

`j,δ
t (w) = 10≤t≤δ · hej , wi (7.16)

286
where ej is the j-th canonical base of Rd .
Now for any k ≥ 2, let {ji }ki=1 be any sequence where each ji ∈ [d]. Then
we recursively define a sequence of time-dependent losses as follows: First define
L1,δ = −`j1 ,δ , then sequentially for each i = 2, 3, . . . , k, we define

√  √   √  √
i,δ i−1, δ ji , δ i−1, δ
L =L k −` k −L k `ji , δ
(7.17)

√ √
i−1, δ
where we write L for convenience. Denote ιi (δ) = len(Li,δ ) for each
= Li−1, δ

√ √
i ∈ [k]. Then ι1 (δ) = δ and ιi (δ) = 2 δ + 2ιi−1 ( δ) for i = 2, 3, . . . , k, which further
implies

i−1
m i−1
X
ιi (δ) = 2m δ 1/2 + 2i−1 δ 1/2 for all i ∈ [k].
m=1

Moreover, for each i = 2, 3, . . . , k, the gradient of Li,δ with respect to w is given by



 √
i−1, δ




 ∇Lt (w) 0 ≤ t ≤ ιi−1 ( δ),

√ √ √






 −eji ιi−1 ( δ) < t ≤ ιi−1 ( δ) + δ,
 √
√ √ √ √

∇Li,δ
t (w) = −∇Lt
i−1, δ
(w) ιi−1 ( δ) + δ < t ≤ 2ιi−1 ( δ) + δ, (7.18)


√ √ √ √






 eji 2ιi−1 ( δ) + δ < t ≤ 2ιi−1 ( δ) + 2 δ,

√ √



0
 t > 2ιi−1 ( δ) + 2 δ.

This inductively implies that for any t ∈ [0, ιk (δ)], ∇Lk,δ d


t (w) ∈ {ej }j=1 does not depend

on w and is only determined by t. Therefore, for any initialization x ∈ M , for all


ι (δ)
sufficiently small δ > 0, the gradient flow under Lk,δ for ιk (δ) time, i.e., φLkk,δ (x), is

287
well-defined. Moreover, it follows from (7.18) that
√ √ √
Z ιk−1 (δ) Z ιk−1 ( δ) √
Z ιk−1 ( δ)+ δ
∇Lk,δ
t (w(t))dt = ∇Lk−1, δ
(w(t))dt + √
−ejk dt
0 0 ιk−1 ( δ)
√ √ √ √
Z 2ιk−1 ( δ)+ δ √ Z 2ιk−1 ( δ)2 δ
k−1, δ
+ √ √
−∇L (w(t))dt + √ √
ejk dt
ιk−1 ( δ)+ δ 2ιk−1 ( δ)+ δ

Z ιk−1 ( δ)  √ √ 
k−1, δ
= ∇Lk−1,
t
δ
(w(t)) − ∇Lt (w(t)) dt = 0
0


where the last two equalities follow from the fact that ∇Ltk−1, δ (w) does not depend
on w and is only determined by t by our construction.
Hence, the mirror flow with respect to the Legendre function R for the time-
dependent loss Lk,δ will return to the initialization after ιk (δ) time since

Z ιk (δ)
∇R(w(ιk (δ))) − ∇R(w(0)) = −∇Lk,δ (w(t))dt = 0.
0

This further implies that

ι (δ) 
G(xinit ) = G φLkk,δ ◦G (xinit )

for all sufficiently small δ. Then differentiating with δ on both sides yields

ι (δ)
dφ kk,δ (xinit )
∂G(x) · L ◦G = 0. (7.19)
dδ δ=0

Note that if the following holds:

ι (δ)
dφLkk,δ ◦G (xinit )
= [[[[∇Gj1 , ∇Gj2 ], . . .], ∇Gjk−1 ], ∇Gjk ](xinit ), (7.20)
dδ δ=0

then combining (7.19) and (7.20) completes the proof, so it remains to verify (7.20).

288
We will prove by induction over k, and now let {ji }∞
i=1 be an arbitrary sequence

where each ji ∈ [d]. For notational convenience, we denote for each k ≥ 1,

ι (δ)
πk,δ (·) := φδ−`jk ,δ (·) and Πk,δ (·) := φLkk,δ (·).

−1 ι (δ)
Then their inverse maps are given by πk,δ (·) = φδ`jk ,δ (·) and Π−1
k,δ (·) = φ
k
k,δ (·) respec-
−L
k
tively. Since G is smooth, each Πk,√δ is a C ∞ function of δ 1/2 , and we can expand it
k
in δ 1/2 as

2 k k
X δ i/2
Πk,√δ (x) = x + ∆k,i (x) + rk,δ (x) (7.21)
i=1
i!

where the remainder term rk,δ (x) is continuous in x and for each x ∈ M , rk,δ (x) = o(δ)
rk,δ (x)
(i.e., limδ→0 δ
= 0), and each ∆k,i is defined as

di Πk,√δ (x)
∆k,i (x) = .
d(δ 1/2k )i δ=0

In particular, for k = 1, we have

√ δ
Π1,√δ (x) = π1,√δ (x) = x + δ∇Gj1 (x) + ∂(∇Gj1 )(x)∇Gj1 (x) + r1,δ (x) (7.22)
2

where the second equality holds as well for any other Gj in place of Gj1 , with a different
but similar remainder term. For any fixed K ≥ 2, there is a small open neighborhood
of xinit on M , denoted by Nxinit ⊆ M , such that for all k ∈ [K], we have rk,δ (x) = o(δ)
uniformly over all x ∈ Nxinit , so we can replace all rk,δ (x) by o(δ) when x ∈ Nxinit .
Then we claim that for each k = 2, 3, . . . , K,

2 k−1k
1 X δ i/2
lim √ ∆k,i (x) = [[[∇Gj1 , ∇Gj2 ], . . .], ∇Gjk ](x), ∀x ∈ Nxinit , (7.23)
δ→∞ δ i=1 i!

289
which directly implies (7.20). With a slight abuse of notation, the claim is also true
for k = 1 since ∆1,1 (x) = ∇Gj1 (x) by (7.22), so we use this as the base case of the
induction. Then, assuming (7.23) holds for k − 1 < K, we proceed to prove it for k.
For convenience, further define LieG (j1:k ) = [[[∇Gj1 , ∇Gj2 ], . . .], ∇Gjk ].
Combining the Taylor expansion in (7.21) and (7.23) for k − 1, we obtain for all
x ∈ Nxinit that

k−1
2X
√ δ i/2
k−1

Πk−1, δ (x) = x + δ · LieG (j1:(k−1) )(x) + ∆k−1,i (x) + o(δ)
i!
i=2k−2 +1

for sufficiently small δ. Further apply (7.22) with Gjk in place of Gj1 for sufficiently
small δ, and then


Πk−1,√δ πk,√δ (x)

 

δ
= Πk−1, δ x + δ∇Gjk (x) + ∂(∇Gjk )(x)∇Gjk (x) + o(δ)
2
√ δ
= x + δ∇Gjk (x) + ∂(∇Gjk )(x)∇Gjk (x) + o(δ)
2 
√ √

δ
+ δ · LieG (j1:(k−1) ) x + δ∇Gjk (x) + ∂(∇Gjk )(x)∇Gjk (x) + o(δ)
2
k−1
2X
δ i/2
k−1

 
δ
+ ∆k−1,i x + δ∇Gjk (x) + ∂(∇Gjk )(x)∇Gjk (x) + o(δ)
k−2
i! 2
i=2 +1

 
δ
+ rk−1,δ x + δ∇Gjk (x) + ∂(∇Gjk )(x)∇Gjk (x) + o(δ)
2

where the second equality follows from the Taylor expansion of Πk−1,√δ and that
πk,√δ (x) ∈ Nxinit for sufficiently small δ. Then by the Taylor expansion of LieG (j1:(k−1) )

290
and each ∆k−1,i , we have for all x ∈ Nxinit ,

 √ √ δ
Πk−1,√δ πk,√δ (x) = x + δ∇Gjk (x) + δ · LieG (j1:(k−1) )(x) + ∂(∇Gjk )(x)∇Gjk (x)
2
k−1
2X k−1
δ i/2
+ δ · ∂LieG (j1:(k−1) )(x)∇Gjk (x) + ∆k−1,i (x) + o(δ)
k−2
i!
i=2 +1

(7.24)

for sufficiently small δ. For the other way around, we similarly have

k−1
2X
√ δ i/2
k−1
 

πk,√δ √ √
Πk−1, δ (x) = πk, δ x + δ · LieG (j1:(k−1) )(x) + ∆k−1,i (x) + o(δ)
i!
i=2k−2 +1
√ √ δ
=x+ δ · LieG (j1:(k−1) ) + ∂(∇Gjk )(x)∇Gjk (x)
δ∇Gjk (x) +
2
k−1
2X k
δ i/2
+ δ∂(∇Gjk )(x)LieG (j1:(k−1) )(x) + ∆k−1,i (x) + o(δ)
k−2
i!
i=2 +1

(7.25)

−1√ ◦ Π−1 √ ◦ Π √
for all x ∈ Nxinit , when δ is sufficiently small. Note that x = πk, δ k−1, δ k−1, δ ◦

πk,√δ (x), thus

−1√ ◦ Π−1 √ ◦ π √ ◦ Π √
Πk,δ (x) − x = πk, δ k−1, δ k, δ k−1, δ (x) − x

−1√ ◦ Π−1 √ ◦ π √ ◦ Π √ −1 −1√ √ √


= πk, k−1, δ (x) − πk, δ ◦ Πk, δ ◦ Πk, δ ◦ πk, δ (x)

δ k−1, δ k, δ

−1√ ◦ Π−1 √ ◦ π √ ◦ Π √ √ √
= πk, δ k−1, δ k, δ k−1, δ (x) − πk, δ ◦ Πk−1, δ (x)

+ πk,√δ ◦ Πk−1,√δ (x) − Πk,√δ ◦ πk,√δ (x)


−1√ ◦ Π−1√ ◦ Π √ ◦ π √ (x)
+ Πk−1,√δ (x) ◦ πk,√δ − πk, δ k, δ k, δ k, δ

= Πk−1,√δ ◦ πk,√δ (x) − πk,√δ ◦ Πk−1,√δ (x) + o(δ) (7.26)

291
−1√ ◦ Π−1 √ (·) in terms
where the last equality follows from the Taylor expansion of πk, δ k−1, δ

of δ. Now, combining (7.24), (7.25) and (7.26), we obtain


Πk,δ (x) − x = δ ∂(∇Gjk )(x)LieG (j1:(k−1) )(x) − ∂LieG (j1:(k−1) )(x)∇Gjk (x) + o(δ)

= δ · [LieG (j1:(k−1) ), ∇Gjk ](x) + o(δ) (7.27)

where the second equality follows from the definition of Lie bracket. Comparing (7.27)
with (7.21) yields (7.23). This completes the induction for k ∈ [K] and hence finishes
the proof as K is arbitrary.

7.8.3 Convergence for gradient flow with commuting parametriza-

tion

Proof of Theorem 7.4.13. Recall that the dynamics of w(t) is given by

dw(t) = −∇2 R(w(t))−1 ∇L(w(t))dt, w(0) = G(xinit ).

By Lemma 7.4.8, we know that R is a Legendre function. Therefore, when R is further


a Bregman function, we can apply Theorem 7.6.8 to obtain the convergence of w(t).
This finishes the proof.

Based on Theorem 7.6.7, we can prove the trajectory convergence of w(t) for the
special case where Ωw (xinit ; G) = Rd as summarized in Corollary 7.4.14.

Proof of Corollary 7.4.14. It suffices to verify that R is a Bregman function in this


case. By Lemma 7.4.8, we know that R is a Legendre function and satisfies that
Rd = Ωw (xinit ; G) = dom ∇R ⊆ dom R ⊆ Rd , which implies dom R = Rd . Moreover,
the domain of its convex conjugate Q is also Rd . Then by Theorem 7.6.7, we see that
R is a Bregman function. This finishes the proof.

292
Next, we prove that for a class of commuting quadratic parametrizations, the
corresponding Legendre function is also a Bregman function, thus guaranteeing the
trajectory convergence.

Proof of Theorem 7.4.15. Since A1 , A2 , . . . , Ad commute with each other, these matri-
ces can be simultaneously diagonalized. Thus we can assume without loss of generality
that each Ai = diag(λi ) where λi ∈ RD , then Gi (x) = λ> 2
i x . For convenience, we

denote Λ = (λ1 , λ2 , . . . , λd )> ∈ Rd×D , so the parametrization is given by G(x) = Λx 2 .


Note that for each i ∈ [d], ∇Gi (x) = 2λi x and ∇2 Gi (x) = 2diag(λi ), so for any
i, j ∈ [d], we have

[∇Gi , ∇Gj ](x) = 4diag(λi )λj x − 4diag(λj )λi x = 0.

Therefore, we see that G : RD d


+ → R is a commuting parametrization. Also, for any
Rt
t ∈ R, x(t) = xinit − 0 ∇Gi (x(s))ds = xinit e−2λi t , which proves the first and the
second claims. Moreover, if the sign of each coordinate of x will not change from
that of initialization, (sign means +,− or 0). Without loss of generality, below we
will assume every coordinate is non-zero at initialization (otherwise we just ignore it).
We can also assume the coordinates at initialization are all positive, as the negatives
will induce the same trajectory in terms of G(x). By Theorem 7.4.9, the dynamics of
w(t) = G(x(t)) is given by

dw(t) = −∇2 R(w(t))−1 ∇L(w(t))dt, w(0) = G(xinit )

for some Legendre function R whose conjugate is denoted by Q. To apply the results
in Theorem 7.4.13, it suffices to show that this R is a Bregman function.
2
To do so, we further denote w
e=x and G(x)
e = x 2 , then w = Λw
e and in this
case G e defined on M = RD
e is a commuting parametrization for w + . Also, we have

∂G(x) = Λ∂ G(x).
e e : Rd → R be defined by L(
Let L e w)
e = L(Λw),
e which satisfies that
293
∇L( e = Λ> ∇L(Λw).
e w) e Then the gradient flow with parametrization G
e governed by

−∇(L
e ◦ G)(x)
e is given by

> e e
dx(t) = −∇(L
e ◦ G)(x)dt
e = −∂ G(x(t))
e ∇L(G(x(t))dt
> >
= −∂ G(x(t))
e Λ ∇L(ΛG(x(t))dt
e

= −∂G(x(t))> ∇L(G(x(t))dt,

which yields the same dynamics of the gradient flow with parametrization G governed
by −∇(L ◦ G)(x). Therefore, we have w(t) = G(x(t)) = ΛG(x(t))
e = Λw(t),
e where
again by Theorem 7.4.9, the dynamics of w(t)
e is

e = −∇2 R(
dw(t) e −1 ∇L(
e w(t)) e w(t))dt,
e w(0)
e = G(x
e init )

for some Legendre function R e For any x ∈ M and


e whose conjugate is denoted by Q.

e ∈ RD , we define ψ(x;
µ e µ e) = φµG
e1 µ
e ◦ φG
e2 µ
e ◦ · · · ◦ φG
eD
e (x). We need the following lemma.
1 2 D

Lemma 7.8.2. In the setting of the proof of Theorem 7.4.15, for any µ ∈ Rd and
e Λ> µ).
x ∈ M , we have ψ(x; µ) = ψ(x;

Recall from Lemma 7.4.8 that ∇Q(µ) = G(ψ(xinit ; µ)) for any µ ∈ Rd and ∇Q(e
e µ) =

G(
e ψ(x
e init ; µ e ∈ RD . Note that
e)) for any µ

∇Q(µ) = Λψ(xinit ; µ) 2 e init ; Λ> µ)


= Λψ(x 2
= ΛG( e init ; Λ> µ)) = Λ∇Q(Λ
e ψ(x e > µ)

(7.28)

where the second equality follows from Lemma 7.8.2. This implies that Q(µ) =
e > µ) + C for some constant C. Recall the definition of convex conjugate, and we
Q(Λ

294
have

R( e = sup he
e w) e − Q(e
µ, wi e µ), R(w) = sup hµ, wi − Q(µ).
e∈RD
µ µ∈Rd

e ∈ RD , we have
Then for any w

e − Q(µ) = sup hΛ> µ, wi


e = sup hµ, Λwi
R(Λw) e > µ) − C
e − Q(Λ
µ∈Rd µ∈Rd

= sup he e − Q(e
µ, wi e µ) − C ≤ sup he e − Q(e
µ, wi e µ) − C = R( e −C
e w) (7.29)
e∈Λ> Rd
µ e∈RD
µ

e ∈ dom R,
Therefore, for any w e ≤ R(
e it holds that R(Λw) e − C < ∞, so Λ dom R
e w) e⊆

dom R, where Λ dom R


e On the other hand, by (7.28) and Proposition 7.6.3, we have

dom ∇R = range ∇Q ⊆ Λ range ∇Q


e = Λ dom ∇R
e

and it follows that

int(dom R) = dom ∇R ⊆ Λ dom ∇R


e = Λ int(dom R).
e

Combining the above, we see that dom R = Λ dom R. e As discussed in Section 7.1,

e = D ei (ln x2wei − 1), which is indeed


P
here it is straightforward to verify that R(
e w)
i=1 w init,i

e = RD
a Bregman function with domain dom R D
+ . Thus dom R = ΛR+ is also a closed

set. This yields the first condition in Definition 7.6.6.


Next, we verify the second condition in Definition 7.6.6. For any µ ∈ Rd , we have

∇R(G(ψ(xinit ; µ))) = ∇R(∇Q(µ)) = µ

295
and

> e e > >


∇R(
e G(ψ(x
e init ; µ))) = ∇R(G(ψ(xinit ; Λ µ))) = ∇R(∇Q(Λ µ)) = Λ µ.
e e e

Comparing the above two equalities, we get

∇R( e = Λ> ∇R(Λw)


e w) e (7.30)

e ∈ RD
for all w e ∈ RD
+ . Then for any w y ∈ int(dom R), we have
+ and y = Λe

DR (Λw, e − R(y) − h∇R(y), Λw


e y) = R(Λw) e − yi

= R(Λw) y ) − hΛ> ∇R(Λe


e − R(Λe e − yei
y ), w

e − R(Λe
= R(Λw) y ) − h∇R(e e − yei
e y ), w

e − R(Λe
= R(Λw) y ) − R(
e w) e y ) + D e (w,
e + R(e R e y e) (7.31)

≥ R(Λw)
e − R( e + C + DRe (w,
e w) e ye)

where the inequality follows from (7.29). Therefore, we further have for any α ∈ R

{y ∈ int(dom R) | DR (Λw, y ∈ RD
e y) ≤ α} ⊆ Λ{e + | DR e ye) ≤ α − R(Λw)
e (w, e + R( e − C}
e w)

where the right-hand side is bounded since R


e is a Bregman function, and so is the

left-hand side.
Finally, we verify the third condition in Definition 7.6.6. Consider any w ∈ dom R
and sequence {wi }∞
i=1 ⊂ int(dom R) such that limi→∞ wi = w. Since dom R =

Λ dom R, e ∈ RD
e there is some w + such that w = Λw ei ∈ RD
e and some w + for each i ∈ N
+

296
such that wi = Λw
ei . We have that

Z 1
R(w) − R(wi ) = h∇R((1 − t)wi + tw), w − wi idt
0
Z 1
= hΛ> ∇R(Λ((1 − t)w
ei + tw)),
e w e−w
ei idt
Z0 1
= h∇R((1
e − t)w
ei + tw),
e w e−w
ei idt
0

= R( e − R(
e w) ew ei ).

Combining this with (7.31), we get DR (w, wi ) = DRe (w,


e wei ). Note that we can always
choose each w
ei properly such that limi→∞ w
ei = w.
e Then since R
e is a Bregman function,

we have

lim DR (w, wi ) = lim DRe (w,


e wei ) = 0.
i→∞ i→∞

Therefore, we conclude that R is also a Bregman function. This finishes the


proof.

Proof of Lemma 7.8.2. For each i ∈ [D] and any t > 0, we have

Z t Z t D
X
φtGi (x) =x+ −∇Gi (φsfi (x))ds =x+ − ej (φs (x))ds = ψ(x;
λi,j ∇G fi
e tλi )
s=0 s=0 j=1

where the last equality follows from Lemma 7.4.7. Therefore, for any µ ∈ Rd , we
further have

ψ(x; µ) = φµG11 ◦ φµG22 ◦ · · · ◦ φµGdd (x)


µ λ1,1 µ λ1,D µ λd,1 µ λd,D
= φGe1 ◦ · · · ◦ φGe1 ◦ · · · ◦ φGed ◦ · · · ◦ φGed (x)
1 D 1 D
Pd Pd
µi λi,1 µi λi,D
= φGe i=1 ◦ · · · ◦ φGe i=1 (x)
1 D

(Λ> µ)1 (Λ> µ)D e Λ> µ).


= φGe ◦ · · · φGe (x) = ψ(x;
1 D

297
where the third equality follows from the assumption that G
e is a commuting

parametrization. This finishes the proof.

7.8.4 Results for underdetermined linear regression

Here we provide the proof for the implicit bias result for the quadratically over-
parametrized linear model.

Proof of Corollary 7.4.17. By symmetry, we assume without loss of generality that


all coordinates of xinit are positive. Note that for M = RD
+ with D = 2d, G : M → R
d

can be written as Gi (x) = x> Ai x where each Ai = ei e> >


i − ed+i ed+i . Therefore, this

parametrization G satisfies the conditions in Theorem 7.4.15, which then implies the
convergence of w(t).
Next, we identify the function R given by Theorem 7.4.9. we have ψ(xinit ; µ) =
u0 e−2µ

v0 e2µ
and thus

G(ψ(xinit ; µ)) = u0 2 e−4µ − v0 2 e4µ

= (u0 2 + v0 2 ) sinh(4µ) + (u0 2 − v0 2 ) cosh(4µ).

So G(ψ(xinit ; µ)) is the gradient of Q(µ) = 14 (u0 2 + v0 2 ) cosh(4µ) + 14 (u0 2 − v0 2 )


sinh(4µ) + C where C is an arbitrary constant. Also note that (∇Q(µ))i only depends
on µi , then we have

s  2 
1 wi wi 1 v0,i
(∇R(w))i = (∇Q(µ))−1
i (w) = ln 1+ + + ln
4 2u0,i v0,i 2u0,i v0,i 4 u0,i
 
1 wi 1 v0,i
= arcsinh + ln
4 2u0,i v0,i 4 u0,i

298
which further implies that

d    q 
1X wi 2 2 2 u0,i
R(w) = wi arcsinh − wi + 4u0,i v0,i − wi ln + C.
4 i=1 2u0,i v0,i v0,i

This finishes the proof.

7.9 Omitted proofs in Section 7.5

We first prove the following intermediate result that will be useful in the proof of
Theorem 7.5.1.

Lemma 7.9.1. Under the setting of Theorem 7.5.1, let F be the smooth map that
isometrically embeds (int(dom R), g R ) into (RD , g). Let M = range(F ), and denote
e : M → Rd . Then for any w ∈ int(dom R), it holds that
the inverse of F by G

∂F (w)(∂F (w)> ∂F (w))−1 = ∂ G(F


e (w))> and ∂ G(F e (w))> = ∇2 R(w)−1 .
e (w))∂ G(F

Proof of Lemma 7.9.1. For any x ∈ M and v ∈ Tx (M ), consider a parametrized curve


dx(t)
{x(t)}t≥0 ⊂ M such that x(0) = x and dt t=0
= v. Since x(t) = F (G(x(t)))
e for any
t ≥ 0, differentiating with respect to t on both sides and evaluating at t = 0 yield

v = ∂F (G(x))∂
e G(x)v.
e (7.32)

Now, for any w ∈ int(dom R), let x = F (w), then for any v ∈ Tx (M ), it follows from
(7.32) that

v > ∂F (w) = v > (∂F (w)∂ G(F


e (w)))> ∂F (w) = v > ∂ G(F
e (w))> ∂F (w)> ∂F (w).

299
Note that the span of the column space of ∂F (w) is exactly Tx (M ), so for any v in
the orthogonal complement of Tx (M ), it holds that

v > ∂F (w) = 0 = v > ∂ G(F


e (w))> ∂F (w)> ∂F (w)

where the second equality follows from the fact that for any i ∈ [d], ∇G
ei (x) ∈ Tx (M ).

Therefore, combining the above two cases, we conclude that

e (w))> ∂F (w)> ∂F (w).


∂F (w) = ∂ G(F

Since ∂F (w)> ∂F (w) = ∇2 R(w) is invertible, we then get

e (w))> = ∂F (w)(∂F (w)> ∂F (w))−1 .


∂ G(F

Next, for any w ∈ int(dom R), since G(F


e (w)) = w, differentiating on both sides

yields

∂ G(F
e (w))∂F (w) = Id .

Therefore, using the identity proved above, we have

∂ G(F e (w))> = ∂ G(F


e (w))∂ G(F e (w))∂F (w)(∂F (w)> ∂F (w))−1

= (∂F (w)> ∂F (w))−1 = ∇2 R(w)−1 .

This finishes the proof.

Proof of Theorem 7.5.1. By Nash’s embedding theorem, there is a smooth map F :


int(dom R) → RD that isometrically embeds (int(dom R), g R ) into (RD , g). Denote
M = range(F ), i.e., the embedding of int(dom R) in RD . We further denote the

300
e : M → Rd . Note (M, G)
inverse of F on M by G e is a global atlas for M , we have
ei (x)}di=1 ) for all x ∈ M . This G
that Tx (M ) = span({∇G e is almost the commuting

parametrization that we seek for, except now it is only defined on M but not on an
open neighborhood of M . Yet we can extend G
e to an open neighbourhood of M in the

following way: First by Foote [186], for each x ∈ M , there is an open neighbourhood
Ux of x such that projection function P defined by

P (y) = argmin ky − y 0 k2
y 0 ∈M

is smooth in Ux . Then we define U = ∪x∈M Ux , and extend G


e to U by defining
e (x)) for all x ∈ U . We have G(x) = G(x)
G(x) := G(P e for all x ∈ M , and we can
verify that ∂G ≡ ∂ G
e on M as well. For any v ∈ Tx (M ), let {γ(t)}t≥0 be a parametrized
dγ(t)
curve on M such that γ(0) = x and dt t=0
= v, then for sufficiently small t, by
Taylor expansion we have

γ(t) = P (γ(t)) = P (x) + ∂P (x)(γ(t) − x) + o(kγ(t) − xk2 )

= x + ∂P (x)(γ(t) − x) + o(kγ(t) − xk2 )

which implies that v = ∂P (x)v by letting t → 0. While for any v in the orthogonal
complement of Tx (M ), for sufficiently small δ > 0, we have P (x + δv) is smooth in δ.
Then since P (x + δv) ∈ M for all sufficiently small δ by its definition, we have

dP (x + δv) P (x + δv) − P (x)


∂P (x)v = = lim =: u ∈ Tx (M ). (7.33)
dδ δ=0
δ→0 δ

Note that kx + δv − P (x + δv)k2 ≤ kx + δv − P (x)k2 = δkvk2 , and by Taylor expansion,


we have

kx + δv − P (x + δv)k2 = kx + δv − δ∂P (x)v + O(δ 2 )k2 = kx + δv − δu + O(δ 2 )k2

301
where O(δ 2 ) denotes a term whose norm is bounded by Cδ 2 for a constant C > 0 for
all sufficiently small δ, and the second equality follows from (7.33). Then dividing
both sides by δ and letting δ → 0, we have kvk2 ≥ kv − uk2 . Since u is orthogonal to
v, we must have u = 0. As v is arbitrary, we conclude that ∂P (x) is the orthogonal
projection matrix onto Tx (M ). Then differentiating both sides of G(x) = G(P
e (x))

with x yields

∂G(x) = ∂ G(P
e (x))∂P (x) = ∂ G(x)
e (7.34)

ei (x)}d ). This
where the second equality follows from the fact that Tx (M ) = span({∇G i=1

further implies that the solution of Equation (7.13) satisfies dx/dt = −∇(L ◦ G)(x)
e ∈
Tx (M ), and thus x(t) ∈ M for all t ≥ 0.
Now we consider the mirror flow

dw(t) = −∇2 R(w(t))−1 ∇Lt (w(t))dt, w(0) = winit .

Since ∇2 R(w) = ∂F (w)> ∂F (w) by the fact that F is an isometric embedding, we


further have

−1
dw(t) = − ∂F (w(t))> ∂F (w(t)) ∇Lt (w(t))dt.

Now define x(t) = F (w(t)), and it follows that

dx(t) = ∂F (w(t))dw(t) = −∂F (w(t))(∂F (w(t))> ∂F (w(t)))−1 ∇Lt (w(t))dt

= −∂G(F (w(t)))> ∇Lt (w(t))dt = −∇(Lt ◦ G)(x(t))dt

where the third equality follows from Lemma 7.9.1 and (7.34).

302
Next, we verify that G restricted on M , G,
e is a commuting and regular parametriza-
e > = ∂F (G(x))(∂F
tion. First, for any x ∈ M , we have ∂ G(x) e (G(x))
e >
∂F (G(x)))
e −1

by Lemma 7.9.1 and (7.34). Since ∇2 R(w) = ∂F (w)> ∂F (w) is of rank d for all
w ∈ int(dom R), it follows that ∂F (w) is also of rank d for all w ∈ int(dom R), thus
∂ G(x)
e ei }d follows directly
is of rank d for all x ∈ M . The commutability of {∇G i=1

from Corollary 7.4.12. Here we just need to show rank(Ωx (x; G))
e = rank(M ). This is
ei (x)}di=1 )) = rank(M ), and on
e ≥ rank(span({∇G
because on one hand rank(Ωx (x; G))
e ≤ rank(M ) since Ωx (x; G)
the other hand, rank(Ωx (x; G)) e ⊂ M , for any x ∈ M .

Finally, we show that when R is a mirror map, each ∇G


ej is a complete vector

field on M . For any xinit ∈ M , consider loss Lt (w) = hej , wi, and the corresponding
gradient flow is

>
dx(t) = −∇(Lt ◦ G)(x(t))dt
e = −∂ G(x(t))
e ∇Lt (G(x(t)))dt
e = −∇G
ej (x(t)),

so x(t) = φtGe (xinit ) for all t ≥ 0. On the other hand, w(t) = G(x(t))
e satisfies that
j

>
dw(t) = ∂ G(x(t))dx(t)
e = −∂ G(x(t))∂
e G(x(t))
e ∇Lt (w(t))dt

= −∇2 R(w(t))−1 ∇Lt (w(t))dt = −∇2 R(w(t))−1 ej dt

where the third equality follows from Lemma 7.9.1 and Equation (7.34). Therefore,
rewriting the above as a mirror Flow yields

d∇R(w(t)) = −ej dt,

the solution to which exists for all t ∈ R and is given by ∇R(w(t)) = ej t, so


w(t) = (∇R)−1 (ej t) is defined for all t ∈ R as ∇R is surjective. This further implies
that x(t) = F (w(t)) is well-defined for all t ∈ R, hence ∇G
ej is a complete vector

field.

303
Part III

Implicit Bias Around Manifold of


Minimizers

304
Chapter 8

Implicit Bias of Stochastic


Gradient Descent

Understanding the implicit bias of Stochastic Gradient Descent (SGD) is one of the key
challenges in deep learning, especially for overparametrized models, where the local
minimizers of the loss function L can form a manifold. Intuitively, with a sufficiently
small learning rate η, SGD tracks Gradient Descent (GD) until it gets close to such
manifold, where the gradient noise prevents further convergence. In such regime,
Blanc et al. [165] proved that SGD with label noise locally decreases a regularizer-like
term, the sharpness of loss, tr[∇2 L].
This chapter gives a general framework for such analysis by adapting ideas
from Katzenberger [187]. It allows in principle a complete characterization for the
regularization effect of SGD around such manifold—i.e., the ”implicit bias”—using a
stochastic differential equation (SDE) describing the limiting dynamics of the parame-
ters, which is determined jointly by the loss function and the noise covariance. This
yields some new results: (1) a global analysis of the implicit bias valid for η −2 steps,
in contrast to the local analysis of Blanc et al. [165] that is only valid for η −1.6 steps
and (2) allowing arbitrary noise covariance.

305
As an application, we show with arbitrary large initialization, label noise SGD can
always escape the kernel regime and only requires O(κ ln d) samples for learning a
κ-sparse overparametrized linear model in Rd [20], while GD initialized in the kernel
regime requires Ω(d) samples. This upper bound is minimax optimal and improves
e 2 ) upper bound [21].
the previous O(κ

8.1 Introduction

The implicit bias underlies the generalization ability of machine learning models trained
by stochastic gradient descent (SGD). But it still remains a mystery to mathematically
characterize such bias. We study SGD in the following formulation


xη (k + 1) = xη (k) − η(∇L(xη (k)) + Ξ · σξk (xη (k))) (8.1)

where η is the learning rate (LR), L : RD → R is the training loss and σ(x) =
[σ1 (x), σ2 (x), . . . , σΞ (x)] ∈ RD×Ξ is a deterministic noise function. Here ξk is sampled
uniformly from {1, 2, . . . , Ξ} and it satisfies Eξk [σξk (x)] = 0, ∀x ∈ Rd and k.
It is widely believed that large LR (or equivalently, small batch size) helps SGD
find better minima. For instance, some previous works argued that large noise enables
SGD to select a flatter attraction basin of the loss landscape which potentially benefits
generalization [164, 188]. However, there is also experimental evidence [50] that small
LR also has equally good implicit bias (albeit with higher training time), and that is
the case studied here. Presumably low LR precludes SGD jumping between different
basins since under general conditions this should require Ω(exp(1/η)) steps [189].
In other words, there should be a mechanism to reach better generalization while
staying within a single basin. For deterministic GD similar mechanisms have been
demonstrated in simple cases [14, 16, 100] and referred to as implicit bias of gradient

306
descent. This chapter presents a study of implicit bias of Stochastic GD, which turns
out to be quite different, mathematically.
Recent work [165] shed light on this direction by analyzing effects of stochasticity
in the gradient. For sufficiently small LR, SGD will reach and be trapped around some
manifold of local minimizers, denoted by Γ (see Figure 8.2). The effect is shown to be
an implicit deterministic drift in a direction corresponding to lowering a regularizer-like
term along the manifold. They showed SGD with label noise locally decreases the
sharpness of loss, tr[∇2 L], by Θ(η 0.4 ) in η −1.6 steps. However, such an analysis is
actually local, since the natural time scale of analysis should be η −2 , not η −1.6 .
The contribution of this chapter is a more general and global analysis of this type.
We introduce a more powerful framework inspired by the classic paper [187].

8.1.1 Intuitive explanation of regularization effect due to

SGD

We start with an intuitive description of the implicit regularization effect described


in [165]. For simplification, we show it for the canonical SDE approximation (See
Section 8.8.1 for more details) of SGD (8.1) [190, 191]. Here W (t) is the standard
Ξ-dimensional Brownian motion. The only property about label noise SGD we will
use is that the noise covariance σσ > (x) = ∇2 L(x) for every x in the manifold Γ (See
derivation in Section 8.6).

eη (t) = −η∇L(X
dX eη (t))dt + η · σ(X
eη (t))dW (t). (8.2)

eη (0) is already close to some local minimizer point X∗ ∈ Γ. The goal is


Suppose X
eη (t) will move in the tangent space and steadily decrease tr[∇2 L]. At first
to show X
glance, this seems impossible as the gradient ∇L vanishes around Γ, and the noise
has zero mean, implying SGD should be like random walk instead of a deterministic

307
(a) Taylor Expansion of ∇L (b) Normal Space Dynamics (c) Tangent Space Dynamics

Figure 8.1: Illustration for limiting flow in R2 . Γ is an 1D manifold of minimizers of


loss L.

drift. The key observation of Blanc et al. [165] is that the local dynamics of X
eη (t) is

completely different in tangent space and normal space — the fast random walk in
eη (t) to move slowly (with velocity Θ(η 2 )) but deterministically
normal space causes X
eη (t) − X∗ , Taylor expansion
in certain direction. To explain this, letting ∆(t) = X
of (8.2) gives d∆(t) ≈ −η∇2 L(X∗ )∆dt + ησ(X∗ )dW (t), meaning ∆ is behaving like
an Ornstein-Uhlenbeck (OU) process locally in the normal space. Its mixing time is
Θ(η −1 ) and the stationary distribution is the standard multivariate gaussian in the

normal space scaled by η (see Figure 8.1b), because noise covariance σσ > = ∇2 L.
Though this OU process itself doesn’t form any regularization, it activates the second
order Taylor expansion of ∇L(X∗ + ∆(t)), i.e., − 12 ∂ 2 (∇L)(X∗ )[∆(t), ∆(t)], creating a
Θ(η 2 ) velocity in the tangent space. Since there is no push back force in the tangent
space, the small velocity accumulates over time, and in a longer time scale of Ω(η −1 ),
the time average of the stochastic velocity is roughly the same as the expected velocity
when ∆ is sampled from its stationary distribution. This simplifies the expression of
η2
the velocity in tangent space to 2
∇T tr[∇2 L] (see Figure 8.1c), where ∇T means the
gradient is only taken in the tangent space.
However, the above approach only gives a local analysis for O(η −1.6 ) time, where
the total movement due to implicit regularization is O(η 2−1.6 ) = O(η 0.4 ) and thus is
negligible when η → 0. In order to get a non-trivial limiting dynamics when η → 0, a

308
global analysis for Ω(η −2 ) steps is necessary and it cannot be done by Taylor expansion
with a single reference point. Recent work by Damian et al. [166] glues analyses of
multiple local phases into a global guarantee that SGD finds a (, γ)-stationary point
for the regularized loss, but still doesn’t show convergence for trajectory when η → 0
and cannot deal with general noise types, e.g., noise lying in the tangent space of
the manifold. The main technical difficulty here is that it’s not clear how to separate
the slow and fast dynamics in different spaces and how to only take limit for the
slow dynamics, especially when shifting to a new reference point in the Taylor series
calculation.

8.1.2 Our Approach: Separating the Slow from the Fast

In this work, we tackle this problem via a different angle. First, since the anticipated
limiting dynamics is of speed Θ(η 2 ), we change the time scaling to accelerate (8.2) by
η −2 times, which yields

dXη (t) = −η −1 ∇L(Xη (t))dt + σ(Xη (t))dW (t). (8.3)

The key idea here is that we only need to track the slow dynamic, or equivalently,
some projection of X onto the manifold Γ, Φ(X). Here Φ : RD → Γ is some function
to be specified and hopefully we can simplify the dynamics (8.3) via choosing suitable
Φ. To track the dynamics of Φ(Xη ), we apply Ito’s lemma (a.k.a. stochastic chain
rule, see Lemma 8.4.10) to Equation (8.3), which yields

dΦ(Xη (t)) = −η −1 ∂Φ(Xη (t))∇L(Xη (t))dt + ∂Φ(Xη (t))σ(Xη (t))dW (t)


1 XD
+ ∂ij Φ(Xη (t))(σ(Xη (t))σ(Xη (t))> )ij dt.
2 i,j=1

309
Note the first term −η −1 ∂Φ(Xη )∇L(Xη ) is going to diverge to ∞ when η → 0, so a
natural choice for Φ is to kill the first term. Further note −∂Φ(X)∇L(X) is indeed
the directional derivative of Φ at X towards −∇L, killing the first term becomes
equivalent to making Φ invariant under Gradient Flow (GF) of −∇L(X)! Thus it
suffices to take Φ(X) to be the limit of GF starting at X. (Formally defined in
Section 8.3; see Lemma 8.9.2 for a proof of ∂Φ(X)∇L(X) ≡ 0.)
Also intuitively Xη will be infinitely close to Γ, i.e., d(Xη (t), Γ) → 0 for any t > 0
as η → 0, so we have Φ(Xη ) ≈ Xη . Thus we can rewrite the above equation as

1 XD
dXη (t) ≈ ∂Φ(Xη (t))σ(Xη (t))dW (t) + ∂ij Φ(Xη (t))(σ(Xη (t))σ(Xη (t))> )ij dt,
2 i,j=1

(8.4)

and the solution of (8.4) shall converge to that of the following (in an intuitive sense):

1 XD
dX(t) = ∂Φ(X(t))σ(X(t))dW (t) + ∂ij Φ(X(t))(σ(X(t))σ(X(t))> )ij dt,
2 i,j=1

(8.5)

The above argument for SDE was first formalized and rigorously proved by Katzen-
berger [187]. It included an extension of the analysis to the case of asymptotic
continuous dynamics (Theorem 8.5.2) including SGD with infinitesimal LR, but the
result is weaker in this case and no convergence is shown. Another obstacle for
applying this analysis is that 2nd order partial derivatives of Φ are unknown. We
solve these issues in Section 8.5 and our main result Theorem 8.5.7 gives a clean and
complete characterization for the implicit bias of SGD with infinitesimal LR in Θ(η −2 )
steps. Finally, our Corollary 8.6.2 shows (8.5) gives exactly the same regularization as
tr[∇2 L] for label noise SGD.
The main contributions of this chapter are summarized as follows.

310
1. In Section 8.5, we propose a mathematical framework to study the implicit bias
of SGD with infinitesimal LR. Our main theorem (Theorem 8.5.7) gives the
limiting diffusion of SGD with LR η for Θ(η −2 ) steps as η → 0 and allows any
covariance structure.

2. In Section 8.6, we give limiting dynamics of SGD with isotropic noise and label
noise.

3. In Section 8.7, we show for any initialization, SGD with label noise achieves
O(κ ln d) sample complexity for learning a κ-sparse overparametrized linear
model [20]. In this case, the implicit regularizer is a data-dependent weighted
`1 regularizer, meaning noise can help reduce the norm and even escape the
kernel regime. The O(κ ln d) rate is minimax optimal [192] and improves over
e 2 ) upper bound by HaoChen et al. [21]. In contrast, vanilla GD requires
O(κ
Ω(d) samples to generalize in the kernel regime.
For technical contributions, we rigorously prove the convergence of GF for OLM
(Lemma 8.7.3), unlike many existing implicit bias analyses which have to assume
the convergence. We also prove the convergence of limiting flow to the global
minimizer of the regularizer (Lemma 8.7.5) by a trajectory analysis via our
framework. It cannot be proved by previous results [165, 166], as they only
assert convergence to stationary point in the best case.

8.2 Related Works

Loss Landscape of Overparametrized Models A phenomenon known as mode


connectivity has been observed that local minimizers of the loss function of a neural
network are connected by simple paths [193–195], especially for overparametrized
models [196–199]. Later this phenomanon is explained under generic assumptions
by Kuditipudi et al. [200]. Moreover, it has been proved that the local minimizers
311
of an overparametrized network form a low-dimensional manifold [201, 202] which
possibly has many components. Fehrman et al. [203] proved the convergence rate of
SGD to the manifold of local minimizers starting in a small neighborhood.

Implicit Bias in Overparametrized Models Algorithmic regularization has


received great attention in the community [25, 29, 95, 100, 105, 105, 112, 142]. In
particular, the SGD noise is widely believed to be a promising candidate for explaining
the generalization ability of modern neural networks [9, 204–207]. Beyond the size of
noise [164, 188], the shape and class of the noise also play an important role [208, 209].
It is shown by HaoChen et al. [21] that parameter-dependent noise will bias SGD
towards a low-complexity local minimizer. Similar implicit bias has also been studied
for overparametrized nonlinear statistical models by Fan et al. [210]. Several existing
works [20, 134, 211] have shown that for the quadratically overparametrized linear
2 2
model, i.e., w = u −v or w = u v, gradient descent/flow from small initialization
implicitly regularizes `1 norm and provides better generalization when the groundtruth
is sparse. This is in sharp contrast to the kernel regime, where neural networks trained
by gradient descent behaves like kernel methods [15, 155, 212]. This allows one to
prove convergence to zero loss solutions in overparametrized settings [149–153, 213],
where the learnt function minimizes the corresponding RKHS norm [148, 154].

Modelling Stochastic First-Order Methods with Itô SDE Apart from the
discrete-time analysis, another popular approach to study SGD is through the
continuous-time lens using SDE [190, 191, 214]. Such an approach is often more
elegant and can provide fruitful insights like the linear scaling rule [215, 216] and
the intrinsic learning rate [50]. A recent work by Li et al. [217] justifies such SDE
approximation. Xie et al. [218] gave a heuristic derivation explaining why SGD favors
flat minima with SDE approximation. Wojtowytsch [219] showed that the invariant
distribution of the canonical SDE approximation of SGD will collapse to some manifold
312
of minimizers and in particular, favors flat minima. By approximating SGD using a
SDE with slightly modified covariance for the overparametrized linear model, Pesme
et al. [176] relates the strength of implicit regularization to training speed.

8.3 Notations

Given loss L, the GF governed by L can be described through a mapping φ : RD ×


Rt
[0, ∞) → RD satisfying φ(x, t) = x − 0 ∇L(φ(x, s))ds. We further denote the limiting
mapping Φ(x) = limt→∞ φ(x, t) whenever the limit exists. We denote 1ξ ∈ RΞ as the
one-hot vector where ξ-th coordinate is 1, and 1 the all 1 vector. For any integer
k, we denote C k as the set of the k times continuously differentiable functions. We
denote a ∧ b = min{a, b}. For any vector u, v and α ∈ R, we define [u v]i = ui vi
and [v α
]i = viα . For any matrix A, we denote its pseudo inverse by A† . For mapping
F : RD → RD , we denote the Jacobian of F at x by ∂F (x) ∈ RD×D where the (i, j)-th
entry is ∂j Fi (x). We also use ∂F (x)[u] and ∂ 2 F (x)[u, v] to denote the first and second
order directional derivative of F at x along the derivation of u (and v). We abuse
the notation of ∂ 2 F by viewing it a linear mapping defined on RD ⊗ RD ∼
2
= RD ,
in the sense that ∂ 2 F (x)[Σ] = D 2 D×D
P
i,j=1 ∂ F (x)[ei , ej ]Σij , for any Σ ∈ R . For any
submanifold Γ ⊂ RD and x ∈ Γ, we denote by Tx (Γ) the tangent space of Γ at x and
Tx⊥ (Γ) the normal space of Γ at x.

8.4 Preliminaries on Stochastic Processes

In this section, we review a few basics of stochastic processes that will be useful
for proving our results. We refer the reader to classics like Karatzas and Shreve
[220], Billingsley [221], Pollard [222] for more systematic derivations.
Throughout the rest of this section, let E be a Banach space equipped with norm
k · k, e.g., (R, | · |) and (RD , k · k2 ).
313
8.4.1 Càdlàg Function and Metric

Definition 8.4.1 (Càdlàg function). Let T ∈ [0, ∞]. A function g : [0, T ) → E is


càdlàg if for all t ∈ [0, T ) it is right-continuous at t and its left limit g(t−) exists. Let
DE [0, T ) be the set of all càdlàg function mapping [0, T ) into E. We also use DE [0, T )
to denote the set of all continuous function mapping [0, T ) into E. By definition,
CE [0, T ) ⊂ DE [0, T ).

Definition 8.4.2 (Continuity modulus). For any function f : [0, ∞) → E and any
interval I ⊆ [0, ∞), we define

ω(f ; I) = sup kf (s) − f (t)k.


s,t∈I

For any N ∈ N and θ > 0, we further define the continuity modulus of continuous f as

ωN (f, θ) = sup {ω(f ; [t, t + θ])}.


0≤t≤t+θ≤N

Moreover, the continuity modulus of càdlàg f ∈ DE [0, ∞) is defined as

 
0
ωN (f, θ) = inf max ω(f ; [ti−1 , ti ) : 0 ≤ t0 < · · · < tr = N, inf (ti − ti−1 ) ≥ θ .
i≤r i<r

Definition 8.4.3 (Jump). For any g ∈ DE [0, T ), we define the jump of g at t to be

∆g(t) = g(t) − g(t−).

For any δ > 0, we define hδ : [0, ∞) → [0, ∞) by




0
 if r ≤ δ
hδ (r) = .

1 − δ/r
 if r ≥ δ

314
We then further define Jδ : DRD [0, ∞) → DRD [0, ∞) [187] as

X
Jδ (g)(t) = hδ (k∆g(s)k)∆g(s). (8.6)
0<s≤t

Definition 8.4.4 (Skorokhod metric on DE [0, ∞)). For each finite T > 0 and each
pair of functions f, g ∈ DE [0, ∞), define dT (f, g) as the infimum of all those values of
δ for which there exist grids 0 ≤ t0 < t1 < · · · < tm and 0 < s0 < s1 < · · · < · · · < sm ,
with tk , sk ≥ T , such that |ti − si | ≤ δ for i = 0, . . . , k, and

kf (t) − g(s)k ≤ δ if (t, s) ∈ [ti , ti+1 ) × [si , si+1 )

for i = 0, . . . , k − 1. The Skorokhod metric on DE [0, ∞) is defined to be


X
d(f, g) = 2−T min{1, dT (f, g)}.
T =1

8.4.2 Stochastic Processes and Stochastic Integral

Let (Ω, F, {Ft }t≥0 , P) be a filtered probability space.

Definition 8.4.5 (Cross variation). Let X and Y be two {Ft }t≥0 -adapted stochastic
processes such that X has sample paths in DRD×e [0, ∞) and Y has samples paths
in DRe [0, ∞), then the cross variation of X and Y on (0, t], denoted by [X, Y ](t), is
defined to be the limit of

m−1
X
(X(ti+1 ) − X(ti ))(Y (ti+1 ) − Y (ti ))
i=0

315
in probability as the mesh size of 0 = t0 < t1 < · · · < tm = t goes to 0, if it exists.
Moreover, for Y itself, we write

e
X
[Y ] = [Yi , Yi ]
i=1

Definition 8.4.6 (Martingale). Let {X(t)}t≥0 be a {Ft }t≥0 -adapted stochastic process.
If for all 0 ≤ s ≤ t, it holds that

E[X(t) | Fs ] = X(s),

then X is called a martingale.

Definition 8.4.7 (Local martingale). Let {X(t)}t≥0 be a {Ft }t≥0 -adapted stochastic
process. If there exists a sequence of {Ft }t≥0 -stopping time, {τk }k≥0 , such that

• P[τk < τk+1 ] = 1, P[limk→∞ τk = ∞] = 1,

• and {X τk (t)}t≥0 is a {Ft }t≥0 -adapted martingale,

then X is called a local martingale.

Definition 8.4.8 (Semimartingale). Let {X(t)}t≥0 be a {Ft }t≥0 -adapted stochastic


process. If there exists a local martingale {M (t)}t≥0 and a càdlàg {Ft }t≥0 -adapted
process {A(t)}t≥0 with bounded total variation that X(t) = M (t) + A(t), then X is
called a semimartingale.

Definition 8.4.9 (Itô’s Stochastic Integral). If {X(t)}t≥0 and {Y (t)}t≥0 are adapted
stochastic processes, X has sample paths in DRd×e [0, ∞), sample paths in DRe [0, ∞)
Rt
and Y is a semimartingale, then the integral s XdY is defined, as the limit of
Pn−1
i=0 X(ri )(Y (ri+1 ) − Y (ri )) where s = r0 < r1 < . . . < rn = t, the limit being in

probability as the mesh size goes to 0. Standard results in stochastic calculus imply
that this limit exists. We call X the integrand and Y the integrator.
316
Since all deterministic process are adapted, the above definition of integral also
makes sense for deterministic functions and is a generalization of standard Riemman-
Stieltjes Integral. The difference is that in the above Itô’s Stochastic Integral we use
the left-end value of the integrand but the existence of Riemman-Stieltjes Integral
requires the limit exists for any point within the interval. When X and Y don’t jump
together, Riemman-Stieltjes Integral exists and coincides with the Itô’s Integral.

Lemma 8.4.10 (Itô’s Lemma). Let {X(t)}t≥0 be defined through the following Itô
drift-diffusion process:

dX(t) = µ(t)dt + σ(t)dW (t).

where {W (t)}t≥0 is the standard Brownian motion. Then for any twice differentiable
function f , it holds that

 
∂f 1
df (t, X(t)) = + (∇x f ) µt + tr[σ ∇x f σ] dt + (∇x f )> σ(t)dW (t).
> > 2
∂t 2

8.4.3 Weak Convergence for Stochastic Processes

Let (DE [0, ∞), A, d) be a metric space equipped with a σ-algebra A and the Skorokhod
metric defined in the previous subsection.
Let {Xn }n≥0 be a sequence of stochastic processes on a sequence of probability
spaces {(Ωn , Fn , Pn )}n≥0 such that each Xn has sample paths in DE [0, ∞). Also, let
X be a stochastic process on (Ω, F, P) with sample paths on DE [0, ∞).

Definition 8.4.11 (Weak convergence). A sequence of stochastic process{Xn }n≥0 is


said to converge in distribution or weakly converge to X (written as Xn ⇒ X) if and
only if for all A-measurable, bounded, and continuous function f : DE [0, ∞) → R, it

317
holds that

lim E [f (Xn )] = E [f (X)] . (8.7)


n→∞

Though we define weak convergence for a countable sequence of stochastic processes,


but it is still valid if we index the stochastic processes by real numbers, e.g., {Xη }η≥0 ,
and consider the weak convergence of Xη as η → 0. This is because the convergence in
(8.7) is for a sequence of real numbers, which is also well-defined if we replace limn→∞
by limη→0 .

Definition 8.4.12 (δ-Prohorov distance). Let δ > 0. For any two probability measures
P and Q on a metric space with metric d, let (X, Y ) be a coupling such that P is the
marginalized law of X and Q that of Y . We define

ρδ (P, Q) = inf{ > 0 : ∃(X, Y ), P[d(X, Y ) ≥ ] ≤ δ}.

Note this distance is not a metric because it does not satisfy triangle inequality.

Definition 8.4.13 (Prohorov metric). For any two probability measures P and Q on
a metric space with metric d, let (X, Y ) be a coupling such that P is the marginalized
law of X and Q that of Y . Denote the marginal laws of X and Y by L(X) and L(Y )
respectively. We define the Prohorov metric as

ρ(P, Q) = inf{ > 0 : ∃(X, Y ), L(X) = P, L(Y ) = Q, P[d(X, Y ) ≥ ] ≤ }.

It can be shown that Xn ⇒ X is equivalent to limn→∞ ρ(Xn , X) = 0.

Theorem 8.4.14 (Skorokhod Representation Theorem). Suppose Pn , n = 1, 2, . . .


and P are probability measures on E such that Pn ⇒ P . Then there is a probability
space (Ω, F, P) on which are defined E-valued random variables Xn , n = 1, 2, . . . and
X with distributions Pn and P respectively, such that limn→∞ Xn = X a.s.
318
The main convergence result in Katzenberger [187] (Theorem 8.8.8) are in the
sense of Skorokhod metric in Definition 8.4.4, which is harder to understand and
use compared to the more common uniform metric (Definition 8.4.15). However,
convergence in Skorokhod metric and uniform metric indeed coincide with each other
when the limit is in CRD [0, ∞), i.e., the continuous functions.

Definition 8.4.15 (Uniform metric on DE [0, ∞)). For each finite T > 0 and each
pair of functions f, g ∈ DE [0, T ), the uniform metric is defined to be

dU (f, g; T ) = sup kf (t) − g(t)k.


t∈[0,T )

The uniform metric on DE [0, ∞) is defined to be


X
dU (f, g) = 2−T min{1, dU (f, g; T )}.
T =1

Lemma 8.4.16 (Problem 7, Section 5, Pollard [222]). If Xn ⇒ X in the Skorokhod


sense, and X has sample paths in CRD [0, ∞), then Xn ⇒ X in the uniform metric.

Remark 8.4.17. We shall note the uniform metric defined above is weaker than
supt∈[0,∞) kf (t) − g(t)k. Convergence in the uniform metric on [0, ∞] defined in
Definition 8.4.15 is equivalent to convergence in the uniform metric on each compact
set [0, T ] for T ∈ N+ . The same holds for the Skorokhod topology.

8.5 Main Result: Limiting Diffusion of SGD on

Manifold of Minimizers

In this section, we first state our assumptions about the loss function in Section 8.5.1.
In Section 8.5.1 In Section 8.5.2 we recap the main result of Katzenberger [187]. In
Section 8.5.3 we derive the closed-form expressions of ∂Φ and ∂ 2 Φ. We present our
319
main result in Section 8.5.4. We remark that sometimes we omit the dependency on t
to make things clearer.

8.5.1 Key Assumptions on Manifold of Local Minimizers

Following Fehrman et al. [203], we make the following important assumption about
the loss function.

Assumption 8.5.1. Assume that the loss L : RD → R is a C 3 function, and that Γ


is a (D − M )-dimensional C 1 -submanifold of RD for some integer 0 ≤ M ≤ D, where
for all x ∈ Γ, x is a local minimizer of L and rank(∇2 L(x)) = M .

Let U be the sets of points starting from which, gradient flow w.r.t. loss L
converges to some point in Γ, that is, U := {x ∈ RD | Φ(x) exists and Φ(x) ∈ Γ}.
Assumption 9.5.1 implies that U is open and Φ is C 3 on U . (By Lemma 8.8.2)
When does such a manifold exist? The vast overparametrization in modern
deep learning is a major reason for the set of global minimizers to appear as a
Riemannian manifold (possibly with multiple connected components), instead of
isolated ones. Suppose all global minimizers interpolate the training dataset, i.e.,
∀x ∈ RD , L(x) = minx0 ∈RD L(x0 ) implies fi (x) = yi for all i ∈ [n], then by preimage
theorem [223], the manifold Γ := {x ∈ RD | fi (x) = yi , ∀i ∈ [n]} is of dimension
D − n if the Jacobian matrix [∇f1 (x), . . . , ∇fn (x)] has rank n for all x ∈ Γ. Note
this condition is equivalent to that NTK at x has full rank, which is very common in
literature.
The smoothness assumption is satisfied for networks with smooth activation
functions like tanh and GeLU [76].The assumption rank (∇2 L(x)) = M basically saies
∇2 L(x) always attains the maximal rank in the normal space of the manifold, which
ensures the differentiability of Φ and is crucial to our current analysis, though it’s not
clear if it is necessary.

320
8.5.2 Recap of Katzenberger’s Theorem

Let {An }n≥1 be a sequence of integrators, where each An : R → R is a non-decreasing


function with An (0) = 0. Let {Zn }n≥1 be a sequence of R|Ξ| -valued stochastic processes
defined on R. Given loss function L and noise covariance function σ, we consider the
following stochastic process:

Z t Z t
Xn (t) = X(0) + σ(Xn (s)dZn (s) + −∇L(Xn (s))dAn (s) (8.8)
0 0

In particular, when the integrator sequence {An }n≥1 increases infinitely fast,
meaning that ∀ > 0, inf t≥0 (An (t + ) − An (t)) → ∞ as n → ∞, we call (8.8) a
Katzenberger process.
One difficulty for directly studying the limiting dynamics of Xn (t) is that the
point-wise limit as n → ∞ become discontinuous at t = 0 if X(0) ∈
/ Γ. The
reason is that clearly limn→∞ Xn (0) = X(0), but for any t > 0, since {An }n≥1
increases infinitely fast, one can prove limn→∞ Xn (t) ∈ Γ! To circumvent this issue,
we consider Yn (t) = Xn (t) − φ(X(0), An (t)) + Φ(X(0)). Then for each n ≥ 1, we have
Yn (0) = Φ(X(0)) and limn→∞ Yn (t) = limn→∞ Xn (t). Thus Yn (t) has the same limit
on (0, ∞) as Xn (t), but the limit of the former is further continuous at t = 0.

Theorem 8.5.2 (Informal version of Theorem 8.8.8, Katzenberger 187). Suppose


the loss L, manifold Γ and neighborhood U satisfies Assumption 8.5.1 and ??. Let
{Xn }n≥1 be a sequence of Katzenberger process with {An }n≥1 , {Zn }n≥1 . Let Yn (t) =
Xn (t) − φ(X(0), An (t)) + Φ(X0 ). Under technical assumptions, it holds that if (Yn , Zn )
converges to some (Y, W ) in distribution, where {W (t)}t≥0 is the standard Brownian
motion, then Y stays on Γ and admits

Z t Z t
1 XD
Y (t) = Y (0) + ∂Φ(Y )σ(Y )dW (s) + ∂ij Φ(Y )(σ(Y )σ(Y )> )ij ds.
0 2 i,j=1 0

(8.9)
321
Indeed, SGD (8.1) can be rewritten into a Katzenberger process as in the following
lemma.

Lemma 8.5.3. Let {ηn }∞ n=1 be any positive sequence with limn→∞ ηn = 0, An (t) =
Pbt/η2 c √ i.i.d.
ηn bt/ηn2 c, and Zn (t) = ηn k=1n Ξ(1ξk − Ξ1 1), where ξ1 , ξ2 , . . . ∼ Unif([Ξ]). Then
with the same initialization Xn (0) = xηn (0) ≡ X(0), Xn (kηn2 ) defined by (8.8) is a
Katzenberger process and is equal to xηn (k) defined in (8.1) with LR equal to ηn for
all k ≥ 1. Moreover, the counterpart of (8.9) is

Z t Z t
1
Y (t) = Φ(X(0)) + ∂Φ(Y )σ(Y )dW (s) + ∂ 2 Φ(Y )[Σ(Y )]ds, (8.10)
0 2 0

where Σ ≡ σσ > and {W (t)}t≥0 is a Ξ-dimensional standard Brownian motion.

However, there are two obstacles preventing us from directly applying Theorem 8.5.2
to SGD. First, the stochastic integral in (8.10) depends on the derivatives of Φ, ∂Φ
and ∂ij Φ, but Katzenberger [187] did not give their dependency on loss L. To resolve
this, we explicitly calculate the derivatives of Φ on Γ in terms of the derivatives of L
in Section 8.5.3.
The second difficulty comes from the convergence of (Yn , Zn ) which we assume
as granted for brevity in Theorem 8.5.2. In fact, the full version of Theorem 8.5.2
(see Theorem 8.8.8) concerns the stopped version of Yn with respect to some compact
µ (K)
K ⊂ U , i.e., Yn n (t) = Yn (t ∧ µn (K)) where µn (K) is the stopping time of Yn
leaving K. As noted in Katzenberger [187], we need the convergence of µn (K) for
µ (K)
Yn n to converge, which is a strong condition and difficult to prove in our cases.
We circumvent this issue by proving Theorem 8.8.10, a user-friendly interface for the
original theorem in Katzenberger [187], and it only requires the information about the
limiting diffusion. Building upon these, we present our final result as Theorem 8.5.7.

322
8.5.3 Closed-Form expression of the limiting diffusion

We can calculate the derivatives of Φ by relating to those of L. Here the key observation
is the invariance of Φ along the trajectory of GF. The proofs of this section are deferred
into Section 8.9.

Lemma 8.5.4. For any x ∈ Γ, ∂Φ(x) is the orthogonal projection matrix onto tangent
space Tx (Γ).

To express the second-order derivatives compactly, we introduce the notion of


Lyapunov operator.

Definition 8.5.5 (Lyapunov Operator). For a symmetric matrix H, we define WH =


{Σ ∈ RD×D | Σ = Σ> , HH † Σ = Σ = ΣHH † } and Lyapunov Operator LH : WH →
WH as LH (Σ) = H > Σ + ΣH. It’s easy to verify L−1
H is well-defined on WH .

Lemma 8.5.6. Let x be any point in Γ and Σ = Σ(x) = σσ > (x) ∈ RD×D be the noise
covariance at x1 . Then Σ can be decomposed as Σ = Σk + Σ⊥ + Σk,⊥ + Σ⊥,k , where
Σk := ∂ΦΣ∂Φ, Σ⊥ := (ID − ∂Φ)Σ(ID − ∂Φ) and Σk,⊥ = Σ>
⊥,k = ∂ΦΣ(ID − ∂Φ) are the

noise covariance in tangent space, normal space and across both spaces, respectively.
Then it holds that

∂ 2 Φ[Σ] = −(∇2 L)† ∂ 2 (∇L) Σk −∂Φ∂ 2 (∇L) L−1


    2
 2 † 
∇2 L (Σ⊥ ) −2∂Φ∂ (∇L) (∇ L) Σ⊥,k .

(8.11)

8.5.4 Main Result

Now we are ready to present our main result. It’s a direct combination of Theo-
rem 8.8.10 and Lemma 8.5.6.

Theorem 8.5.7. Suppose the loss function L, the manifold of local minimizer Γ and
the open neighborhood U satisfy Assumption 8.5.1 and ??, and xη (0) = x(0) ∈ U
1
For notational convenience, we drop dependency on x.
323
for all η > 0. If SDE (8.12) has a global solution Y with Y (0) = x(0) and Y never
leaves U , i.e., P[Y (t) ∈ U, ∀t ≥ 0] = 1, then for any T > 0, xη (bT /η 2 c) converges in
distribution to Y (T ) as η → 0.

1 1
dY (t) = Σk2 (Y )dW (t) − ∇2 L(Y )† ∂ 2 (∇L)(Y ) Σk (Y ) dt
 
(8.12)
| {z } |2 {z }
Tangent Noise Tangent Noise Compensation
 
1 2
 2 †
 2
 −1 
− ∂Φ(Y ) 2 ∂ (∇L)(Y ) ∇ L(Y ) Σ⊥,k (Y ) + ∂ (∇L)(Y ) L∇2 L (Σ⊥ (Y )) dt,
2 | {z } | {z }
Mixed Regularization Normal Regularization

where Σ ≡ σσ > and Σk , Σ⊥ , Σ⊥,k are defined in Lemma 8.5.6.

Based on the above theorem, the limiting dynamics of SGD can be understood as
1/2
follows: (a) the Tangent Noise, Σk (Y )dW (t), is preserved, and the second term
of (8.12) can be viewed as the necessary Tangent Noise Compensation for the
limiting dynamics to stay on Γ. Indeed, Lemma 8.9.7 shows that the value of the second
term only depends on Γ itself, i.e., it’s same for all loss L which locally defines the
same Γ. (b) The noise in the normal space is killed since the limiting dynamics always
stay on Γ. However, its second order effect (Itô correction term) takes place as a vector
field on Γ, which induces the Noise Regularization and Mixed Regularization
term, corresponding to the mixed and normal noise covariance respectively.

Remark 8.5.8. In Section 8.8.4 we indeed prove a stronger version of Theorem 8.5.7
eη (t) = xη (bt/η 2 c),
that the sample paths of SGD converge in distribution, i.e., let x
then x
eη weakly converges to Y on [0, T ]. Moreover, we only assume the existence of
a global solution for ease of presentation. As long as there exists a compact K ⊆ Γ
such that Y stays in K on [0, T ] with high probability, Theorem 8.8.10 still provides
the convergence of SGD iterates (stopped at the boundary of K) before time T with
high probability.

324
8.6 Implications and Examples

In this section, we derive the limiting dynamics for two notable noise types, where we
fix the expected loss L and the noise distribution, and only drive η to 0. The proofs
are deferred into Section 8.10.

Type I: Isotropic Noise. Isotropic noise means Σ(x) ≡ ID for any x ∈ Γ [189].
The following theorem shows that the limiting diffusion with isotropic noise can be
viewed as a Brownian Motion plus Riemannian Gradient Flow with respect to the
pseudo-determinant of ∇2 L.

Corollary 8.6.1 (Limiting Diffusion for Isotropic Noise). If Σ ≡ ID on Γ, SDE (8.12)


is then

1 1
dY (t) = ∂Φ(Y )dW − ∇2 L(Y )† ∂ 2 (∇L)(Y ) [∂Φ(Y )] dt − ∂Φ(Y )∇(ln |∇2 L(Y )|+ )dt
| 2 {z } |2 {z }
Brownian Motion on Manifold Normal Regularization

(8.13)

|∇2 L(Y )+αID |


where |∇2 L(Y )|+ = limα→0 2
αD−rank(∇ L(Y ))
is the pseudo-determinant of ∇2 L(Y ).
|∇2 L(Y )|+ is also equal to the sum of log of non-zero eigenvalue values of ∇2 L(Y ).

Type II: Label Noise. When doing SGD for `2 -regression on dataset {(zi , yi )}ni=1 ,
adding label noise [165, 166] means replacing the true label at iteration k, yik , by a
i.i.d.
fresh noisy label yeik := yik + δk , where δk ∼ Unif{−δ, δ} for some constant δ > 0.
Then the corresponding loss becomes 12 (fik (x) − yeik )2 , where fik (x) is the output of
the model with parameter x on data zik . So the label noise SGD update is

xk+1 = xk − η/2 · ∇x (fik (xk ) − yik + δk )2 = xk − η(fik (xk ) − yik + δk )∇x fik (xk ).
(8.14)

325
Suppose the model can achieve the global minimum of the loss L(x) := 12 E[(fi (x)− yei )2 ]
at x∗ , then the model must interpolate the whole dataset, i.e., fi (x∗ ) = yi for all
i ∈ [n], and thus here the manifold Γ is a subset of {x ∈ RD | fi (x) = yi , ∀i ∈
[n]}. Here the key property of the label noise used in previous works is Σ(x) =
δ2
Pn > 2 2
n i=1 ∇x fi (x)∇x fi (x) = δ ∇ L(x). Lately, Damian et al. [166] further generalizes

the analysis to other losses, e.g., logistic loss and exponential loss, as long as they
satisfy Σ(x) = c∇2 L(x) for some constant c > 0.
In sharp contrast to the delicate discrete-time analysis in Blanc et al. [165] and
Damian et al. [166], the following corollary recovers the same result but with much
simpler analysis – taking derivatives is all you need. Under our framework, we no
longer need to do Taylor expansion manually nor carefully control the infinitesimal
variables of different orders together. It is also worth mentioning that our framework
immediately gives a global analysis of Θ(η −2 ) steps for SGD, far beyond the local
coupling analysis in previous works. In Section 8.7, we will see how such global analysis
allows us to prove a concrete generalization upper bound in a non-convex problem,
the overparametrized linear model [20, 21].

Corollary 8.6.2 (Limiting Flow for Label Noise). If Σ ≡ c∇2 L on Γ for some
constant c > 0, SDE (8.12) can be simplified into (8.15) where the regularization is
from the noise in the normal space.

dY (t) = −1/4 · ∂Φ(Y (t))∇ tr[c∇2 L(Y (t))]dt. (8.15)

Example: k-Phase Motor We also give an example with rigorous proof where the
implicit bias induced by noise in the normal space cannot be characterized by a fixed
regularizer, which was first discovered by Damian et al. [166] but was only verified via
experiments.

326
Note the normal regularization in both cases of label noise and isotropic noise
induces Riemmanian gradient flow against some regularizer, it’s natural to wonder if
the limiting flow induced by the normal noise can always be characterized by certain
regularizer. Interestingly, Damian et al. [166] answers this question negatively via
experiments in their Section E.2. We adapt their example into the following one, and
rigorously prove the limiting flow moves around a cycle at a constant speed and never
stops using our framework.
x1:2

Suppose dimension D = k + 2 ≥ 5. For each x ∈ RD , we decompose x = x3:D

where x1:2 ∈ R2 and x3:D ∈ RD−2 . Let Qθ ∈ R2×2 be the rotation matrix of angle θ, i.e.,
θ − sin θ
and the loss L(x) := 18 (kx1:2 k22 − 1)2 + 12 D
 j−3 2
Qθ = cos
P
sin θ cos θ j=3 (2 + hQα v, x1:2 i xj ,

where α = D−2
and v is any vector in R2 with unit norm. Here the manifold is given
by Γ := {x | L(x) = 0} = {x ∈ RD | x21 + x22 = 1, xj = 0, ∀j = 3, . . . , D}.
The basic idea is that we can add noise in the ‘auxiliary dimensions’ for j = 3, . . . , D
to get the regularization force on the circle {x21 + x22 = 1}, and the goal is to make the
vector field induced by the normal regularization always point to the same direction,
say anti-clockwise. However, this cannot be done with a single auxiliary dimension
because from the analysis for label noise, we know when L−1
∇2 L (Σ⊥ ) is identity, the

normal regularization term in Equation (8.12) has 0 path integral along the unit
circle and thus it must have both directions. The key observation here is that we
can align the magnitude of noise with the strength of the regularization to make the
path integral positive. By using k ≥ 3 auxiliary dimensions, we can further ensure
the normal regularization force is anti-clockwise and of constant magnitude, which is
reminiscent of how a three-phase induction motor works.

Lemma 8.6.3. Let Σ ∈ RD×D be given by Σij (x) = (1 + Qj−3


α v, Q−π/2 x1:2 )(2 +

hQj−3
α v, x1:2 i), if i = j ≥ 3 or 0 otherwise, then the solution of SDE (8.12) is the

following (8.16) , which implies that Y (t) moves anti-clockwise with a constant angular

327
speed of (D − 2)/2.

Y1:2 (t) = Qt(D−2)/2 Y1:2 (0) and Y3:D (t) ≡ 0. (8.16)

8.7 Provable Generalization Benefit with Label

Noise

In this section, we show provable benefit of label noise in generalization using our
framework (Theorem 8.8.8) in a concrete setting, the overparametrized linear models
(OLM) [20]. While the existing implicit regularization results for Gradient Flow often
relates the generalization quality to initialization, e.g., Woodworth et al. [20] shows
that for OLM, small initialization corresponds to the rich regime and prefers solutions
with small `1 norm while large initialization corresponds to the kernel regime and
prefers solutions with small `2 norm, our result Theorem 8.7.1 surprisingly proves
that even if an OLM is initialized in the kernel regime, label noise SGD can still
help it escape and then enter the rich regime by minimizing its weighted `1 norm.
When the groundtruth is κ-sparse, this provides a O(κ
e ln d) vs Ω(d) sample complexity

separation between SGD with label noise and GD when both initialized in the kernel
regime. Here d is the dimension of the groundtruth. The lower bound for GD in the
kernel regime is folklore, but for completeness, we state the result as Theorem 8.7.7 in
Section 8.7.3 and append its proof in Section 8.11.6.

Theorem 8.7.1. In the setting of OLM, suppose the groundtruth is κ-sparse and
n ≥ Ω(κ ln d) training data are sampled from either i.i.d. Gaussian or Boolean
distribution. Then for any initialization xinit (except a zero-measure set) and any
 > 0, there exist η0 , T > 0 such that for any η < η0 , OLM trained with label noise
SGD (8.14) with LR equal to η for bT /η 2 c steps returns an -optimal solution, with
probability of 1 − e−Ω(n) over the randomness of the training dataset.
328
The proof roadmap of Theorem 8.7.1 is the following:

1. Show Assumption 8.5.1 is satisfied, i.e., the set of local minimizers, Γ, is indeed
a manifold and the hessian ∇2 L(x) is non-degenerate on Γ (by Lemma 8.7.2);

2. Show ?? is satisfied, i.e., Φ(U ) ⊂ Γ (by Lemma 8.7.3);

3. Show the limiting flow (8.15) converges to the minimizer of the regularizer (by
Lemma 8.7.5);

4. Show the minimizer of the regularizer recovers the groundtruth (by Lemma 8.7.6).

Our setting is more general than HaoChen et al. [21], which assumes w∗ ∈ {0, 1}d
and their reparametrization can only express positive linear functions, i.e., w = u 2 .
e 2 ) rate is achieved with a delicate three phase LR schedule, while our
Their O(κ
O(κ ln d) rate only uses a constant LR.

i.i.d.
Setting: Let {(zi , yi )}i∈[n] be the training dataset where z1 , . . . , zn ∼ Unif({±1}d )
or N (0, Id ) and each yi = hzi , w∗ i for some unknown w∗ ∈ Rd . We assume that w∗ is
κ-sparse for some κ < d. Denote x = uv ∈ RD = R2d , and we will use x and (u, v)


exchangeably as the parameter of functions defined on RD in the sequel. For each


i ∈ [n], define fi (x) = fi (u, v) = zi> (u 2
− v 2 ). Then we fit {(zi , yi )}i∈[n] with an
overparametrized model through the following loss function:

Xn
L(x) = L(u, v) = 1
n
`i (u, v), where `i (u, v) = 21 (fi (u, v) − yi )2 . (8.17)
i=1

4
Pn zi u zi u >
 
It is straightforward to verify that ∇2 L(x) = n i=1 −zi v −zi v
, ∀x ∈ Γ. For
simplicity, we define Z = (z1 , . . . , zn )> ∈ Rn×d and Y = (y1 , . . . , yn )> ∈ Rn . Consider
the following manifold:

Γ = x = (u> , v > )> ∈ U : Z(u 2


− v 2) = Y , where U = (R \ {0})D .

(8.18)

329
We verify that the above loss function L and manifold Γ satisfy Assumption 8.5.1 by
Lemma 8.7.2, and that the neighborhood U and Γ satisfy ?? by Lemma 8.7.3.

Lemma 8.7.2. Consider the loss L defined in (8.17) and manifold Γ defined in (8.18).
If data is full rank, i.e., rank(Z) = n, then it holds that (a). Γ is a smooth manifold
of dimension D − n; (b). rank(∇2 L(x)) = n for all x ∈ Γ. In particular, rank(Z) = n
holds with probability 1 for Gaussian distribution and with probability 1−cd for Boolean
distribution for some constant c ∈ (0, 1).

Lemma 8.7.3. Consider the loss function L defined in (8.17), manifold Γ and its
dxt
open neighborhood defined in (8.18). For gradient flow dt
= −∇L(xt ) starting at any
x0 ∈ U , it holds that Φ(x0 ) ∈ Γ.

Remark 8.7.4. In previous works [20, 138], the convergence of gradient flow is
only assumed. Recently Pesme et al. [176] proved it for a specific initialization, i.e.,
uj = vj = α, ∀j ∈ [n] for some α > 0. Lemma 8.7.3 completely removes the technical
assumption.

Therefore, by the result in the previous section, the implicit regularizer on the
manifold is R(x) = tr(Σ(x)) = tr(δ 2 ∇2 L(x)). Without loss of generality, we take
δ = 1. Hence, it follows that

4 X D X n 2  2
R(x) = zi,j (uj + vj2 ). (8.19)
n j=1 i=1

The limiting behavior of label noise SGD is described by a Riemannian gradient flow
on Γ as follows:

dxt = −1/4 · ∂Φ(xt )∇R(xt )dt, with x0 = Φ(xinit ) ∈ Γ. (8.20)

The goal is to show that the above limiting flow will converge to the underlying
∗ 1/2 1/2
groundtruth x∗ = uv∗ where (u∗ , v ∗ ) = ([w∗ ]+ , [−w∗ ]+ ).
330
8.7.1 Limiting Flow Converges to Minimizers of Regularizer

In this subsection we show limiting flow (8.15) starting from anywhere on Γ converges
to the minimizer of regularizer R (by Lemma 8.7.5). The proof contains two parts:
(a) the limiting flow converges; (b) the limit point of the flow cannot be sub-optimal
stationary points. These are indeed the most technical and difficult parts of proving
the O(κ ln d) upper bound, where the difficulty comes from the fact that the manifold
Γ is not compact, and the stationary points of the limiting flow are in fact all located
on the boundary of Γ. However, the limiting flow itself is not even defined on the
boundary of the manifold Γ. Even if we can extend ∂Φ(·)∇R(·) continuously to entire
RD , the continuous extension is not everywhere differentiable.
Thus the non-compactness of Γ brings challenges for both (a) and (b). For (a), the
convergence for standard gradient flow is often for free, as long as the trajectory is
bounded and the objective is analytic or smooth and semialgebraic. The latter ensures
the so-called Kurdyka-Lojasiewicz (KL) inequality [224], which implies finite trajectory
length and thus the convergence. However, since our flow does not satisfy those nice
properties, we have to show that the limiting flow satisfies Polyak-Lojasiewicz condition
(a special case of KL condition) [225] via careful calculation (by Lemma 8.11.16).
For (b), the standard analysis based on center stable manifold theorem shows that
gradient descent/flow converges to strict saddle (stationary point with at least one
negative eigenvalue in hessian) only for a zero-measure set of initialization [119, 121].
However, such analyses cannot deal with the case where the flow is not differentiable
at the sub-optimal stationary point. To circumvent this issue, we prove the non-
convergence to sub-optimal stationary points with a novel approach: we show that for
any stationary point x, whenever there exists a descent direction of the regularizer
R at x, we can construct a potential function which increases monotonically along
the flow around x, while the potential function is equal to −∞ at x, leading to a
contradiction. (See proof of Lemma 8.7.5.)

331
Lemma 8.7.5. Let {xt }t≥0 ⊆ RD be generated by the flow defined in (8.20) with any
initialization x0 ∈ Γ. Then x∞ = limt→∞ xt exists. Moreover, x∞ = x∗ is the optimal
solution of (8.21).

8.7.2 Minimizer of the Regularizer Recovers the Sparse

Groundtruth
1
Pn 2 iid
Note n i=1 zi,j = 1 when zi,j ∼ Unif{−1, 1}, and we can show minimizing R(x) on
Γ, (8.21), is equivalent to finding the minimum `1 norm solution of Equation (8.17).
Standard results in sparse recovery imply that minimum `1 norm solution recovers
with the sparse groundtruth. The gaussian case is more complicated but still can be
proved with techniques from Tropp [226].

4 X d X n 2  2
minimize R(x) = zi,j (uj + vj2 ),
n j=1 i=1
(8.21)
2 2 ∗
subject to Z(u − v ) = Zw .

i.i.d.
Lemma 8.7.6. Let z1 , . . . , zn ∼ Unif({±1}d ) or N (0, Id ). Then there exist some
constants C, c > 0 such that if n ≥ Cκ ln d, then with probability at least 1 − e−cn , the
optimal solution of (8.21), (b
u, vb), is unique up to sign flips of each coordinate and
recovers the groundtruth, i.e., u
b 2
− vb 2
= w∗ .

8.7.3 Lower Bound for Gradient Descent in the Kernel

Regime

In this subsection we show GD needs at least Ω(d) samples to learn OLM, when
initialized in the kernel regime. This lower bound holds for all learning rate schedules
and numbers of steps. This is in sharp contrast to the O(κ
e ln d) sample complexity

upper bound of SGD with label noise. Following the setting of kernel regime in [20],
we consider the limit of u0 = v0 = α1, with α → ∞. It holds that fi (u0 , v0 ) = 0 and
332
∇fi (u0 , v0 ) = [αzi , −αzi ] for each i ∈ [n]. Standard convergence analysis for NTK
(Neural Tangent Kernel, Jacot et al. [15]) shows that upon convergence, the distance
traveled by parameter converges to 0, and thus the learned model shall converge
in function space, so is the generalization performance. For ease of illustration, we
directly consider the lower bound for test loss when the NTK is fixed throughout the
training.

i.i.d.
Theorem 8.7.7. Assume z1 , . . . , zn ∼ N (0, Id ) and yi = zi> w∗ , for all i ∈ [n].
Define the loss with linearized model as L(x) = ni=1 (fi (x0 ) + h∇fi (x0 ), x − x0 i − yi )2 ,
P

where x = uv and x0 = uv00 = α 11 . Then for any groundtruth w∗ , any learning


  

rate schedule {ηt }t≥1 , and any fixed number of steps T , the expected `2 loss of x(T )
is at least (1 − nd ) kw∗ k22 , where x(T ) is the T -th iterate of GD on L, i.e., x(t + 1) =
x(t) − ηt ∇L(x(t)), for all t ≥ 0.

8.8 Derivation for Limiting Diffusion of SGD

In this section, we give a complete derivation of the limiting diffusion of SGD. Here
we use ⇒ to denote the convergence in distribution. For any U ⊆ RD , we denote by
Ů its interior. For linear space S, we use S ⊥ to denote its orthogonal complement.
First, as mentioned in ??, we verify that the mapping Φ is C 2 in ??. In Section 8.8.1
we discuss how different time scalings could affect the coefficients in SDE (8.2) and
(8.3). Then we check the necessary conditions for applying the results in Katzenberger
[187] in Section 8.8.2 and recap the corresponding theorem for the asymptotically
continuous case in Section 8.8.3. Finally, we provide a user-friendly interface for
Katzenberger’s theorem in Section 8.8.4.

Lemma 8.8.1. If limn→∞,n∈N φ(x, n) exists, then Φ(x) also exists and Φ(x) =
limn→∞,n∈N φ(x, n).

333
Proof of Lemma 8.8.1. Suppose K ⊆ RD is a compact set and φ(x, t) ∈ K for all
t ∈ [0, T ], we have that

d k∇L(φ(x, t))k2
= −2∇L(φ(x, t))∇2 L(φ(x, t))∇L(φ(x, t)) ≤ 2ρK k∇L(φ(x, t))k2 ,
dt

where ρK denotes supx∈K k∇2 L(φ(x, t))k. This implies that k∇L(φ(x, t))k ≤
eρK T k∇L(x)k and kφ(x, t) − xk ≤ eρK T k∇L(x)k for all t ∈ [0, T ].
Now suppose limn→∞,n∈N φ(x, n) = x∗ , we know ∇L(x∗ ) and k∇L(φ(x, n))k must
converges to 0 due to the continuity of ∇L. Take any compact neighborhood of x∗ as the
above defined K, we know there exists N > 0, such that for all n > N , k∇L(Φ(x, n))k
and kφ(x, n) − x∗ k are small enough such that φ(x, n + δ) ∈ K for all δ ∈ [0, 1].
Therefore, we know that when t → ∞ as a real number, kφ(x, t) − φ(x, btc)k ≤
eρK k∇L(φ(x, btc))k → 0. This completes the proof.

Lemma 8.8.2. Let U = {x ∈ RD | Φ(x) exists and Φ(x) ∈ Γ}. If L is C k , then we


have that U is open and Φ is C k−1 on U .

Proof of Lemma 8.8.2. Since Γ is a (D−M ) dimensional manifold and rank(∇2 L(x)) =
M , it holds that for any x ∈ Γ, there exists a small open set containing x, Vx , such that
Γ ∩ Vx is the set of the stationary points of L in Vx , i.e., Γ ∩ Vx = {y ∈ Vx | ∇L(y) = 0}.
Thus Γ is the set of the stationary points of L in open set V := ∪x∈Γ Vx . We further
define f : RD → RD , f (x) = φ(x, 1), and we have Γ is the set of the fixed points of
mapping f . Since L is C k , ∇L is C k−1 and thus f is C k−1 . By Theorem 5.1 in [227],
we know that there is an open set N containing Γ and f ∞ (x) is well-defined and C 3 on
N with f ∞ (x) ∈ Γ for any x ∈ Γ, where f ∞ (x) := limn→∞ f n (x), f n (x) = f (f n−1 (x))
and f 1 (x) = f (x). By Lemma 8.8.1, we know that Φ(x) = f ∞ (x).
Since V ⊇ Γ is open and Φ(x) ∈ Γ for all x ∈ U , we know that there is a t > 0,
such that φ(x, t) ∈ V . Thus U = ∪t≥0 φ(V, −t) is a union of open sets, which is still

334
open. Moreover, Φ(x) = Φ(φ(x, t)) for each x ∈ U and some t > 0 with φ(x, t) ∈ V .
Since Φ is C 3 in V and φ(·, t) is C k−1 for any t, we conclude that Φ is C 3 in U .

8.8.1 Approximating SGD by SDE

Let’s first clarify how we derive the SDEs, (8.2) and (8.3), that approximate SGD
(8.1) under different time scalings. Recall W (t) is Ξ-dimensional Brownian motion
and that σ(X) : RD → RD×Ξ is a deterministic noise function. As proposed by [190],
one approach to approximate SGD (8.1) by SDE is to consider the following SDE:


dX(t) = −∇L(X(t))dt + ησ(X(t))dW (t),

where the time correspondence is t = kη, i.e., X(kη) ≈ xη (k).


Now rescale the above SDE by considering X(t)
e = X(tη), which then yields

e = dX(tη) = −∇L(X(tη))d(tη) + √
dX(t) ησ(X(tη))dW (tη)

= −η∇L(X(tη))dt + ησ(X(tη))dW (tη).

Now we define W 0 (t) = √1 W (tη),


η
and it’s easy to verify that W 0 (t) is also a Ξ-
d
dimensional brownian motion, which means W 0 = W , i.e., W and W 0 have the same
sample paths in CRd [0, ∞). Thus

e = −η∇L(X(tη))dt + ησ(X(tη))dW 0 (t)


dX(t)
0
= −η∇L(X(t))dt
e + ησ(X(t))dW
e (t),

where the time correspondence is t = k, i.e., X(k)


e ≈ xη (k). The above SDE is exactly
the same as (8.2).

335
Then, to accelerate the above SDE by η −2 times, let’s define X̄(t) = X(t/η
e 2
).
Then it follows that

2 2
dX̄(t) = dX(t/η
e ) = −η∇L(X(t/η
e ))dt/η 2 + ησ(X(t/η
e 2
))dW (t/η 2 )
1
= − ∇L(X̄(t))dt + σ(X̄(t))d ηW (t/η 2 )

η

d
Again note that ηW (t/η 2 ) = W (t) in sample paths and thus is also a Ξ-Brownian
motion. Here the time correspondence is t = kη 2 , i.e., evolving for constant time with
the above SDE approximates Ω(1/η 2 ) steps of SGD. In this way, we derive SDE (8.3)
in the main context.

8.8.2 Necessary Conditions

Below we collect the necessary conditions imposed on {Zn }n≥1 and {An }n≥1 in Katzen-
berger [187]. Recall that we consider the following stochastic process

Z t Z t
Xn (t) = X(0) + σ(Xn (s))dZn (s) − ∇L(Xn (s))dAn (s).
0 0

For any stopping time τ , the stopped process is defined as Xnτ (t) = Xn (t ∧ τ ). For any
compact K ⊂ U , we define the stopping time of Xn leaving K as λn (K) = inf{t ≥ 0 |
Xn (t−) ∈
/ K̊ or Xn (t) ∈
/ K̊}.

Condition 8.8.3. The integrator sequence {An }n≥1 is asymptotically continuous:


sup |An (t) − An (t−)| ⇒ 0 where An (t−) = lims→t− An (s) is the left limit of An at t.
t>0

Condition 8.8.4. The integrator sequence {An }n≥1 increases infinitely fast: ∀ > 0,
inf (An (t + ) − An (t)) ⇒ ∞.
t≥0

336
Condition 8.8.5 (Eq.(5.1), Katzenberger 187). For every T > 0, as n → ∞, it holds
that

sup k∆Zn (t)k2 ⇒ 0.


0<t≤T ∧λn (K)

Condition 8.8.6 (Condition 4.2, Katzenberger 187). For each n ≥ 1, let Yn be a


{Ftn }-semimartingale with sample paths in DRD [0, ∞). Assume that for some δ > 0
(allowing δ = ∞) and every n ≥ 1 there exist stopping times {τnm | m ≥ 1} and a
decomposition of Yn − Jδ (Yn ) into a local martingale Mn plus a finite variation process
Fn such that P[τnm ≤ m] ≤ 1/m, {[Mn ](t ∧ τnm ) + Tt∧τnm (Fn )}n≥1 is uniformly integrable
for every t ≥ 0 and m ≥ 1, and

 
lim lim sup P sup (Tt+γ (Fn ) − Tt (Fn )) >  = 0,
γ→0 n→∞ 0≤t≤T

for every  > 0 and T > 0, where Tt (·) denotes total variation on the interval [0, t].

Lemma 8.8.7. For SGD iterates defined using the notation in Lemma 8.5.3, the
sequences {An }n≥1 and {Zn }n≥1 satisfy Condition 8.8.3, 8.8.4, 8.8.5 and 8.8.6.

Proof of Lemma 8.8.7. Condition 8.8.3 is obvious from the definition of {An }n≥1 .
Next, for any  > 0 and t ∈ [0, T ], we have

t +  − ηn2  − ηn2
   
t+ t t
An (t + ) − An (t) = ηn · − η n · ≥ − = ,
ηn2 ηn2 ηn ηn ηn

which implies that inf 0≤t≤T (An (t + ) − An (t)) > /(2ηn ) for small enough ηn . Then
taking n → ∞ yields the Condition 8.8.4.

337
For Condition 8.8.5, note that


ηn Ξ(1ξk − 1 1) if t = k · ηn2 ,


Ξ
∆Zn (t) =

0
 otherwise.


Therefore, we have k∆Zn (t)k2 ≤ 2ηn Ξ for all t > 0. This implies that k∆Zn (t)k2 → 0
uniformly over t > 0 as n → ∞, which verifies Condition 8.8.5.
We proceed to verify Condition 8.8.6. By the definition of Zn , we know that
{Zn (t)}t≥0 is a jump process with independent increments and thus is a martingale.
Therefore, by decomposing Zn = Mn + Fn with Mn being a local martingale and Fn a
finite variation process, we must have Fn = 0 and Mn is Zn itself. It then suffices to
show that [Mn ](t ∧ τnm ) is uniformly integrable for every t ≥ 0 and m ≥ 1. Since Mn
is a pure jump process, we have

X X
[Mn ](t ∧ τnm ) = k∆Mn (s)k22 ≤ k∆Mn (s)k22
0<s≤t∧τnm 0<s≤t
2c
bt/ηn 2c
bt/ηn

  2
X 1 X
= ηn Ξ 1ξk − 1 ≤ 4Ξ ηn2 ≤ 4Ξt.
k=1
Ξ 2 k=1

This implies that [Mη ](t ∧ τηm ) is universally bounded by 4t, and thus [Mη ](t ∧ τηm ) is
uniformly integrable. This completes the proof.

Lemma 8.5.3. Let {ηn }∞ n=1 be any positive sequence with limn→∞ ηn = 0, An (t) =
Pbt/η2 c √ i.i.d.
ηn bt/ηn2 c, and Zn (t) = ηn k=1n Ξ(1ξk − Ξ1 1), where ξ1 , ξ2 , . . . ∼ Unif([Ξ]). Then
with the same initialization Xn (0) = xηn (0) ≡ X(0), Xn (kηn2 ) defined by (8.8) is a
Katzenberger process and is equal to xηn (k) defined in (8.1) with LR equal to ηn for
all k ≥ 1. Moreover, the counterpart of (8.9) is

Z t Z t
1
Y (t) = Φ(X(0)) + ∂Φ(Y )σ(Y )dW (s) + ∂ 2 Φ(Y )[Σ(Y )]ds, (8.10)
0 2 0

338
where Σ ≡ σσ > and {W (t)}t≥0 is a Ξ-dimensional standard Brownian motion.

Proof of Lemma 8.5.3. For any n ≥ 1, it suffices to show that given Xn (kηn2 ) = xηn (k),
we further have Xn ((k + 1)ηn2 ) = xηn (k + 1). By the definition of Xn (t) and note that
An (t), Zn (t) are constants on [kηn2 , (k + 1)ηn2 ), we have that Xn (t) = Xn (kηn2 ) for all
t ∈ [kηn2 , (k + 1)ηn2 ), and therefore

Xn ((k + 1)ηn2 ) − Xn (kηn2 )


Z (k+1)ηn2 Z 2
(k+1)ηn
=− ∇L(Xn (t))dAn (t) + σ(Xn (t))dZn (t)
2
kηn 2
kηn

= − ∇L(Xn (kηn2 ))(An ((k + 1)ηn2 ) − An (kηn2 )) + σ(Xn (kηn2 ))(Zn ((k + 1)ηn2 ) − Zn (kηn2 ))

= − ηn ∇L(Xn (kηn2 )) + ηn Ξσξk (Xn (kηn2 ))

= − ηn ∇L(xηn (k)) + ηn Ξσξk (xηn (k)) = xηn (k + 1) − xηn (k)

where the second equality is because An (t) and Zn (t) are constant on interval [kηn2 , (k +
1)ηn2 ). This confirms the alignment between {Xn (kηn2 )}k≥1 and {xηn (k)}k≥1 .
For the second claim, note that σ(x)EZn (t) ≡ 0 for all x ∈ RD , t ≥ 0 (since the
noise has zero-expectation) and that {Zn (t) − EZn (t)}t≥0 will converge in distribution
to a Brownian motion by the classic functional central limit theorem (see, for example,
Theorem 4.3.5 in Whitt [228]). Thus, the limiting diffusion of Xn as n → ∞ can be
obtained by substituting Z with the standard Brownian motion W in (8.23). This
completes the proof.

8.8.3 Katzenberger’s Theorem for Asymptotically Continu-

ous Case

The full Katzenberger’s theorem deals with a more general case, which only requires
the sequence of intergrators to be asymptotically continuous, thus including SDE (8.3)
and SGD (8.1) with η goes to 0.
339
To describe the results in Katzenberger [187], we first introduce some definitions.
For each n ≥ 1, let (Ωn , F n , {Ftn }t≥0 , P) be a filtered probability space, Zn an Re -
valued cadlag {Ftn }-semimartingale with Zn (0) = 0 and An a real-valued cadlag
{Ftn }-adapted nondecreasing process with An (0) = 0. Let σn : U → M(D, e) be
continuous with σn → σ uniformly on compact subsets of U . Let Xn be an RD -valued
càdlàg {Ftn }-semimartingale satisfying, for all compact K ⊂ U ,

Z t Z t
Xn (t) = X(0) + σ(Xn )dZn + −∇L(Xn )dAn (8.22)
0 0

for all t ≤ λn (K) where λn (K) = inf{t ≥ 0 | Xn (t−) ∈


/ K̊ or Xn (t) ∈
/ K̊} is the
stopping time of Xn leaving K.

Theorem 8.8.8 (Theorem 6.3, Katzenberger 187). Suppose X(0) ∈ U , Assump-


tion 8.5.1 and ??, Condition 8.8.3, 8.8.4, 8.8.5 and 8.8.6 hold. For any compact
K ⊂ U , define µn (K) = inf{t ≥ 0 | Yn (t−) ∈
/ K̊ or Yn (t) ∈
/ K̊}, then the sequence
µ (K) µ (K)
{(Yn n , Zn n , µn (K)} is relatively compact in DRD×e [0, ∞)×[0, ∞). If (Y, Z, µ) is a
limit point of this sequence under the skorohod metric (Definition 8.4.4), then (Y, Z) is
a continuous semimartingale, Y (t) ∈ Γ for every t ≥ 0 a.s., µ ≥ inf{t ≥ 0 | Y (t) ∈
/ K̊}
a.s. and Y (t) admits

Z t∧µ
Y (t) = Y (0) + ∂Φ(Y (s))σ(Y (s))dZ(s)
0
D e Z t∧µ
1 X X
+ ∂ij Φ(Y (s))σ(Y (s))ik σ(Y (s))jl d[Zk , Zl ](s). (8.23)
2 i,j=1 k,l=1 0

We note that by Lemma 8.4.16, convergence in distribution under skorohod metric


is equivalent to convergence in distribution under uniform metric Definition 8.4.15,
therefore in the rest of the paper we will only use the uniform metric in the rest of
the paper, e.g., whenever we mention Prohorov metric and δ-Prohorov distance, the
underlying metric is the uniform metric.

340
8.8.4 A User-friendly Interface for Katzenberger’s Theorem

Based on the Lemma 8.8.7, we can immediately apply Theorem 8.8.8 to obtain the
following limiting diffusion of SGD.

Theorem 8.8.9. Let the manifold Γ and its open neighborhood U satisfy Assump-
tion 8.5.1 and ??. Let K ⊂ U be any compact set and fix some x0 ∈ K. Consider the
SGD formulated in Lemma 8.5.3 where Xηn (0) ≡ x0 . Define

Yηn (t) = Xηn (t) − φ(Xηn (0), Aηn (t)) + Φ(Xηn (0))

µ (K)
/ K̊}. Then the sequence {(Yηnηn
and µηn (K) = min{t ∈ N | Yηn (t) ∈ , Zηn , µηn (K))}n≥1
is relatively compact in DRD ×Rn [0, ∞) × [0, ∞]. Moreover, if (Y, Z, µ) is a limit point
of this sequence, it holds that Y (t) ∈ Γ a.s for all t ≥ 0, µ ≥ inf{t ≥ 0 | Y (t) ∈
/ K̊}
and Y (t) admits

Z t∧µ Z t∧µ D
1X
Y (t) = ∂Φ(Y (s))σ(Y (s))dW (s) + ∂ij Φ(Y (s))(σ(Y (s))σ(Y (s))> )ij ds
s=0 s=0 2 i,j=1

(8.24)

where {W (s)}s≥0 is the standard Brownian motion and σ(·) is as defined in


Lemma 8.5.3.

However, the above theorem is hard to parse and cannot be directly applied if we
want to further study the implicit bias of SGD through this limiting diffusion. There-
fore, we develop a user-friendly interface to it in below. In particular, Theorem 8.5.7
is the a special case of Theorem 8.8.10. In Theorem 8.5.7, we replace ∂Φ(Y (t))σ(Y (t))
1
with Σk2 (Y (t)) to simplify the equation, since ∂Φ(Y (t))σ(Y (t)) (∂Φ(Y (t))σ(Y (t)))> =
Σk (Y (t)) and thus this change doesn’t affect the distribution of the sample paths of
the solution.

341
Theorem 8.8.10. Under the same setting as Theorem 8.8.9, we change the integer
index back to η > 0 with a slight abuse of notation. For any stopping time µ and
stochastic process {Y (t)}t≥0 such that µ ≥ inf{t ≥ 0 | Y (t) ∈
/ K̊}, Y (0) = Φ(x0 )
and that (Y, µ) satisfy Equation (8.24) for some standard Brownian motion W . For
any compact set K ⊆ U and T > 0, define µ(K) = inf{t ≥ 0 | Y (t) ∈
/ K̊} and
δ = P(µ(K) ≤ T ). Then for any  > 0, it holds for all sufficiently small LR η that:

ρ2δ (Yηµη (K)∧T , Y µ(K)∧T ) ≤ , (8.25)

which means there is a coupling between the distribution of the stopped processes
µ (K)∧T
Yη η and Y µ(K)∧T , such that the uniform metric between them is smaller than 
µ (K)∧T
with probability at least 1 − 2δ. In other words, limη→0 ρ2δ (Yη η , Y µ(K)∧T ) = 0.
Moreover, when {Y (t)}t≥0 is a global solution to the following limiting diffusion

Z t Z t D
1X
Y (t) = ∂Φ(Y (s))σ(Y (s))dW (s) + ∂ij Φ(Y (s))(σ(Y (s))σ(Y (s))> )ij ds
s=0 s=0 2 i,j=1

and Y never leaves U , i.e. P[∀t ≥ 0, Y (t) ∈ U ] = 1, it holds that YηT converges in
distribution to Y T as η → 0 for any fixed T > 0.

For clarity, we break the proof of Theorem 8.8.10 into two parts, devoted to the
two claims respectively.

Proof of the first claim of Theorem 8.8.10. First, Theorem 8.8.9 guarantees there ex-
e and a stochastic process {Ye (t)}t≥0 such that
ists a stopping time µ

1. (Ye , µ
e) satisfies Equation (8.24);

2. Ye ∈ Γ a.s.;

e≥µ
3. µ e(K) := inf{t ≥ 0 | Ye (t) ∈
/ K̊}.

342
The above conditions imply that Ye µe(K) ∈ Γ a.s.. Since the coefficients in Equa-
d
e(K)) = (Y µ(K) , µ(K)). To
tion (8.24) are locally Lipschitz, we claim that (Ye µe(K) , µ
see this, note that for any compact K ⊆ U , the noise function σ, ∂Φ and ∂ 2 Φ are all
Lipschitz on K, thus we can extend their definitions to RD such that the resulting
functions are still locally Lipschitz. Based on this extension, applying classic theorem
on weak uniqueness (e.g., Theorem 1.1.10, Hsu 229) to the extended version of Equa-
tion (8.24) yields the equivalence in law. Thus we only need to prove the first claim
for Ye .
Let ET be the event such that µ
e(K) > T on ET . Then restricted on ET , we have
Ye (T ∧ µ
e) = Ye (T ∧ µ e≥µ
e(K)) as µ e(K) holds a.s. We first prove the claim for any
convergent subsequence of {Yη }η>0 .
µ (K)
Now, let {ηm }m≥1 be a sequence of LRs such that ηm → 0 and Yηmηm ⇒ Ye µe as
m → ∞. By applying the Skorohod representation theorem, we can put {Yηm }m≥1
µ (K)
and Ye under the same probability space such that Yηmηm → Ye µe a.s. in the Skorohod
metric, or equivalently the uniform metric (since Ye µe is continuous) i.e.,

dU (Yηµmηm (K) , Ye µe ) → 0, a.s.,

which further implies that for any  > 0, there exists some N > 0 such that for all
m > N,

h i
P dU (Yηµmηm (K)∧T , Ye µe∧T ) ≥  ≤ δ.

343
µ (K)∧T µ (K)∧T
Restricted on ET , we have dU (Yηmηm , Ye µe∧T ) = dU (Yηmηm , Ye µe(K)∧T ), and it
follows that for all m > N ,

h i h i
P dU (Yηµmηm (K)∧T , Ye µe(K)∧T ) ≥  ≤ P {dU (Yηm µηm (K)∧T e µ
,Y e(K)∧T
) ≥ } ∩ ET + P [ETc ]
h i
µηm (K)∧T e µ e∧T
= P {dU (Yηm ,Y ) ≥ } ∩ ET + P[ETc ]
h i
≤ P dU (Yηµmηm (K)∧T , Ye µe∧T ) ≥  + P[ETc ]

≤ 2δ,

where we denote the complement of ET by ETc .


By the definition of the Prohorov metric in Definition 8.4.13, we then get
µ (K)∧T
ρ2δ (Yηmηm , Ye µe(K)∧T ) ≤  for all m > N . Therefore, we have

lim ρ2δ (Yηµmηm (K)∧T , Ye µe(K)∧T ) = 0.


m→∞

µ (K)∧T
Now we claim that it indeed holds that limη→0 ρ2δ (Yη η , Ye µe(K)∧T ) = 0. We
prove this by contradiction. Suppose otherwise, then there exists some  > 0 such that
µ (K)∧T
for all η0 > 0, there exists some η < η0 with ρ2δ (Yη η , Ye µe(K)∧T ) > . Consequently,
µ (K)
there is a sequence {ηm }m≥1 satisfying limm→∞ ηm = 0 and ρ2δ (Yηmηm , Ye µe(K)∧T ) > 
µ (K)∧T
for all m. Since {(Yηmηm , Zηm , µηm (K))}m≥1 is relatively compact, there ex-
ists a subsequence (WLOG, assume it is the original sequence itself) converging
to (Ye µe∧T , W, µ
e) in distribution. However, repeating the exactly same argument as
µ (K)∧T
above, we would have ρ2δ (Yηmηm , Ye µe(K)∧T ) ≤  for all sufficiently large m, which
is a contradiction. This completes the proof.

Proof of the second claim of Theorem 8.8.10. We will first show there exists a se-
quence of compact set {Km }m≥1 such that ∪∞
m=1 Km = U and Km ⊆ Km+1 . For

m ∈ N+ , we define Hm = U \ (B1/m (0) + RD \ U ) and Km = Hm ∩ Bm (0). By


definition it holds that ∀m < m0 , Hm ⊆ Hm0 and Km ⊆ Km0 . Moreover, since Km

344
is bounded and closed, Km is compact for every m. Now we claim ∪∞
m=1 Km = U .

Note that ∪∞ ∞ ∞
m=1 Km = ∪m=1 Hm ∩ Bm (0) = ∪m=1 Hm . ∀x ∈ U , since U is open,

we know dU (x, RD \ U ) > 0, thus there exists m0 ∈ N+ , such that ∀m ≥ m0 ,


/ (B1/m (0) + RD \ U ) and thus x ∈ Hm , which implies x ∈ ∪∞
x ∈ m=1 Hm . On the

other hand, ∀x ∈ RD \ U , it holds that x ∈ (B1/m (0) + RD \ U ) for all m ∈ N+ , thus


x∈
/ Hm ⊂ Km .
Therefore, since Y ∈ U and is continuous almost surely, random variables
limm→∞ µ(Km ) = ∞ a.s., which implies µ(Km ) converges to ∞ in distribution, i,e,,
∀δ > 0, T > 0, ∃m ∈ N+ , such that ∀K ⊇ Km , it holds P[µ(K) ≤ T ] ≤ δ.
Now we will show for any T > 0 and  > 0, there exists η0 such that ρ (Y T , YηT ) ≤ 
for all η ≤ η0 . Fixing any T > 0, for any  > 0, let δ = 4 , then from above we know
exists compact set K, such that P(µ(K) ≤ T ) ≤ δ. We further pick K 0 = K + B20 (0),
where 0 can be any real number satisfying 0 < 0 <  and K 0 ⊆ U . Such 0 exists since
U is open. Note K ⊆ K 0 , we have P(µ(K 0 ) ≤ T ) ≤ P(µ(K) ≤ T ) ≤ δ. Thus by the
first claim of Theorem 8.8.10, there exists η0 > 0, such that for all η ≤ η0 , we have
µ (K 0 )∧T 0
ρ2δ (Yη η , Y µ(K )∧T ) ≤ 2−dT e 0 .
0
Note that ρδ (Y µ(K)∧T , Y µ(K )∧T ) = 0, so we have for all η ≤ η0 ,

0
ρ3δ (Y µ(K)∧T , Yηµη (K )∧T ) ≤ 2−dT e 0 .

By the definition of δ-Prohorov distance in Definition 8.4.12, we can assume


h i
µ (K 0 )∧T µ (K 0 )∧T
(Y µ(K)∧T , Yη η ) is already the coupling such that P dU (Y µ(K)∧T , Yη η ) ≥ 2−dT e 0 ≤
3δ. Below we want to show ρ3δ (Y µ(K)∧T , YηT ) ≤ 2−dT e 0 . Note that for all t ≥ 0,

345
Y µ(K)∧T (t) ∈ K, thus we know if µη (K 0 ) ≤ T , then

0 0
dU (Y µ(K)∧T , Yηµη (K )∧T ) ≥ 2−dT e Y µ(K)∧T (µη (K 0 )) − Yηµη (K )∧T (µη (K 0 )
2

≥ 2−dT e dU (K, Rd /K 0 )

≥ 2−dT e 0 .

µ (K 0 )∧T
On the other hand, if µη (K 0 ) > T , then YηT = Yη η . Thus we can conclude
µ (K 0 )∧T
that dU (Y µ(K)∧T , YηT ) ≥ 2−dT e 0 implies dU (Y µ(K)∧T , Yη η ) ≥ 2−dT e 0 . Therefore,
we further have

h 0
i
P dU (Y µ(K)∧T , YηT ) ≥ 2−dT e 0 ≤ P dU (Y µ(K)∧T , Yηµη (K )∧T ) ≥ 2−dT e 0 ≤ 3δ,
 

that is,

ρ3δ (Y µ(K)∧T , YηT ) ≤ 2−dT e 0 .

Finally, since ρδ (Y T , Y µ(K)∧T ) = 0, we have for all η ≤ η0 ,

ρ (Y T , YηT ) = ρ4δ (Y T , YηT ) ≤ ρ3δ (Y µ(K)∧T , YηT ) + ρδ (Y T , Y µ(K)∧T ) ≤ 2−dT e 0 + 0 ≤ ,

which completes the proof.

Now, we provide the proof of Theorem 8.5.7 as a direct application of Theo-


rem 8.8.10.

Proof of Theorem 8.5.7. We first prove that Y never leaves Γ, i.e., P[Y (t) ∈ Γ, ∀t ≥
0] = 1. By the result of Theorem 8.8.9, we know that for each compact set K ⊂ Γ,
Y µ(K) stays on Γ almost surely, where µ(K) := inf{t ≥ 0 | Ye (t) ∈
/ K̊} is the
earliest time that Y leaves K. In other words, for all compact set K ⊂ Γ, P[∃t ≥
0, Y (t) ∈
/ Γ, Y (t) ∈ K] = 0. Let {Km }m≥1 be any sequence of compact sets such that
346
∪m≥1 Km = U and Km ⊂ U , e.g., the ones constructed in the proof of the second claim
of Theorem 8.8.10. Therefore, we have


X
P[∃t ≥ 0, Y (t) ∈
/ Γ] = P[∃t ≥ 0, Y (t) ∈
/ Γ, Y (t) ∈ U ] ≤ P[∃t ≥ 0, Y (t) ∈
/ Γ, Y (t) ∈ Km ] = 0,
m=1

which means Y always stays on Γ.


Then recall the decomposition of Σ = Σk + Σ⊥ + Σk,⊥ + Σ⊥,k as defined in
Lemma 8.5.6. Since Y never leaves Γ, by Lemma 8.5.6, we can rewrite Equation (8.12)
as

1/2
dY (t) = Σk dW (t) + ∂ 2 Φ(Y (t))[Σ(Y (t))]dt
D
1X
= ∂Φ(Y (t))σ(Y (t))dW (t) + ∂ij Φ(Y (t))(σ(Y (t))σ(Y (t))> )ij dt
2 i,j=1

where the second equality follows from the definition that Σk = ∂ΦΣ∂Φ = ∂Φσσ > ∂Φ.
This coincides with the formulation of the limiting diffusion in Theorem 8.8.10.
Therefore, further combining Lemma 8.5.3 and the second part of Theorem 8.8.10, we
obtain the desired result.

Remark 8.8.11. Our result suggests that for tiny LR η, SGD dynamics have two
phases. In Phase I of Θ(1/η) steps, the SGD iterates move towards the manifold Γ
of local minimizers along GF. Then in Phase II which is of Θ(1/η 2 ) steps, the SGD
iterates stay close to Γ and diffuse approximately according to (8.12). See Figure 8.2
for an illustration of this two-phase dynamics. However, since the length of Phase I
gets negligible compared to that of Phase II when η → 0, Theorem 8.5.7 only reflects
the time scaling of Phase II.

347
Figure 8.2: Illustration for two-phase dynamics of SGD with the same example as in
Figure 8.1 . Γ is an 1D manifold of minimizers of loss L.

8.9 Explicit Formula of the Limiting Diffusion

In this section, we demonstrate how to compute the derivatives of Φ by relating to


those of the loss function L, and then present the explicit formula of the limiting
diffusion.

8.9.1 Explicit Expression of the Derivatives

For any x ∈ Γ, we choose an orthonormal basis of Tx (Γ) as {v1 , . . . , vD−M }. Let


{vD−M +1 , . . . , vD } be an orthonormal basis of Tx⊥ (Γ) so that {vi }i∈[D] is an orthonormal
basis of RD .

Lemma 8.9.1. For any x ∈ Γ and any v ∈ Tx (Γ), it holds that ∇2 L(x)v = 0.

Proof. For any x ∈ Tx (Γ), let {x(t)}t≥0 be a parametrized smooth curve on Γ such
dx(t) d∇L(xt )
that x(0) = x and dt t=0
= v. Then ∇L(xt ) = 0 for all t. Thus 0 = dt t=0
=
∇2 L(x)v.

Lemma 8.9.2. For any x ∈ RD , it holds that ∂Φ(x)∇L(x) = 0 and

∂ 2 Φ(x)[∇L(x), ∇L(x)] = −∂Φ(x)∇2 L(x)∇L(x).

348
dx(t)
Proof. Fixing any x ∈ RD , let dt
= −∇L(x(t)) be initialized at x(0) = x. Since
Φ(x(t)) = Φ(x) for all t ≥ 0, we have

d
Φ(x(t)) = −∂Φ(x(t))∇L(x(t)) = 0.
dt

Evaluating the above equation at t = 0 yields ∂Φ(x)∇L(x) = 0. Moreover, take the


second order derivative and we have

d2
 
2 dx(t) 2 dx(t)
Φ(x t ) = −∂ Φ(x(t)) , ∇L(x(t)) − ∂Φ(x(t))∇ L(x(t)) = 0.
dt2 dt dt

Evaluating at t = 0 completes the proof.

Now we can prove Lemma 8.5.4, restated in below.

Lemma 8.5.4. For any x ∈ Γ, ∂Φ(x) is the orthogonal projection matrix onto tangent
space Tx (Γ).

Proof of Lemma 8.5.4. For any v ∈ Tx (Γ), let {v(t), t ≥ 0} be a parametrized smooth
dv(t)
curve on Γ such that v(0) = x and dt t=0
= v. Since v(t) ∈ Γ for all t ≥ 0, we have
Φ(v(t)) = v(t), and thus

dv(t) d dv(t)
= Φ(v(t)) = ∂Φ(x) .
dt t=0 dt t=0 dt t=0

This implies that ∂Φ(x)v = v for all v ∈ Tx (Γ).


Next, for any u ∈ Tx⊥ (Γ) and t ≥ 0, consider expanding ∇L(x + t∇2 L(x)† u) at
t = 0:

∇L x + t∇2 L(x)† u = ∇2 L(x) · t∇2 L(x)† u + o(t)




= tu + o(t)

349
where the second equality follows from the assumption that ∇2 L(x) is full-rank when
restricted on Tx⊥ (Γ). Then since ∂Φ is continuous, it follows that

∂Φ(x + t∇2 L(x)† u)∇L(x + t∇2 L(x)† u)


lim = lim ∂Φ(x + t∇2 L(x)† )(u + o(1))
t→0 t t→0

= ∂Φ(x)u.

By Lemma 8.9.2, we have ∂Φ(x + t(∇2 L(x))† u))∇L(x + t(∇2 L(x))† u) = 0 for all t > 0,
which then implies that ∂Φ(x)u = 0 for all u ∈ Tx⊥ (Γ).
Therefore, under the basis {v1 , . . . , vD }, ∂Φ(x) is given by

 
ID−M 0 D×D
∂Φ(x) =  ∈R ,
0 0

that is, the projection matrix onto Tx (Γ).

Lemma 8.9.3. For any x ∈ Γ, it holds that ∂Φ(x)∇2 L(x) = 0.

Proof. It directly follows from Lemma 8.9.1 and Lemma 8.5.4.

Next, we proceed to compute the second-order derivatives.

Lemma 8.9.4. For any x ∈ Γ, u ∈ RD and v ∈ Tx (Γ), it holds that

∂ 2 Φ(x)[v, u] = −∂Φ(x)∂ 2 (∇L)(x)[v, ∇2 L(x)† u] − ∇2 L(x)† ∂ 2 (∇L)(x)[v, ∂Φ(x)u].

Proof of Lemma 8.9.4. Consider a parametrized smooth curve {v(t)}t≥0 on Γ such


dv(t)
that v(0) = x and dt t=0
= v. We define P (t) = ∂Φ(v(t)), P ⊥ (t) = ID − P (t) and
H(t) = ∇2 L(v(t)) for all t ≥ 0. By Lemma 8.9.1 and 8.5.4, we have

P ⊥ (t)H(t) = H(t)P ⊥ (t) = H(t), (8.26)

350
Denote the derivative of P (t), P ⊥ (t) and H(t) with respect to t as P 0 (t), (P ⊥ )0 (t) and
H 0 (t). Then differentiating with respect to t, we have

(P ⊥ )0 (t)H(t) = H 0 (t) − P ⊥ (t)H 0 (t) = P (t)H 0 (t). (8.27)

Then combining (8.26) and (8.27) and evaluating at t = 0, we have

P 0 (0)H(0) = −(P ⊥ )0 (0)H(0) = −P (0)H 0 (0)

We can decompose P 0 (0) and H(0) as follows

   
0 0
P11 (0) P12 (0) 0 0 
P 0 (0) =  , H(0) =  , (8.28)
0 0
P21 (0) P22 (0) 0 H22 (0)

0
where P11 (0) ∈ R(D−M )×(D−M ) and H22 is the hessian of L restricted on Tx⊥ (Γ). Also
note that
   
0 0
ID−M 0 H11 (0) H12 (0) 0 0 
P (0)H 0 (0)P ⊥ (0) =    
0 0
0 0 H21 (0) H22 (0) 0 IM
 
0
0 H12 (0)
= ,
0 0

and thus by (8.28) we have

   
0 0
0 P12 (0)H22 (0) 0 −H12 (0)
P 0 (0)H(0) =  =
  .
0
0 P22 (0)H22 (0) 0 0

0 0 0
This implies that we must have P22 (0) = 0 and P12 (0)H22 (0) = H12 (0). Similarly, by
0 0
taking transpose in (8.28), we also have H22 (0)P21 (0) = −H21 (0).

351
0
It then remains to determine the value of P11 (0). Note that since P (t)P (t) = P (t),
we have P 0 (t)P (t) + P (t)P 0 (t) = P 0 (t), evaluating at t = 0 yields

0 0
2P11 (0) = P11 (0).

0
Therefore, we must have P11 (0) = 0. Combining the above results, we obtain

P 0 (0) = −P (0)H 0 (0)H(0)† − H(0)† H 0 (0)P (0).

Finally, recall that P (t) = ∂Φ(v(t)), and thus

d
P 0 (0) = ∂Φ(v(t)) = ∂(∂Φ(x))[v].
dt t=0

Similarly, we have H 0 (0) = ∂(∇2 L)(x)[v], and it follows that

∂(∂Φ(x))[v] = −∂Φ(x)∂(∇2 L)(x)[v]∇2 L(x)† − ∇2 L(x)† ∂(∇2 L)(x)[v]∂Φ(x).

Thus we conclude that

∂ 2 Φ(x)[v, u] = −∂Φ(x)∂ 2 (∇L)(x)[v, ∇2 L(x)† u] − ∇2 L(x)† ∂ 2 (∇L)(x)[v, ∂Φ(x)u],

which completes the proof.

Lemma 8.9.5. For any x ∈ Γ and u ∈ Tx⊥ (Γ), it holds that

∂ 2 Φ(x)[uu> + ∇2 L(x)† uu> ∇2 L(x)] = −∂Φ(x)∂ 2 (∇L)(x)[∇2 L(x)† uu> ].

352
Proof of Lemma 8.9.5. For any u ∈ Tx⊥ (Γ), we define u(t) = x + t∇2 L(x)† u for t ≥ 0.
By Taylor approximation, we have

∇L(u(t)) = t∇2 L(x)∇2 L(x)† u + o(t) = tu + o(t) (8.29)

and

∇2 L(u(t)) = ∇2 L(x) + t∂ 2 (∇L)(x)[∇2 L(x)† u] + o(t). (8.30)

Combine (8.29) and (8.30) and apply Lemma 8.9.2, and it follows that

0 = ∂ 2 Φ(u(t))[∇L(u(t)), ∇L(u(t))] + ∂Φ(u(t))∇2 L(u(t))∇L(u(t))

= t2 ∂ 2 Φ(u(t))[u + o(1), u + o(1)] + t2 ∂Φ(u(t))∂(∇2 L)(x)[∇2 L(x)† u](u + o(1))


∂Φ(u(t)) 2
+ t2 ∇ L(x)(u + o(1))
t
= t2 ∂ 2 Φ(u(t))[u + o(1), u + o(1)] + t2 ∂Φ(u(t))∂(∇2 L)(x)[∇2 L(x)† u](u + o(1))
∂Φ(u(t)) − ∂Φ(x) 2
+ t2 ∇ L(x)(u + o(1))
t

where the last equality follows from Lemma 8.9.3. Dividing both sides by t2 and
letting t → 0, we get

∂ 2 Φ(x)[u, u] + ∂Φ(x)∂ 2 (∇L)(x)[∇2 L(x)† u, u] + ∂ 2 Φ(x)[∇2 L(x)† u, ∇2 L(x)u] = 0.

Rearranging the above equation completes the proof.

With the notion of Lyapunov Operator in Definition 8.5.5, Lemma 8.9.5 can be
further simplified into Lemma 8.9.6.

353
Lemma 8.9.6. For any x ∈ Γ and Σ ∈ span{uu> | u ∈ Tx⊥ (Γ)},

h∂ 2 Φ(x), Σi = −∂Φ(x)∂ 2 (∇L)(x)[L−1


∇2 L(x) (Σ)]. (8.31)

Proof of Lemma 8.9.6. Let A = uu> + ∇2 L(x)† uu> ∇2 L(x) and B = ∇2 L(x)† uu> .
The key observation is that A + A> = L∇2 L(x) (B + B > ). Therefore, by Lemma 8.9.5,
it holds that

∂ 2 Φ(x)[L∇2 L(x) (B+B > )] = ∂ 2 Φ(x)[A+A> ] = 2∂Φ(x)∂ 2 (∇L)(x)[B] = ∂Φ(x)∂ 2 (∇L)(x)[B+B > ].

Since ∇2 L(x)† is full-rank when restricted to Tx⊥ (Γ), we have span{∇2 L(x)† uu> +
uu> ∇2 L(x)† | u ∈ Tx⊥ (Γ)} = span{uu> | u ∈ Tx⊥ (Γ)}. Thus by the linearity of above
equation, we can replace B + B > by any Σ ∈ span{uu> | u ∈ Tx⊥ (Γ)}, resulting in the
desired equation.

Then Lemma 8.5.6 directly follows from Lemma 8.9.4 and 8.9.5.

8.9.2 Tangent Noise Compensation only Dependends on the

Manifold Itself

Here we show that the second term of (8.12), i.e., the tangent noise compensation
for the limiting dynamics to stay on Γ, only depends on Γ itself.

Lemma 8.9.7. For any x ∈ Γ, suppose there exist a neighborhood Ux of x and


two loss functions L and L0 that define the same manifold Γ locally in Ux , i.e.,
Γ ∩ Ux = {x | ∇L(x) = 0} = {x | ∇L0 (x) = 0}. Then for any v ∈ Tx (Γ), it holds that
(∇2 L(x))† ∂ 2 (∇L)(x) [v, v] = (∇2 L0 (x))† ∂ 2 (∇L0 )(x) [v, v].

Proof of Lemma 8.9.7. Let {v(t)}t≥0 be a smooth curve on Γ with v(0) = x and
dv(t)
dt t=0
= v. Since v(t) stays on Γ, we have ∇L(v(t)) = 0 for all t ≥ 0. Taking deriva-
2
tive for two times yields ∂ 2 (∇L)(v(t))[ dv(t)
dt
, dv(t)
dt
] + ∇2 L(v(t)) d dtv(t)
2 = 0. Evaluating it
354
at t = 0 and multiplying both sides by ∇2 L(x)† , we get

d2 v(t) d2 v(t)
∇2 L(x)† ∂ 2 (∇L)(x) [v, v] = −∇2 L(x)† ∇2 L(x) = −∂Φ(x) .
dt2 t=0 dt2 t=0

Since ∂Φ(x) is the projection matrix onto Tx (Γ) by Lemma 8.5.4, it does not depend
2
on L, so analogously we also have ∇2 L0 (x)† ∂ 2 (∇L0 )(x) [v, v] = −∂Φ(x) d dtv(t)
2 t=0
as
d2 v(t)
well. The proof is thus completed. Note that ∂Φ(x) dt2 t=0
is indeed the second
fundamental form for v at x, and the value won’t change if we choose another
parametric smooth curve with a different second-order time derivative. (See Chapter
6 in Do Carmo [230] for a reference.)

8.10 Proof of results in Section 8.6

Now we are ready to give the missing proofs in Section 8.6 which yield explicit formula
of the limiting diffusion for label noise and isotropic noise.

Corollary 8.6.1 (Limiting Diffusion for Isotropic Noise). If Σ ≡ ID on Γ, SDE (8.12)


is then

1 1
dY (t) = ∂Φ(Y )dW − ∇2 L(Y )† ∂ 2 (∇L)(Y ) [∂Φ(Y )] dt − ∂Φ(Y )∇(ln |∇2 L(Y )|+ )dt
| 2 {z } |2 {z }
Brownian Motion on Manifold Normal Regularization

(8.13)

|∇2 L(Y )+αID |


where |∇2 L(Y )|+ = limα→0 2
αD−rank(∇ L(Y ))
is the pseudo-determinant of ∇2 L(Y ).
|∇2 L(Y )|+ is also equal to the sum of log of non-zero eigenvalue values of ∇2 L(Y ).

Proof of Corollary 8.6.1. Set Σk = ∂Φ, Σ⊥ = ID − ∂Φ and Σ⊥,k = Σk,⊥ = 0 in the de-
composition of Σ by Lemma 8.5.6, and we need to show ∇(ln |Σ|+ ) = ∂ 2 (∇L)[(∇2 L)† ].
Holbrook [231] shows that the gradient of pseudo-inverse determinant satis-
fies ∇|A|+ = |A|+ A† . Thus we have for any vector v ∈ RD , hv, ∇ ln |∇2 L|+ i =
355
D E
|∇2 L|+ ∇2 L
|∇2 L|+
, ∂ 2 (∇L)[v] = h∇2 L, ∂ 2 (∇L)[v]i = ∂ 2 (∇L)[v, ∇2 L] = v, ∂ 2 (∇L)[(∇2 L)† ] ,
which completes the proof.

Corollary 8.6.2 (Limiting Flow for Label Noise). If Σ ≡ c∇2 L on Γ for some
constant c > 0, SDE (8.12) can be simplified into (8.15) where the regularization is
from the noise in the normal space.

dY (t) = −1/4 · ∂Φ(Y (t))∇ tr[c∇2 L(Y (t))]dt. (8.15)

Proof of Corollary 8.6.2. Since Σ = c∇2 L, here we have Σ⊥ = Σ and Σk , Σ⊥,k , Σk,⊥ =
0. Thus it suffices to show that 2∂ 2 (∇L) L−1
  2
∇2 L (Σ⊥ ) = ∇ tr[∇ L]. Note that for any

v ∈ RD ,

v > ∇ tr[∇2 L] = ID , ∂(∇2 L)[v] = ID − ∂Φ, ∂(∇2 L)[v] , (8.32)

where the second equality is because the the tangent space of symmetric rank-(D − M )
matrices at ∇2 L is {A∇2 L + ∇2 LA> | A ∈ RD×D }, and every element in this
tangent space has zero inner-product with ∂Φ by Lemma 8.5.4. Also note that
L−1 2 1 2 −1 2 2
∇2 L (∇ L) = 2 (ID − ∂Φ), thus hID − ∂Φ, ∂ (∇L)[v]i = 2 L∇2 L (∇ L), ∂ (∇L)[v] =

2v > ∂ 2 (∇L)[L−1 2
∇2 L (∇ L)].

Lemma 8.6.3. Let Σ ∈ RD×D be given by Σij (x) = (1 + Qj−3


α v, Q−π/2 x1:2 )(2 +

hQj−3
α v, x1:2 i), if i = j ≥ 3 or 0 otherwise, then the solution of SDE (8.12) is the

following (8.16) , which implies that Y (t) moves anti-clockwise with a constant angular
speed of (D − 2)/2.

Y1:2 (t) = Qt(D−2)/2 Y1:2 (0) and Y3:D (t) ≡ 0. (8.16)

356
Proof of Lemma 8.6.3. Note that for any x ∈ Γ, it holds that

2 + hQj−3

α v, x1:2 i if i = j ≥ 3,






∇2 L(x)

ij
= xi xj if i, j ∈ {1, 2}, (8.33)





0
 otherwise.

Then clearly Σ only brings about noise in the normal space, and specifically, it holds
that L−1 0 D−3
∇2 L(x) (Σ(x)) = diag(0, 0, 1+ Qα v, Q−π/2 x1:2 , . . . , 1+ Qα v, Q−π/2 x1:2 ). Fur-
ther note that, by the special structure of the hessian in (8.33) and Lemma 8.9.3, for any
x1:2 Q−π/2 x1:2 >
x ∈ Γ, we have ∂Φ(x) = (x2 , −x1 , 0, . . . , 0)> (x2 , −x1 , 0, . . . , 0) = Q−π/2
 
0 0
.
Combining these facts, the dynamics of the first two coordinates in SDE (8.12) can
be simplified into

 
dx1:2 (t) 1 2 −1
=− ∂Φ(x(t))∂ (∇L)(x(t))[L∇2 L (Σ(x(t))]
dt 2 1:2
D
1 X
= − Q−π/2 x1:2 x> >
1 + Qj−3

1:2 Q−π/2 α v, Q−π/2 x1:2 ∇1:2 (∂jj L)(x)
2 j=3
* D
+
1 X
1 + Qj−3 Qj−3

= − Q−π/2 x1:2 Q−π/2 x1:2 , α v, Q−π/2 x1:2 α v
2 j=3
* D
+ D
!
1 X X 2
= − Q−π/2 x1:2 Q−π/2 x1:2 , Qj−3
α v + Qj−3
α v, Q−π/2 x1:2
2 j=3 j=3
 
1 D−2 2 D−2
= − Q−π/2 x1:2 0 + Q−π/2 x1:2 2 = Qπ/2 x1:2 ,
2 2 2

where the second to the last equality follows from the property of Qα and the last
equality follows from the fact that kx1:2 k22 = 1 for all x ∈ Γ. Note we require k ≥ 3 (or
2 2
D ≥ 5) to allow D j−3
= D−2
P
j=3 Qα v, Q−π/2 x1:2 2
Q−π/2 x1:2 2 . On the other hand,
dx3:D (t)
we have dt
= 0 as ∂Φ kills the movement on that component.

357
The proof is completed by noting that the solution of x1:2 is

 
D−2
x1:2 (t) = exp t · Qπ/2 x1:2 (0),
2

and by Lemma 8.10.1,

 
D−2 t(D−2) t(D−2)
exp t · Qπ/2 = (exp(Qπ/2 )) 2 = Q1 2 = Q t(D−2) .
2 2

Lemma 8.10.1. exp(( 01 −1 cos 1 − sin 1


0 )) = ( sin 1 cos 1 ).

P∞ At
Proof. By definition, for matrix A = ( 01 −1
0 ), exp(A) = t=0 t! . Note that A2 = −I,
A3 = −A and A4 = I. Using this pattern, we can easily check that
   

P∞ i 1
P∞ i 1
X At  − i=0 (−1)
i=0 (−1) (2i)! cos 1 − sin 1
(2i+1)! 
= P = .
t! ∞ i 1
P∞ i 1
t=0 i=0 (−1) (2i+1)! i=0 (−1) (2i)! sin 1 cos 1

8.11 Proof of results in Section 8.7

In this section, we present the missing proofs in Section 8.7 regarding the over-
parametrized linear model.
For convenience, for any p, r ≥ 0 and u ∈ RD , we denote by Brp (u) the `p norm
ball of radius r centered at u. We also denote vi:j = (vi , vi+1 , . . . , vj )> for i, j ∈ [D].

8.11.1 Proof of Theorem 8.7.1

In this subsection, we provide the proof of Theorem 8.7.1.

358
Theorem 8.7.1. In the setting of OLM, suppose the groundtruth is κ-sparse and
n ≥ Ω(κ ln d) training data are sampled from either i.i.d. Gaussian or Boolean
distribution. Then for any initialization xinit (except a zero-measure set) and any
 > 0, there exist η0 , T > 0 such that for any η < η0 , OLM trained with label noise
SGD (8.14) with LR equal to η for bT /η 2 c steps returns an -optimal solution, with
probability of 1 − e−Ω(n) over the randomness of the training dataset.

Proof of Theorem 8.7.1. First, by Lemma 8.7.6, it holds with probability at least
1 − e−Ω(n) that the solution to (8.21), x∗ , is unique up to and satisfies |x∗ | = ψ(w∗ ).
Then on this event, for any  > 0, by Lemma 8.7.5, there exists some T > 0 such that
xT given by the Riemannian gradient flow (8.20) satisfies that xT is an /2-optimal
solution of the OLM. For this T , by Theorem 8.5.7, we know that the bT /η 2 c-th
SGD iterate, xη (bT /η 2 c), satisfies kxη (bT /η 2 c) − xT k2 ≤ /2 with probability at
least 1 − e−Ω(n) for all sufficiently small η > 0, and thus xη (bT /η 2 c) is an -optimal
solution of the OLM. Finally, the validity of applying Theorem 8.5.7 is guaranteed by
Lemma 8.7.2 and 8.7.3. This completes the proof.

In the following subsections, we provide the proofs of all the components used in
the above proof.

8.11.2 Proof of Lemma 8.7.2


zi u
Recall that for each i ∈ [n] fi (x) = f (u, v) = zi> (u 2

− v 2 ), ∇fi (x) = 2 zi v
, and
K(x) = (Kij (x))i,j∈[n] where each Kij (x) = h∇fi (x), ∇fj (x)i. Then

 
 zi u 
 
2
∇ `i (x) = 2   (zi u) >
−(zi v) > + (fi (u, v) − yi ) · diag(zi , zi ).
−zi v

359
So for any x ∈ Γ, it holds that
 
n
 zi u 
 
2 2 X
∇ L(x) =  (zi u)> −(zi v)> . (8.34)
n

i=1 −zi v

Lemma 8.11.1. For any fixed x ∈ RD , suppose {∇fi (x)}i∈[n] is linearly independent,
then K(x) is full-rank.

Proof of Lemma 8.11.1. Suppose otherwise, then there exists some λ ∈ Rn such that
λ 6= 0 and λ> K(x)λ = 0. However, note that

n
X
>
λ K(x)λ = λi λj Kij (x)
i,j=1
X n
= λi λj h∇fi (x), ∇fj (x)i
i,j=1
n 2
X
= λi ∇fi (x) ,
i=1 2

Pn
which implies that i=1 λi ∇fi (x) = 0. This is a contradiction since by assumption
{∇fi (x)}i∈[n] is linearly independent.

Lemma 8.7.2. Consider the loss L defined in (8.17) and manifold Γ defined in (8.18).
If data is full rank, i.e., rank(Z) = n, then it holds that (a). Γ is a smooth manifold
of dimension D − n; (b). rank(∇2 L(x)) = n for all x ∈ Γ. In particular, rank(Z) = n
holds with probability 1 for Gaussian distribution and with probability 1−cd for Boolean
distribution for some constant c ∈ (0, 1).

Proof of Lemma 8.7.2. (1) By preimage theorem [223], it suffices to check the jacobian
z1 u zn u
 
[∇f1 (x), . . . , ∇fn (x)] = 2[ −z 1 v
, . . . , −zn v
] is full rank. Similarly, for the second
zi u

claim, due to (8.34). it is also equivalent to show that { −z i v
}i∈[n] is of rank n.
Since uv ∈ Γ ⊂ U , each coordinate is non-zero, thus we only need to show that


{zi }i∈[n] is of rank n. This happens with probability 1 in the Gaussian case, and
360
probability at least 1 − cd for some constant c ∈ (0, 1) by Kahn et al. [232]. This
completes the proof.

8.11.3 Proof of Lemma 8.7.3

We first establish some auxiliary results. The following lemma shows the PL condition
along the trajectory of gradient flow.

Lemma 8.11.2. Along the gradient flow generated by −∇L, it holds that
k∇L(x(t))k2 ≥ 16
λ (ZZ > )
n min
· mini∈[d] |ui (0)vi (0)|L(x(t)), ∀t ≥ 0.

To prove Lemma 8.11.2, we need the following invariance along the gradient flow.

Lemma 8.11.3. Along the gradient flow generated by −∇L, uj (t)vj (t) stays constant
for all j ∈ [d]. Thus, sign(uj (t)) = sign(uj (0)) and sign(vj (t)) = sign(vj (0)) for any
j ∈ [d].

Proof of Lemma 8.11.3.

∂ ∂uj (t) ∂vj (t)


(uj (t)vj (t)) = · vj (t) + uj (t) ·
∂t ∂t ∂t
= ∇u L(u(t), v(t))j · vj (t) + uj (t) · ∇v L(u(t), v(t))j
n n
2X 2uj (t) X
= (fi (u(t), v(t)) − yi )zi,j uj (t)vj (t) − (fi (u(t), v(t)) − yi )zi,j vj (t)
n i=1 n i=1

= 0.

Therefore, any sign change of uj (t), vj (t) would enforce uj (t) = 0 or vj (t) = 0 for
some t > 0 since uj (t), vj (t) are continuous in time t. This immediately leads to a
contradiction to the invariance of uj (t)vj (t).

We then can prove Lemma 8.11.2.

361
Proof of Lemma 8.11.2. Note that

n
1 X
k∇L(x)k22 = 2 (fi (x) − yi )(fj (x) − yj ) h∇fi (x), ∇fj (x)i
n i,j=1
n
1 X
≥ 2 (fi (x) − yi )2 λmin (K(x))
n i=1
2
= L(x)λmin (K(x)),
n

where K(x) is a n × n p.s.d. matrix with Kij (x) = h∇fi (x), ∇fj (x)i. Below we
lower bound λmin (K(x)), the smallest eigenvalue of K(x). Note that Kij (x(t)) =
4 dh=1 zi,h zj,h ((uh (t))2 + (vh (t))2 ), and we have
P

K(x(t)) = 4Zdiag((u(t)) 2
+ (v(t)) 2 )Z >  8Zdiag(|u(t) v(t)|)Z >
(∗)
= 8Zdiag(|u(0) v(0)|)Z >  8 min |ui (0)vi (0)|ZZ T
i∈[d]

where (∗) is by Lemma 8.11.3. Thus λmin (K(x(t)) ≥ 8 mini∈[d] |ui (0)vi (0)|λmin (ZZ T )
for all t ≥ 0, which completes the proof.

We also need the following characterization of the manifold Γ.

Lemma 8.11.4. All the stationary points in U are global minimizers, i.e., Γ = {x ∈
U | ∇L(x) = 0}.

Proof of Lemma 8.11.4. Since Γ is the set of local minimizers, each x in Γ must satisfy
∇L(x) = 0. The other direction is proved by noting that rank({zi }i∈[n] ) = n, which
implies rank({∇fi (x)}i∈[n] ) = n.

Now, we are ready to prove Lemma 8.7.3 which is restated below.

Lemma 8.7.3. Consider the loss function L defined in (8.17), manifold Γ and its
dxt
open neighborhood defined in (8.18). For gradient flow dt
= −∇L(xt ) starting at any
x0 ∈ U , it holds that Φ(x0 ) ∈ Γ.
362
dx(t)
Proof of Lemma 8.7.3. It suffices to prove gradient flow dt
= −∇L(x(t)) converges
when t → ∞, as long as x(0) ∈ U . Whenever it converges, it must converge to a
stationary point in U . The proof will be completed by noting that all stationary point
of L in U belongs to Γ (Lemma 8.11.4).
Below we prove limt→∞ x(t) exists. Denote C = 16
n
mini∈[d] |ui (0)vi (0)|λmin (ZZ > ),
then it follows from Lemma 8.11.2 that

k∇L(x(t))k22 − dL(x(t))
p
dx(t) dt 1 d L(x(t))
= k∇L(x(t))k ≤ p = p =− √ .
dt CL(x(t)) L(x(t)) 2 C dt

R∞ R∞ √
dx(t) 1 d L(x(t))
Thus the total GF trajectory length is bounded by t=0 dt
dt ≤ t=0
− 2 C ddt dt
√ ≤
L(x(0))

2 C
, where the last inequality uses that L is non-negative over RD . Therefore, the
GF must converge.

8.11.4 Proof of results in Section 8.7.2


2
P
Without loss of generality, we will assume i=1 zi,j > 0 for all j ∈ [d], because
otherwise we can just delete the unused coordinate, since there won’t be any update
in the parameter corresponding to that coordinate. Moreover, in both gaussian and
2
P
boolean setting, it can be shown that with probability 1, i=1 zi,j > 0 for all j ∈ [d].
To study the optimal solution to (8.21), we consider the corresponding d-
dimensional convex program in terms of w ∈ Rd , which has been studied in Tropp
[226]:

d n
!
4X X 2
minimize R(w) = z |wj |,
n j=1 i=1 i,j
(8.35)
subject to Zw = Zw∗ .

363
Here we slightly abuse the notation of R and the parameter dimension will be clear
from the context. We can relate the optimal solution to (8.21) to that of (8.35) via a
canonical parametrization defined as follows.

u

Definition 8.11.5 (Canonical Parametrization). For any w ∈ Rd , we define v
=
1/2 1/2 >
ψ(w) = ([w> ]+ , [−w> ]+ ) as the canonical parametrization of w. Clearly, it holds
2 2
that u −v = w.

Indeed, we can show that if (8.35) has a unique optimal solution, it immediately
follows that the optimal solution to (8.21) is also unique up to sign flips of each
coordinate, as summarized in the lemma below.

Lemma 8.11.6. Suppose the optimal solution to (8.35) is unique and equal to w∗ .
Then the optimal solution to (8.21) is also unique up to sign flips of each coordi-
u∗ , ve∗ ) = ψ(w∗ ), that is, the canonical
nate. In particular, one of them is given by (e
parametrization of w∗ .

Proof of Lemma 8.11.6. Let (b


u, vb) be any optimal solution of (8.21) and we define
w
b=u
b 2
− vb 2 , which is also feasible to (8.35). By the optimality of w∗ , we have

d n
! d n
! d n
!
X X X X X X
2
zi,j |wj∗ | ≤ 2
zi,j |w
bj | ≤ 2
zi,j u2j + vbj2 ).
(b (8.36)
j=1 i=1 j=1 i=1 j=1 i=1

u∗ , ve∗ ) = ψ(w∗ ) is feasible to (8.21). Thus, it follows from the


On the other hand, (e
optimality of (b
u, vb) that

d n
! d n
! d n
!
X X X X X X
2
zi,j u2j
(b + vbj2 ) ≤ 2
zi,j u∗j )2
((e + vj∗ )2 )
(e = 2
zi,j |wj∗ |.
j=1 i=1 j=1 i=1 j=1 i=1

(8.37)

364
Combining (8.36) and (8.37) yields

d n
! d n
! d n
!
X X X X X X
2
zi,j u2j + vbj2 ) =
(b 2
zi,j |wj∗ | = 2
zi,j u2j − vbj2 |
|b (8.38)
j=1 i=1 j=1 i=1 j=1 i=1

which implies that u


b 2
− vb 2
is also an optimal solution of (8.35). Since w∗ is the
unique optimal solution to (8.35), we have u
b 2
− vb 2
= w∗ . Moreover, by (8.38), we
must have u
b 2
= [w∗ ]+ and u
b 2
= [w∗ ]+ , otherwise the equality would not hold. This
completes the proof.

Therefore, the unique optimality of (8.21) can be reduced to that of (8.35). In the
sequel, we show that the latter holds for both Boolean and Gaussian random vectors.
We divide Lemma 8.7.6 into to Lemma 8.11.8 and 8.11.7 for clarity.

i.i.d.
Lemma 8.11.7 (Boolean Case). Let z1 , . . . , zn ∼ Unif({±1}d ). There exist some
constants C, c > 0 such that if the sample size n satisfies

n ≥ C[κ ln(d/κ) + κ]

2
then with probability at least 1 − e−cn , the optimal solution of (8.21), (b
u, vb), is unique
up to sign flips of each coordinate and recovers the groundtruth, i.e., u
b 2
− vb 2
= w∗ .

i.i.d.
Proof of Lemma 8.11.7. By the assumption that z1 , . . . , zn ∼ Unif({±1}d ), we have
Pn 2
i=1 zi,j = n for all j ∈ [d]. Then (8.35) is equivalent to the following optimization

problem:

minimize g(w) = kwk1 ,


(8.39)
∗ 2 ∗ 2
subject to Zw = Z((u ) − (v ) ).


This model exactly fits the Example 6.2 in Tropp [226] with σ = 1 and α = 1/ 2.
Then applying Equation (4.2) and Theorem 6.3 in Tropp [226], (8.39) has a unique

365
2
optimal solution equal to (u∗ ) 2
− (v ∗ ) 2
with probability at least 1 − e−ch for some
constant c > 0, given that the sample size satisfies

n ≥ C(κ ln(d/κ) + κ + h)

n
for some absolute constant C > 0. Choosing h = 2C
and then adjusting the choices of
C, c appropriately yield the desired result. Finally, applying Lemma 8.11.6 finishes
the proof.

The Gaussian case requires more careful treatment.


i.i.d.
Lemma 8.11.8 (Gaussian Case). Let z1 , . . . , zn ∼ N (0, Id ). There exist some
constants C, c > 0 such that if the sample size satisfies

n ≥ Cκ ln d,

then with probability at least 1 − (2d + 1)e−cn , the optimal solution of (8.21), (b
u, vb),
is unique up to sign flips of each coordinate of u
b and vb and recovers the groundtruth,
i.e., u
b 2
− vb 2
= w∗ .
i.i.d.
Proof of Lemma 8.11.8. Since z1 , . . . , zn ∼ N (0, Id ), we have

" n
#
X
P 2
zi,j ∈ [n/2, 3n/2], ∀j ∈ [d] ≥ 1 − 2de−cn
i=1

for some constant c > 0, and we denote this event by En . Therefore, on En , we have

D
X D
X
2 2
2 (uj + vj ) ≤ R(x) ≤ 6 (u2j + vj2 )
j=1 j=1

or equivalently,

2(ku 2 k1 + kv 2 k1 ) ≤ R(x) ≤ 6(ku 2 k1 + v 2 k1 ).

366
Define w∗ = (u∗ ) 2
− (v ∗ ) 2 , and (8.35) is equivalent to the following convex
optimization problem

d n
!
4X X 2
minimize g(w) = z |wj + wj∗ |,
n j=1 i=1 i,j
(8.40)
subject to Zw = 0.

The point w = 0 is feasible for (8.40), and we claim that this is the unique optimal
solution when n is large enough. In detail, assume that there exists a non-zero feasible
point w for (8.40) in the descent cone [226] D(g, w∗ ) of g, then

kZwk2
λmin (Z; D(g, w∗ )) ≤ =0
kwk2

where the equality follows from that w is feasible. Therefore, we only need to show
that λmin (Z; D(g, x∗ )) is bounded from below for sufficiently large n.
On En , it holds that g belongs to the following function class

( d
)
X
G= h : Rd → R | h(w) = υj |wj |, υ ∈ Υ with Υ = {υ ∈ Rd : υj ∈ [2, 6], ∀j ∈ [d]}.
j=1

We identify gυ ∈ G with υ ∈ Υ, then D(g, w∗ ) ⊆ ∪υ∈Υ D(gυ , w∗ )) := DΥ , which further


implies that

λmin (Z; D(g, w∗ )) ≥ λmin (Z; DΥ ).

Recall the definition of minimum conic singular value [226]:

λmin (Z; DΥ ) = inf sup hq, Zpi.


p∈DΥ ∩S d−1 q∈S n−1

367
where S n−1 denotes the unit sphere in Rn . Applying the same argument as in Tropp
[226] yields

√ 2
P λmin (Z; DΥ ) ≥ n − 1 − w(DΥ ) − h ≥ 1 − e−h /2 .
 

Take the intersection of this event with En , and we obtain from a union bound that


λmin (Z; D(g, w∗ )) ≥ n − 1 − w(DΥ ) − h (8.41)

2 /2
with probability at least 1 − e−h − 2de−cn . It remains to determine w(DΥ ), which
is defined as
" # " #
w(DΥ ) = Ez∼N (0,Id ) sup hz, pi = Ez∼N (0,Id ) sup sup hz, pi . (8.42)
p∈DΥ ∩S d−1 υ∈Υ p∈D(gυ ,x∗ )∩S d−1

Without loss of generality, we assume that w∗ = (w1∗ , . . . , wκ∗ , 0, . . . , 0)> with


w1∗ , . . . , wκ∗ > 0, otherwise one only needs to specify the signs and the nonzero set of
w∗ in the sequel. For any υ ∈ Υ and any p ∈ D(gυ , w∗ ) ∩ S d−1 , there exists some
τ > 0 such that gυ (w∗ + τ · p) ≤ gυ (w∗ ), i.e.,

d
X d
X
υj |wj∗ + τ pj | ≤ υj |wj∗ |
j=1 j=1

which further implies that

d
X κ
X κ
X
τ υj |pj | ≤ υj (|wj∗ | − |wj∗ − τ pj |) ≤ τ υj |pj |
j=κ+1 j=1 j=1

368
where the second inequality follows from the triangle inequality. Then since each
υj ∈ [2, 6], it follows that

d
X κ
X
|pj | ≤ 3 |pj |.
j=κ+1 j=1

Note that this holds for all ξ ∈ Ξ simultaneously. Now let us denote p1:κ =
(p1 , . . . , pκ ) ∈ Rκ and p(κ+1):d = (pκ+1 , . . . , pd ) ∈ Rd−κ , and similarly for other d-
dimensional vectors. Then for all p ∈ DΥ ∩ S d−1 , by Cauchy-Schwartz inequality, we
have


kp(κ+1):d k1 ≤ 3kp1:κ k1 ≤ 3 κkp1:κ k2 .

Thus, for any z ∈ Rd and any p ∈ DΥ ∩ S d−1 , it follows that

hz, pi = hz1:κ , p1:κ i + hz(κ+1):d , p(κ+1):d i

≤ kz1:κ k2 kp1:κ k2 + kp(κ+1):d k1 · max |zj |


j∈{κ+1,...,d}

≤ kz1:κ k2 kp1:κ k2 + 3 κkp1:κ k2 · max |zj |
j∈{κ+1,...,d}

≤ kz1:κ k2 + 3 κ · max |zj |
j∈{κ+1,...,d}

where the last inequality follows from the fact that p ∈ S d−1 . Therefore, combine the
above inequality with (8.42), and we obtain that


 
w(DΥ ) ≤ E kz1:κ k2 + 3 κ · max |zj |
j∈{κ+1,...,d}
√ √
 
≤ κ+3 κ·E max |zj | . (8.43)
j∈{κ+1,...,d}

p √
where the second inequality follows from the fact that E[kz1:κ k2 ] ≤ E[kz1:κ k22 ] = κ.
To bound the second term in (8.43), applying Lemma 8.11.9, it follows from (8.43)

369
that

√ p
w(DΥ ) ≤ κ + 3 2κ ln(2(d − κ)). (8.44)

Therefore, combining (8.44) and (8.41), we obtain

√ √ p
λmin (Z; D(g, w∗ )) ≥ n − 1 − κ − 3 2κ ln(2(d − κ)) − h.


Therefore, choosing h = n − 1/2, as long as n satisfies that n ≥ C(κ ln d) for some
constant C > 0, we have λmin (Z; D(g, w∗ )) > 0 with probability at least 1−(2d+1)e−cn .
Finally, the uniqueness of the optimal solution to (8.21) in this case follows from
Lemma 8.11.6.

  p
Lemma 8.11.9. Let z ∼ N (0, Id ), then it holds that E maxi∈[d] |zi | ≤ 2 ln(2d).

Proof of Lemma 8.11.9. Denote M = maxi∈[d] |zi |. For any λ > 0, by Jensen’s in-
equality, we have

  Xd
λ|zi |
λ·E[M ]
E eλ|zi | .
 λM   
e ≤E e = E max e ≤
i∈[d]
i=1

Note that E[eλ|zi | ] ≤ 2 · E[eλzi ]. Thus, by the expression of the Gaussian moment
generating function, we further have

d
2
X
λ·E[M ]
E eλzi = 2deλ /2 ,
 
e ≤2
i=1

from which it follows that

ln(2d) λ
E[M ] ≤ + .
λ 2

p
Choosing λ = 2 ln(2d) yields the desired result.

370
8.11.5 Proof of Lemma 8.7.5

Instead of studying the convergence of the Riemannian gradient flow directly, it is more
convenient to consider it in the ambient space RD . To do so, we define a Lagrange
function L(x; λ) = R(x) + ni=1 λi (fi (x) − yi ) for λ ∈ Rn . Based on this Lagrangian,
P

we can continuously extend ∂Φ(x)∇R(x) to the whole space RD . In specific, we can


find a continuous function F : RD → RD such that F (·)|Γ = ∂Φ(·)∇R(·). Such an F
can be implicitly constructed via the following lemma.

Lemma 8.11.10. The `2 norm has a unique minimizer among {∇x L(x; λ) |
λ ∈ Rn } for any fixed x ∈ RD . Thus we can define F : RD → RD by
F (x) = argming∈{∇x L(x;λ)|λ∈Rn } kgk2 . Moreover, it holds that hF (x), ∇fi (x)i = 0 for
all i ∈ [n].

Proof of Lemma 8.11.10. Fix any x ∈ RD . Note that {∇x L(x; λ) | λ ∈ Rn } is the
subspace spanned by {∇fi (x)}i∈[n] shifted by ∇R(x), thus there is unique minimizer
of the `2 norm in this set. This implies that F (x) = argming∈{∇x L(x;λ)|λ∈Rn } kgk2 is
well-defined.
To show the second claim, denote h(λ) = k∇x L(x; λ)k22 /2, which is a quadratic
function of λ ∈ Rn . Then we have
     
Pn
 h∇R(x), ∇f1 (x)i   i=1 λi h∇f1 (x), ∇fi (x)i   h∇R(x), ∇f1 (x)i 
 ..   ..   .. 
∇h(λ) = 
 . +
  . =
  .  + K(x)λ.

  P   
n
h∇R(x), ∇fn (x)i i=1 λi h∇fn (x), ∇fi (x)i h∇R(x), ∇fn (x)i

For any λ such that ∇x L(x; λ) = F (x), we must have ∇h(λ) = 0 by the definition of
F (x), which by the above implies

(K(x)λ)i = −h∇R(x), ∇fi (x)i for all i ∈ [n].

371
Therefore, we further have

n
X
hF (x), ∇fi (x)i = h∇R(x), ∇fi (x)i + λj h∇fi (x), ∇fj (x)i = h∇R(x), ∇fi (x)i + (K(x)λ)i = 0
j=1

for all i ∈ [n]. This finishes the proof.

Hence, with any initialization x(0) ∈ Γ, the limiting flow (8.20) is equivalent to
the following dynamics

dx(t) 1
= − F (x(t)). (8.45)
dt 4

Thus Lemma 8.7.5 can be proved by showing that the above x(t) converges to x∗ as
t → ∞. We first present a series of auxiliary results in below.

Lemma 8.11.11 (Implications for F (x) = 0). Let F : RD → RD be as defined in


Lemma 8.11.10. For any x = uv ∈ RD such that F (x) = 0, it holds that for each


j ∈ [d], either uj = 0 or vj = 0.

Proof. Since F (x) = 0, it holds for all j ∈ [d] that,

n
" n n
#
∂R X ∂fi 4X 2 X
0= (x) + λi (x) (x) = 2uj z + λi (x)zi,j ,
∂uj i=1
∂uj n i=1 i,j i=1
n
" n n
#
∂R X ∂fi 4X 2 X
0= (x) + λi (x) (x) = 2vj z − λi (x)zi,j .
∂vj i=1
∂vj n i=1 i,j i=1

If there exists some j ∈ [d] such that uj 6= 0 and vj 6= 0, then it follows from the above
two identities that

n
X
2
zi,j =0
i=1

which happens with probability 0 in both the Boolean and Gaussian case. Therefore,
we must have uj = 0 or vj = 0 for all j ∈ [d].
372
Lemma 8.11.12. Let F : RD → RD be as defined in Lemma 8.11.10. Then F is
continuous on RD .

Proof. Case I. We first consider the simpler case of any fixed x∗ ∈ U = (R \ {0})D ,
assuming that K(x∗ ) is full-rank. Lemma 8.11.10 implies that for any λ ∈ Rn such
that ∇x L(x∗ ; λ) = F (x∗ ), we have

K(x∗ )λ = −[∇f1 (x) . . . ∇fn (x)]> ∇R(x).

Thus such λ is unique and given by

λ(x∗ ) = −K(x∗ )−1 [∇f1 (x) . . . ∇fn (x)]> ∇R(x).

Since K(x) is continuous around x∗ , there exists a sufficiently small δ > 0 such that
for any x ∈ Bδ (x∗ ), K(x) is full-rank, which further implies that K(x)−1 is also
continuous in Bδ (x). Therefore, by the above characterization of λ, we see that λ(x)
is continuous for x ∈ Bδ (x∗ ), and so is F (x) = ∇R(x) + ni=1 λi (x)∇fi (x).
P

Case II. Next, we consider all general x∗ ∈ RD . Here for simplicity, we reorder the
coordinates as x = (u1 , v1 , u2 , v2 , . . . , ud , vd ) with a slight abuse of notation. Without
loss of generality, fix any x∗ such that for some q ∈ [d], (ui (0))2 + (vi (0))2 > 0 for all
i = 1, . . . , q and u∗i = vi∗ = 0 for all i = q + 1, . . . , d. Then ∇R(x∗ ) and {∇fi (x∗ )}i∈[n]
only depend on {zi,j }i∈[n],j∈[q] , and for all i ∈ [n], it holds that

(∇R(x∗ ))(2q+1):D = (∇fi (x∗ ))(2q+1):D = 0.

Note that if we replace {∇fi (x)}i∈[n] by any fixed and invertible linear transform of
itself, it would not affect the definition of F (x). In specific, we can choose an invertible
matrix Q ∈ Rn×n such that, for some q 0 ∈ [q], (e
z1 , . . . , zen ) = (z1 , . . . , zn )Q satisfies
zi,1:q }i∈[q0 ] is linearly independent and zei,1:q = 0 for all i = q 0 + 1, . . . , n. We then
that {e
373
h i
consider ∇fe1 (x), . . . , ∇fen (x) = [∇f1 (x), . . . , ∇fn (x)] Q and the corresponding F (x).
For notational simplicity, we assume that Q can be chosen as the identity matrix, so
that (z1 , . . . , zn ) itself satisfies the above property, and we repeat it here for clarity

{zi,1:q }i∈[q0 ] is linearly independent and zei,1:q = 0 for all i = q 0 + 1, . . . , n. (8.46)

This further implies that

(∇fi (x))1:(2q) = 0, for all i ∈ {q 0 + 1, . . . , n} and x ∈ RD . (8.47)

In the sequel, we use λ for n-dimensional vectors and λ̄ for q 0 -dimensional vectors.
Denote2

n
X
λ(x) ∈ argmin ∇R(x) + λi ∇fi (x) ,
λ∈Rn i=1 2
q0
!
X
λ̄(x) ∈ argmin ∇R(x) + λ̄i ∇(fi (x) .
λ̄∈Rq0 i=1 1:(2q) 2

Then due to (8.46) and (8.47), we have

q 0 !
X n
X
∗ ∗ ∗ ∗
∇R(x ) + λ̄i (x )∇fi (x ) = ∇R(x ) + λi (x)∇fi (x∗ ) = kF (x∗ )k2 .
i=1 1:(2q) 2 i=1 2

(8.48)

2
We do not care about the specific choice of λ(x) or λ̄(x) when there are multiple candidates,
and we only need their properties according to Lemma 8.11.10, so they can be arbitrary. Also, the
minimum of `2 -norm of an affine space can always be attained so argmin exists.
374
On the other hand, for any x ∈ RD , by (8.47), we have

q0
  n
!
X X
∇R(x) + λ̄i (x)∇fi (x) = minn ∇R(x) + λi (x)∇fi (x)
λ∈R
i=1 1:(2q) 2 i=1 1:(2q) 2
 n
X 
≤ ∇R(x) + λi (x)∇fi (x) = kF1:(2q) (x)k2
i=1 1:(2q) 2
n
X
≤ kF (x)k2 ≤ ∇R(x) + λi (x∗ )∇fi (x)
i=1 2

(8.49)

where the first and third inequalities follow from the definition of F (x). Let x → x∗ ,
by the continuity of ∇R(x) and {∇fi (x)}i∈[n] , we have

n
X n
X
∗ ∗
lim ∇R(x) + λi (x )∇fi (x) = ∇R(x ) + λi (x∗ )∇fi (x∗ ) (8.50)
x→x∗
i=1 2 i=1 2

Denote K(x)
e e ij (x))(i,j)∈[q0 ]2 = (h∇fi (x)1:(2q) , ∇fi (x)1:(2q) i)(i,j)∈[q0 ]2 . By apply-
= (K
e ∗ ) is full-rank, it also holds that
ing the same argument as in Case I, since K(x
limx→x∗ λ̄(x) = λ̄(x∗ ), and thus

 q0  q 0 
X X
∗ ∗
lim ∇R(x) + λ̄i (x)∇fi (x)1:(2q) = ∇R(x) + λ̄i (x )∇fi (x ) .
x→x∗
i=1 2 i=1 1:(2q) 2

(8.51)

Combing (8.48), (8.49), (8.50) and (8.51) yields

 n
X 
lim∗ kF1:(2q) (x)k2 = lim∗ minn ∇R(x) + λi ∇fi (x) = kF (x∗ )k2 . (8.52)
x→x x→x λ∈R
i=1 1:(2q) 2

p
Moreover, since kF(2q+1):D (x)k2 = kF (x)k22 − kF1:(2q) (x)k22 , we also have

lim kF(2q+1):D (x)k2 = 0. (8.53)


x→x∗

375
It then remains to show that limx→x∗ F1:(2q) (x) = F1:(2q) (x∗ ), which directly follows
from limx→x∗ λ1:q0 (x) = λ1:q0 (x∗ ) = λ̄(x∗ ).
e ∗ )  0, we can
Now, for any  > 0, due to the convergence of λ̄(x) and that K(x
pick a sufficiently small δ1 such that for some constant α > 0 and all x ∈ Bδ1 (x∗ ), it
holds that kλ̄(x) − λ̄(x∗ )k2 ≤ /2 and

2 2
 q0   q0 
X X
∇R(x) + λ̄i ∇fi (x) ≥ ∇R(x) + λ̄i (x)∇fi (x) + αkλ̄ − λ̄(x)k22 .
i=1 1:(2q) 2 i=1 1:(2q) 2

(8.54)

for all λ̄ ∈ Rp , where the inequality follows from the strong convexity. Meanwhile, due
to (8.47), we have

 q0   n 
X X
lim∗ ∇R(x) + λi (x)∇fi (x) = lim∗ ∇R(x) + λi (x)∇fi (x)
x→x x→x
i=1 1:(2q) 2 i=1 1:(2q) 2
 q0 
X
∗ ∗
= ∇R(x) + λ̄i (x )∇fi (x )
i=1 1:(2q) 2
 q0 
X
= lim∗ ∇R(x) + λ̄i (x)∇fi (x) .
x→x
i=1 1:(2q) 2

where the second equality follows from (8.52) and the second equality is due to (8.51).
Therefore, we can pick a sufficiently small δ2 such that

q0 q0
α2
 X   X 
∇R(x) + λi (x)∇fi (x) ≤ ∇R(x) + λ̄i (x)∇fi (x) +
i=1 1:(2q) 2 i=1 1:(2q) 2 4

(8.55)

for all x ∈ Bδ2 (x∗ ). Setting δ = min(δ1 , δ2 ), it follows from (8.54) and (8.55) that


kλ1:q0 (x) − λ̄(x)k2 ≤ , for all x ∈ Bδ (x∗ ).
2

376
Recall that we already have kλ̄(x) − λ̄(x∗ )k ≤ /2, and thus

kλ1:q0 (x) − λ(x∗ )1:q0 k2 = kλ1:q0 (x) − λ̄(x∗ )k2 ≤ kλ1:q0 (x) − λ̄(x)k2 + kλ̄(x) − λ̄(x∗ )k2 ≤ 

for all x ∈ Bδ (x∗ ). Therefore, we see that limx→x∗ λ1:q0 (x) = λ(x∗ )1:q0 .
Finally, it follows from the triangle inequality that

kF (x) − F (x∗ )k2


 

≤ F (x) − F (x ) + kF(2q+1):D (x)k2 + kF(2q+1):D (x∗ )k2
1:(2q) 2 | {z }
0
q0 q0
!
X X
= ∇R(x) + λi (x)∇fi (x) − ∇R(x∗ ) − λi (x∗ )∇fi (x∗ ) + kF(2q+1):D (x)k2
i=1 i=1 1:(2q) 2
q 0
X
≤ λi (x)∇fi (x) − λi (x∗ )∇fi (x∗ ) + k∇R(x) − ∇R(x∗ )k2 + kF(2q+1):D (x)k2
i=1 2

where, as x → x∗ , the first term vanishes by the convergence of λ1:q0 (x) and the
continuity of each ∇fi (x), the second term converges to 0 by the continuity of ∇R(x)
and the third term vanishes by (8.53). Therefore, we conclude that

lim F (x) = F (x∗ ),


x→x∗

that is, F is continuous.

Lemma 8.11.13. For any initialization x∗ ∈ Γ, the Riemmanian Gradient Flow


(8.20) (or equivalently, (8.45)) is defined on [0, ∞).

Proof of Lemma 8.11.13. Let [0, T ) be the right maximal interval of existence of the
solution of Riemannian gradient glow and suppose T 6= ∞. Since R(x(t)) is monotone
decreasing, thus R(x(t)) is upper bounded by R(x(0)) and therefore k∇R(x(t))k is
dx(t)
also upper bounded. Since dt
≤ k∇R(x(t))k2 for any t < T , the left limit
2

377
x(T −) := limτ →T − x(τ ) must exist. By Corollary 1, Perko [233], x(T −) belongs to
boundary of U , i.e., uj (T −) = 0 or vj (T −) = 0 for some j ∈ [d] by Lemma 8.11.11.
By the definition of the Riemannian gradient flow in (8.20), we have

 
d dx(t)
(uj (t)vj (t)) = vj (t)e>
j uj (t)e>
j
dt dt
 
1
=− v (t)e> uj (t)e> F (x(t)).
4 j j j

Pn
By the expression of F (x(t)) = ∇R(x(t)) + i=1 λi (x(t))∇fi (x(t)), we then have

d
(uj (t)vj (t))
dt " # " n #
n n n
2X 2 1X 2X 2 1X
=− z + λi (x(t))zi,j uj (t)vj (t) − z − λi (x(t))zi,j uj (t)vj (t)
n i=1 i,j 2 i=1 n i=1 i,j 2 i=1
n
!
4X 2
=− z uj (t)vj (t).
n i=1 i,j

Pn
Denote sj = 4
n
2
i=1 zi,j . It follows that |uj (t)vj (t)| = |uj (0)vj (0)|e−sj t for all t ∈ [0, T ).
Taking the limit we have |uj (T −)vj (T −)| ≥ |uj (0)vj (0)|e−sj T > 0. Contradiction with
T 6= ∞!

Before showing that F satisfies the PL condition, we need the following two
intermediate results. Given two points u and v in Rd , we say u weakly dominate v
(written as u ≤ v) if and only if ui ≤ vi , for all i ∈ [d]. Given two subsets A and B of
RD , we say A weakly dominates B if and only if for any point v in B, there exists a
point u ∈ A such that u ≤ v.

Lemma 8.11.14. For some q ∈ [D], let S be any q-dimensional subspace of RD


and P = {u ∈ RD | ui ≥ 0, ∀i ∈ [D]}. Let u? be an arbitrary point in P and
Q = P ∩ (u? + S). Then there exists a radius r > 0, such that Br1 (0) ∩ Q is non-empty
and weakly dominates Q, where Br1 (0) is the `1 -norm ball of radius r centered at 0.

378
As a direct implication, for any continuous function f : P → R, which is coordinate-
wise non-decreasing, minx∈U f (x) can always be achieved.

Proof of Lemma 8.11.14. We will prove by induction on the environment dimension


D. For the base case of D = 1, either S = {0} or S = R, and it is straight-forward to
verify the desired for both scenarios.
Suppose the proposition holds for D − 1, below we show it holds for D. For
each i ∈ [D], we apply the proposition with D − 1 to Q ∩ {u ∈ P | ui = 0} (which
can be seen as a subset of RD−1 ), and let ri be the corresponding `1 radius. Set
r = maxi∈[D] ri , and we show that choosing the radius to be r suffices.
For any v ∈ Q, we take a random direction in S, denoted by ω. If ω ≥ 0 or ω ≤ 0,
we denote by y the first intersection (i.e., choosing the smallest λ) between the line
{v − λ|ω|}λ≥0 and the boundary of U , i.e., ∪D D
i=1 {z ∈ R | zi = 0}. Clearly y ≤ v. By

the induction hypothesis, there exists a u ∈ Br1 (0) ∩ Q such that u ≤ y. Thus u ≤ v
and meets our requirement.
If ω has different signs across its coordinates, we take y1 , y2 to be the first
intersections of the line {v − λ|ω|}λ∈R and the boundary of U in directions of
λ > 0 and λ < 0, respectively. Again by the induction hypothesis, there exist
u1 , u2 ∈ Br1 (0) ∩ Q such that u1 ≤ y1 and u2 ≤ y2 . Since v lies in the line con-
necting u1 and u2 , there exists some h ∈ [0, 1] such that v = (1 − h)u1 + hu2 . It
then follows that (1 − h)u1 + hu2 ≤ (1 − h)y1 + hy2 = v. Now since Q is convex,
we have (1 − h)u1 + hu2 ∈ Q, and by the triangle inequality it also holds that
k(1 − h)u1 + hu2 k1 ≤ r, so (1 − h)u1 + hu2 ∈ Br1 (0) ∩ Q. Therefore, we conclude that
Br1 (0) ∩ Q weakly dominates Q, and thus the proposition holds for D. This completes
the proof by induction.

Lemma 8.11.15. For some q ∈ [D], let S be any q-dimensional subspace of RD


and P = {u ∈ RD | ui ≥ 0, ∀i ∈ [D]}. Let u? be an arbitrary point in P and
Q = P ∩ (u? + S). Then there exists a constant c ∈ (0, 1] such that for any sufficiently
379
small radius r > 0, c · Q weakly dominates P ∩ (u? + S + Br2 (0)), where Br2 (0) is the
`2 -norm ball of radius r centered at 0.

Proof of Lemma 8.11.15. We will prove by induction on the environment dimension


D. For the base case of D = 1, either S = {0} or S = R. S = R is straight-forward;
for the case S = {0}, we just need to ensure c|u? | ≤ |u? | − r, and it suffices to pick
r = |u? | and c = 0.5.
Suppose the proposition holds for D − 1, below we show it holds for D. For each
i ∈ [D], we first consider the intersection between P ∩ (u? + S + Br2 (0)) and Hi := {u ∈
RD | ui = 0}. Let ui be an arbitrary point in P ∩(u? +S)∩Hi , then P ∩(u? +S)∩Hi =
P ∩ (ui + S) ∩ Hi = P ∩ (ui + S ∩ Hi ). Furthermore, there exists {αi }i∈[D] which only
depends on S and satisfies P ∩ (u∗ + S + Br2 (0)) ∩ Hi ⊂ P ∩ (ui + S ∩ Hi + Bα2 i r (0) ∩ Hi ).
Applying the induction hypothesis to P ∩ (ui + S ∩ Hi + Bα2 i r (0) ∩ Hi ), we know there
exists a c > 0 such that for sufficiently small r, c(P ∩(u? +S)∩Hi ) = c(P ∩(ui +S ∩Hi ))
weakly dominates P ∩ (ui + S ∩ Hi + Bα2 i r (0) ∩ Hi ).
For any point v in Q and any z ∈ Br2 (0), we take a random direction in S, denoted
by ω. If ω ≥ 0 or ω ≤ 0, we denote by y the first intersection between {v +z −λ|ω|}λ≥0
and the boundary of U . Clearly y ≤ v. Since y ∈ P ∩ (u? + S + Br2 (0)) ∩ Hi ⊂ P ∩ (ui +
S ∩Hi +Bα2 i r (0)∩Hi ), by the induction hypothesis, there exists a u ∈ c(P ∩(u? +S)∩Hi )
such that u ≤ y. Thus z ≤ v + z and z ∈ c(P ∩ (u? + S)) = c · Q.
If ω has different signs across its coordinates, we take y1 , y2 to be the first inter-
sections of the line {v + z − λ|ω|}λ∈R and the boundary of U in directions of λ > 0
and λ < 0, respectively. By the induction hypothesis, there exist u1 , u2 ∈ c · Q
such that u1 ≤ y1 and u2 ≤ y2 . Since v + z lies in the line connecting u1 and u2 ,
there exists some h ∈ [0, 1] such that v + z = (1 − h)y1 + hy2 . It then follows that
(1−h)u1 +hu2 ≤ (1−h)y1 +hy2 = v+z. Since Q is convex, we have (1−h)u1 +hu2 ∈ cQ.
Therefore, we conclude that cQ ∩ Q weakly dominates P ∩ (u? + S + Br2 (0)) for all

380
sufficiently small r, and thus the proposition holds for D. This completes the proof
by induction.

Lemma 8.11.16. (Polyak-Lojasiewicz condition for F .) For any x∗ such that L(x∗ ) =
0, i.e., x∗ ∈ Γ, there exist a neighbourhood U 0 of x∗ and a constant c > 0, such that
kF (x)k22 ≥ c · max(R(x) − R(x∗ ), 0) for all x ∈ U 0 ∩ Γ. Note this requirement is only
non-trivial when kF (x∗ )k2 = 0 since F is continuous.

Proof of Lemma 8.11.16. It suffices to show the PL condition for {x | F (x) = 0}.
We need to show for any x∗ satisfying F (x∗ ) = 0, there exist some  > 0 and
C > 0, such that for all x ∈ Γ ∩ B2 (x∗ ) with R(x) > R(x∗ ), it holds that kF (x)k22 ≥
C(R(x) − R(x∗ )).

u

Canonical Case. We first prove the case where x = v
itself is a canonical
parametrization of w = u 2
− v 2 , i.e., uj vj = 0 for all j ∈ [d]. Since x∗ satisfies
∇F (x∗ ) = 0, by Lemma 8.11.11, we have x∗ = ψ(w∗ ) where w∗ = (u∗ ) 2
− (v ∗ ) 2 . In
this case, we can rewrite both R and F as functions of w ∈ Rd . In detail, we define
R0 (w) = R(ψ(w)) and F 0 (w) = F (ψ(w)) for all w ∈ Rd . For any w in a sufficiently
small neighbourhood of w∗ , it holds that sign(wj ) = sign(wj∗ ) for all j ∈ [q]. Below
we show that for each possible sign pattern of w(q+1):d , there exists some constant
C which admits the PL condition in the corresponding orthant. Then we take the
minimum of all C from different orthant and the proof is completed. W.L.O.G., we
assume that wj ≥ 0, for all j = q + 1 . . . , d.
We temporarily reorder the coordinates as x = (u1 , v1 , u2 , v2 , . . . , ud , vd )> . Recall
that Z = [z1 , . . . , zn ]> is a n-by-d matrix, and we have

2
kF 0 (w)k2 = minn (a − sign(w) Z > λ) 2 , |w| ,
λ∈R

381
Pn 2
where a = 8
n
∈ Rd . Since F (x∗ ) = 0, there must exist λ∗ ∈ Rn , such that
i=1 zi

the first 2q coordinates of ∇R(x∗ ) + ni=1 λ∗i ∇fi (x∗ ) are equal to 0. As argued in the
P

proof of Lemma 8.11.12, we can assume the first q 0 rows of Z are linear independent
0
on  q coordinates for some q ∈ [q]. In other words, Z can be written as
 the first
ZA ZB  0
  where ZA ∈ Rq ×q . We further denote λa := λ1:q0 , λb := λ(q0 +1):n , aa := a1:q
0 ZD
and ab := a(q+1):d , wa := w1:q and wb := w(q+1):d for convenience, then we have

2
kF 0 (w)k2 = minn (aa + sign(wa ) ZA> λa ) 2 , |wa | + (ab + ZB> λa + ZD
>
λb ) 2 , wb .
λ∈R

(8.56)

Pn
Since every w in Γ is a global minimizer, R0 (w) = R0 (w) + i=1 λ∗i (zi> w − yi ) :=
g > w + R0 (w∗ ), where g = sign(w) a + Z > λ∗ . Similarly we define ga := g1:q and
gb := g(q+1):d . It holds that ga = 0 and we assume ZD gb = 0 without loss of generality,
because this can always be done by picking suitable λ∗i for i = q 0 + 1, . . . , n. (We have
such freedom on λ∗q0 +1:n because they doesn’t affect the first 2q coordinates.)
We denote λa − λ∗a by ∆λa , then since 0 = ga = sign(wa ) ZA> λ∗a + aa , we further
have

(aa + sign(wa ) ZA> λa ) 2 , |wa | = (aa + sign(wa ) ZA> λ∗a + sign(wa ) ZA> ∆λa ) 2 , |wa |

= (sign(wa ) ZA> ∆λa ) 2 , |wa | .

On the other hand, we have gb = sign(wb ) ab + ZB> λ∗a + ZD


> ∗
λb = ab + ZB> λ∗a + ZD
> ∗
λb
by the assumption that each coordinate of wb is non-negative. Combining this with
the above identity, we can rewrite Equation (8.56) as:

2
kF 0 (w)k2 = min (ZA> ∆λa ) 2 , |wa | + (gb + ZB> ∆λa + ZD
>
λb ) 2 , wb . (8.57)
λ∈RD

382
Now suppose R0 (w) − R0 (w∗ ) = gb> wb = δ for some sufficiently small δ (which can
be controlled by ). We will proceed in the following two cases separately.

• Case I.1: k∆λa k2 = Ω( δ). Since ZA has full row rank, (ZA> ∆λa ) 2
1
=
2
(ZA> ∆λa ) 2
≥ k∆λa k22 λ2min (ZA ) is lower-bounded. On the other hand, we can
choose  small enough such that ∀i ∈ [q]|(wa )2i | ≥ 12 (wa∗ )2i . Thus the first term of
Equation (8.57) is lower bounded by k∆λa k22 λ2min (ZA ) · mini∈[q] 12 (wa∗ )2i = Ω(δ) =
Ω(R0 (w) − R0 (w∗ )).

• Case I.2: k∆λa k2 = O( δ). Let u = gb + ZB> ∆λa + ZD
>
λb , then we have
> 0
u ∈ S + Bc2√δ (0) for some constant c > 0, where S = {gb + ZD λb | λb ∈ Rn−q }.
1
By Lemma 8.11.14, there exists some constant c0 ≥ 1, such that c0
· S weakly
dominates S + Bc2√δ (0). Thus we have kF 0 (w)k22 ≥ inf u∈S+Bc√δ (0) hu 2 , wb i ≥
inf u∈ 1 ·S hs 2 , wb i, where the last step is because each coordinate of wb is non-
c0

negative.

Let A be the orthogonal complement of span(ZD , gb ), i.e., the spanned space of


δ
columns of ZD and gb , we know wb ∈ g
kgb k22 b
+ A, since ZD wb = ZD w∗2 = 0 and
gb> wb = δ. Therefore,

kF 0 (w)k22 D
2 wb
E
inf ≥ inf inf u ,
w:R0 (w)−R0 (w∗ )=δ>0 R0 (w) − R0 (w ∗ ) wb :R0 (w)−R0 (w∗ )=δ>0 u∈ 1 ·S δ
c0

1
≥ inf u 2 , wb . (8.58)
c20 wb ∈ δ
gb +A,wb ≥0,u∈S
kgb k22

Note hu 2 , wb i is a monotone non-decreasing function in the first joint orthant,


0
i.e., {(u, wb ) ∈ Rd ×Rd−q | u ≥ 0, wb ≥ 0}, thus by Lemma 8.11.15 the infinimum
can be achieved by some finite (u, wb ) in the joint first orthant. Applying the same
argument to each other orthant of u ∈ Rd , we conclude that the right-hand-side
of (8.58) can be achieved.

383
On the other hand, we have u> wb = δ > 0 for all wb ∈ δ
g
kgb k22 b
+ A and u ∈ S,
by ZD gb = 0 and the definition of A. This implies there exists at least one
i ∈ [d − q 0 ] such that w2,i ui > 0, which further implies hu 2 , wb i > 0. Therefore,
we conclude that kF 0 (w)k22 = Ω(R0 (w) − R0 (w0 )).

u
 2 2
General Case. Next, for any general x = v
, we define w = u −v and
m = min{u 2 , v 2 }, where min is taken coordinate-wise. Then we can rewrite kF (x)k22
as

kF (x)k22
       2
a  Z   u
= minn   +   λ  
λ∈R
a −Z v
2
     2  
2
a  Z   u 
= minn   +   λ  
λ∈R 2
a −Z v
1
     2   
a  Z   2 m
= minn   +   λ ψ(w) +  

λ∈R
a −Z m
1
     2      2  
a  Z   a  Z   m
≥ minn   +   λ ψ(w) 2 + minn   +   λ  
λ∈R λ∈R
a −Z a −Z m
1 1
     2        2

a  Z   a  Z    m
= minn   +   λ ψ(w) + minn   +   λ √  .
λ∈R λ∈R
a −Z a −Z m
2 2

384
Then applying the result for the previous case yields the following for some constant
C ∈ (0, 1):

       2

a  Z    m
kF (x)k22 ≥ C(R(ψ(w)) − R(ψ(w∗ )) + minn   +   λ √ 
λ∈R
a −Z m
2

= C(R(ψ(w)) − R(x∗ ) + 2 a 2 , m

≥ C(R(ψ(w)) − R(x∗ ) + 2 min ai ha, mi


i∈[d]

= C(R(ψ(w)) − R(x∗ ) + min ai (R(x) − R(ψ(w)))


i∈[d]
 
≥ min C, min ai (R(x) − R(x∗ )),
i∈[d]

where the first equality follows from the fact that x∗ = ψ(w∗ ) and the last inequality is
due to the fact that both R(ψ(w) − R(ψ(w∗ )) and R(x) − R(ψ(w)) are non-negative.
This completes the proof.

Now, based on the PL condition, we can show that (8.20) indeed converges.

Lemma 8.11.17. The trajectory of the flow defined in (8.20) has finite length, i.e.,
R ∞ dx
k k dt < ∞ for any x∗ ∈ Γ. Moreover, x(t) converges to some x(∞) when
t=0 dt 2

t → ∞ with F (x(∞)) = 0.

Proof of Lemma 8.11.17. Note that along the Riemannian gradient flow, R(x(t)) is
non-increasing, thus kx(t)k2 is bounded over time and {x(t)}t≥0 has at least one limit
point, which we will call x∗ . Therefore, R(x∗ ) is a limit point of R(x(t)), and again since
R(x(t)) is non-increasing, it follows that R(x(t)) ≥ R(x∗ ) and limt→∞ R(x(t)) = R(x∗ ).
Below we will show limt→∞ x(t) = x∗ .
D E
Note that dt = ∇R(x(t)), dt = − ∇R(x(t)), 14 F (x(t)) = − 41 kF (x(t))k22
dR(x(t)) dx(t)

where the last equality applies Lemma 8.11.10. By Lemma 8.11.16, there exists a
neighbourhood of x∗ , U 0 , in which PL condition holds of F . Since x∗ is a limit point,
there exists a time T0 , such that xT0 ∈ U 0 . Let T1 = inf t≥T0 {x(t) ∈
/ U 0 } (which is
385
equal to ∞ if x(t) ∈ U 0 for all t ≥ T0 ). Since x(t) is continuous in t and U is open, we

know T1 > T0 and for all t ∈ [T0 , T1 ), we have kF (x(t))k2 ≥ c(R(x(t)) − R(x∗ ))1/2 .
Thus it holds that for t ∈ [T0 , T1 ),


d(R(x(t)) − R(x∗ )) c
≤ − (R(x(t)) − R(x∗ ))1/2 kF (x(t))k2 ,
dt 4

that is,


d(R(x(t)) − R(x∗ ))1/2 c
≤− kF (x(t))k2 .
dt 8

Therefore, we have

Z T1
8
kF (x(t))k2 dt ≤ √ (R(x(T0 )) − R(x∗ ))1/2 . (8.59)
t=T0 c

Thus if we pick T0 such that R(x(T0 )) − R(x∗ ) is sufficiently small, R(T1 ) will remain
in U if T1 is finite. Contradiction! This implies that T1 has to be ∞. Therefore,
Equation (8.59) shows that the trajectory of x(t) is of finite length, so x(∞) :=
limt→∞ x(t) exists and is equal to x∗ . As a by-product, F (x∗ ) must be 0.

Finally, collecting all the above lemmas, we are able to prove Lemma 8.7.5. In
Lemma 8.11.17 we already show the convergence of x(t) as t → ∞, the main part
of the proof of Lemma 8.7.5 is to show the x(∞) cannot be sub-optimal stationary
points of R on Γ, the closure of Γ. The key idea here is that we can construct a
different potential φ for each such sub-optimal stationary point x∗ , such that (1) φ(xt )
is locally increasing in a sufficiently neighborhood of x∗ and (2) limx→x∗ φ(x) = −∞.

Lemma 8.7.5. Let {xt }t≥0 ⊆ RD be generated by the flow defined in (8.20) with any
initialization x0 ∈ Γ. Then x∞ = limt→∞ xt exists. Moreover, x∞ = x∗ is the optimal
solution of (8.21).

386
u(∞)

Proof of Lemma 8.7.5. We will prove by contradiction. Suppose x(∞) = v(∞)
=
2
limt→∞ x(t) is not the optimal solution to (8.21). Denote w(t) = (u(t)) − (v(t)) 2 ,
then w(∞) = limt→∞ w(t) is not the optimal solution to (8.35). Thus we have
R(w(t)) > R(w∗ ). Without loss of generality, suppose there is some q ∈ [d] such
that (ui (∞))2 + (vi (∞))2 > 0 for all i = 1, . . . , q and ui (∞) = vi (∞) = 0 for all
i = q + 1, . . . , d. Again, as argued in the proof of Lemma 8.11.12, we can assume that,
for some q 0 ∈ [q],

{zi,1:q }i∈[q0 ] is linearly independent and zi,1:q = 0 for all i = q 0 + 1, . . . , n. (8.60)

Since both w(∞) and w∗ satisfy the constraint that Zw(∞) = Zw∗ = Y , we further
have

0 = hzi , w(∞)i = hzi , w∗ i = hzi,(q+1):d , w(q+1):d



i, for all i = q 0 + 1, . . . , n. (8.61)

Consider a potential function ϕ : U → R defined as

d
X
wj∗ ln(uj )2 1{wj∗ > 0} − ln(vj )2 1{wj∗ < 0} .
 
ϕ(x) = ϕ(u, v) =
j=q+1

Clearly limt→∞ ϕ(x(t)) = −∞ if limt→∞ x(t) = x(∞). Below we will show contra-
diction if x(∞) is suboptimal. Consider the dynamics of ϕ(x) along the Riemannian
gradient flow:

   
dϕ dx(t) 1
(x(t)) = ∇ϕ(x(t)), = − ∇ϕ(x(t)), F (x(t)) (8.62)
dt dt 4

387
where F is defined previously in Lemma 8.11.10. Recall the definition of F , and we
have
* q 0 +
1 1X
h∇ϕ(x(t)), F (x(t))i = ∇ϕ(x(t)), ∇R(x(t)) + λi (x(t))∇fi (x(t))
4 4 i=1
| {z }
I1
* n
+
1 X
+ ∇ϕ(x(t)), λi (x(t))∇fi (x(t)) . (8.63)
4 i=q0 +1
| {z }
I2

To show h∇ϕ(x(t)), F (x(t))i < 0, we analyze I1 and I2 separately. By the definition


of ϕ(x), we have

d
1{wj∗ > 0} 1{wj∗ < 0}
X  
∇ϕ(x) = 2wj∗ · ej − · eD+j
j=q+1
uj vj

zi u

where ej is the j-th canonical base of Rd . Recall that ∇fi (x) = 2 −zi v
, and we
further have

n d
1{wj∗ > 0} 1{wj∗ < 0}
X X  
I2 = λi (x(t)) wj∗ ui +hej , zi hej , zi vi
0
i=q +1 j=q+1
uj vj

n d
1{wj∗ > 0} 1{wj∗ < 0}
X X  

= λi (x(t)) wj zi,j uj + zi,j vj
i=q 0 +1 j=q+1
uj vj
n
X d
X n
X
= λi (x(t)) wj∗ zi,j = ∗
λi (x(t))hzi,(q+1):d , w(q+1):d i=0 (8.64)
i=q 0 +1 j=q+1 i=q 0 +1

where the last equality follows from (8.61).


Next, we show that I1 < 0 by utilizing the fact that w∗ − w(∞) is a descent
direction of R0 (w). For w ∈ Rd , define fei (w) = zi> w and

q 0
X
R(w)
e = R(w) + λi (x(∞))(fei (w) − yi ).
i=1

388
Clearly, for any w ∈ RD satisfying Zw = Y , it holds that fei (w) − yi = 0 for each
i ∈ [n], and thus R(w) = R(w).
e In particular, we have R(w(∞))
e = R(w(∞)) >
R(w∗ ) = R(w
e ∗ ). Since R(w)
e is a convex function, it follows that R(w(∞)
e + s(w∗ −
e ∗ ) + (1 − s)R(∞)
w(∞))) ≤ sR(w e < R(w(∞))
e for all 0 < s ≤ 1, which implies
dR
(w(∞) + s(w∗ − w(∞)))|s=0 < −2c < 0+ for some constant c > 0. Note that, for
e
dt

small enough s > 0, we have

d n
!
∗ 4X X 2
R(w(∞) + s(w − w(∞))) = z |wj (∞) + s(wj∗ − wj (∞))|
n j=1 i=1 i,j
q n
!
4X X 2
= z sign(wj (∞))(wj (∞) + s(wj∗ − wj (∞)))
n j=1 i=1 i,j
d n
!
4 X X 2
+ z s|wj∗ |.
n j=q+1 i=1 i,j

Therefore, we can compute the derivative with respect to s at s = 0 as

dR
e
−2c > (w(∞) + s(w∗ − w(∞)))
dt s=0
q n
! d n
!
4X X 2 4 X X
= z sign(wj (∞))(wj∗ − wj (∞)) + z 2 |wj∗ |
n j=1 i=1 i,j n j=q+1 i=1 i,j
q0
X
+ λi (x(∞))zi> (w∗ − wj (∞))
i=1
qn
! d n
!
4X X 2 4 X X 2
= z sign(wj (∞))(wj∗ − w(∞)) + z |wj∗ |
n j=1 i=1 i,j n j=q+1 i=1 i,j
q q0 q0
X X d
X X
+ (wj∗ − wj (∞)) λi (x(∞))zi,j + wj∗ λi (x(∞))zi,j (8.65)
j=1 i=1 j=q+1 i=1

where the second equality follows from the fact that w(q+1):d (∞) = 0. Since x(t)
converges to x(∞), we must have F (x(∞)) = 0, which implies that for each j ∈

389
{1, . . . , q},

q0 q0
" n #
∂R X ∂fi 4X 2 X
0= (x(∞)) + λi (x(∞)) (x(∞)) = 2uj (∞) zi,j + λi (x(∞))zi,j ,
∂uj i=1
∂u j n i=1 i=1
q0 q0
" n #
∂R X ∂fi 4X 2 X
0= (x(∞)) + λi (x(∞)) (x(∞)) = 2vj (∞) z − λi (x(∞))zi,j .
∂vj i=1
∂vj n i=1 i,j i=1

Combining the above two equalities yields

n q0
4X 2 X
zi,j = − sign(wj (∞)) λi (x(∞))zi,j , for all j ∈ [q].
n i=1 i=1

Apply the above identity together with (8.65), and we obtain

q q0 !
d n
X X 4 X X 2
−2c > − sign(wj (∞))2 (wj∗ − w(∞)) λi (x(∞))zi,j + z |wj∗ |
j=1 i=1
n j=q+1 i=1 i,j
q q0 d q0
X X X X
+ (wj∗ − wj (∞)) λi (x(∞))zi,j + wj∗ λi (x(∞))zi,j
j=1 i=1 j=q+1 i=1
! q 0
d n d
4 X X 2 X X
= z |wj∗ | + wj∗ λi (x(∞))zi,j (8.66)
n j=q+1 i=1 i,j j=q+1 i=1

390
On the other hand, by directly evaluating ∇R(x(t)) and each ∇fi (x(t)), we can
compute I1 as

q0
d
" n #
X wj∗ 1{wj∗ > 0} 2 X 2 1 X
I1 = zi,j uj (t) + λi (x(t))zi,j uj (t)
j=q+1
uj (t) n i=1
2 i=1
q0
d
" n #
X wj∗ 1{wj∗ < 0} 2 X 1 X
2
− zi,j vj (t) − λi (x(t))zi,j vj (t)
j=q+1
vj (t) n i=1
2 i=1
q0
d n
! d
2 X X 2 1 X X
= z |wj∗ | + w∗ λi (x(t))zi,j
n j=q+1 i=1 i,j 2 j=q+1 j i=1
q0
d n
! d
2 X X 2 ∗ 1 X ∗X
= z |wj | + w λi (x(∞))zi,j
n j=q+1 i=1 i,j 2 j=q+1 j i=1
q 0
d
1 X ∗X
+ wj (λi (x(t)) − λi (x(∞))) zi,j .
2 j=q+1 i=1

We already know that λ1:q0 (x) is continuous at x(∞) by the proof of Lemma 8.11.12,
so the third term converges to 0 as x(t) tends to x(∞). Now, applying (8.66), we
immediately see that there exists some δ > 0 such that I1 < −c for x(t) ∈ Bδ (x(∞)).
As we have shown in the above that I2 = 0, it then follows from (8.62) and (8.63) that


(x(t)) > c, for all x(t) ∈ Bδ (x(∞)). (8.67)
dt

Since limt→∞ x(t) = x(∞), there exists some T > 0 such that x(t) ∈ Bδ (x(∞)) for all
t > T . By the proof ofLemma 8.11.13, we know that ϕ(x(T )) > −∞, then it follows
from (8.67) that

Z ∞ Z ∞
dϕ(x(t))
lim ϕ(x(t)) = ϕ(x(T )) + dt > ϕ(x(T )) + cdt = ∞
t→∞ T dt T

which is a contradiction. This finishes the proof.

391
8.11.6 Proof of Theorem 8.7.7

Here we present the lower bound on the sample complexity of GD in the kernel regime.

i.i.d.
Theorem 8.7.7. Assume z1 , . . . , zn ∼ N (0, Id ) and yi = zi> w∗ , for all i ∈ [n].
Define the loss with linearized model as L(x) = ni=1 (fi (x0 ) + h∇fi (x0 ), x − x0 i − yi )2 ,
P

where x = uv and x0 = uv00 = α 11 . Then for any groundtruth w∗ , any learning


  

rate schedule {ηt }t≥1 , and any fixed number of steps T , the expected `2 loss of x(T )
is at least (1 − nd ) kw∗ k22 , where x(T ) is the T -th iterate of GD on L, i.e., x(t + 1) =
x(t) − ηt ∇L(x(t)), for all t ≥ 0.

Proof of Theorem 8.7.7. We first simplify the loss function by substituting x0 = x −


x(0), so correspondingly x00 = 0 and we consider L0 (x0 ) := L(x) = (h∇fi (x(0)), x0 i−yi )2 .
We can think as if GD is performed on L0 (x0 ). For simplicity, we still use the x and
L(x) notation in below.
In order to show test loss lower bound against a single fixed target function, we
must take the properties of the algorithm into account. The proof is based on the
observation that GD is rotationally equivariant [24, 82] as an iterative algorithm,
i.e., if one rotates the entire data distribution (including both the training and test
data), the expected loss of the learned function remains the same. Since the data
distribution and initialization are invariant under any rotation, it means the expected
loss of x(T ) with ground truth being w∗ is the same as the case where the ground
truth is uniformly randomly sampled from all vectors of `2 -norm kw∗ k2 .
Thus the test loss of x(T ) is

Ez (h∇fz (x(0)), x(T )i − hz, w∗ i)2


 
(8.68)
=Ez (hz, w∗ − (u(T ) − v(T ))i)2 = kw∗ − (u(T ) − v(T ))k22 .
 

Note x(T ) ∈ span{∇fx (x(0))}, which is at most an n-dimensional space spanned by


the gradients of model output at x(0), so is u(T ) − v(T ). We denote the corresponding
392
space for u(T ) − v(T ) by S, so dim(S) ≤ n and it holds that kw∗ − (u(T ) − v(T ))k22 ≥
k(ID − PS )w∗ k22 , where PS is projection matrix onto space S.
The expected test loss is lower bounded by

Ew∗ Ezi kw∗ − (u(T ) − v(T ))k22 = Ezi Ew∗ kw∗ − (u(T ) − v(T ))k22
     

≥ min Ew∗ k(ID − PS )w∗ k22


 
{zi }i∈[n]
 n ∗ 2
≥ 1− kw k2 .
d

393
Chapter 9

Implicit Bias of Gradient Descent


Operating on Edge of Stability:
Sharpness Reduction

Deep learning experiments by Cohen et al. [234] using deterministic Gradient Descent
(GD) revealed an Edge of Stability (EoS) phase when learning rate (LR) and sharpness
(i.e., the largest eigenvalue of Hessian) no longer behave as in traditional optimization.
Sharpness stabilizes around 2/LR and loss goes up and down across iterations, yet
still with an overall downward trend. This chapter mathematically analyzes a new
mechanism of implicit regularization in the EoS phase, whereby GD updates due
to non-smooth loss landscape turn out to evolve along some deterministic flow on
the manifold of minimizers as introduced in Chapter 8. This is in contrast to many
previous results about implicit bias either relying on infinitesimal updates or noise in
gradient. Formally, for any smooth function L with certain regularity condition, this
η
effect is demonstrated for (1) Normalized GD, i.e., GD with a varying LR ηt = k∇L(x(t))k
p
and loss L; (2) GD with constant LR and loss L − minx L(x). Both provably enter

394
the Edge of Stability, with the associated flow on the manifold minimizing λ1 (∇2 L).
The above theoretical results have been corroborated by an experimental study.

9.1 Introduction

Traditional convergence analyses of gradient-based algorithms assume learning rate η


is set according to the basic relationship η < 2/λ where λ is the largest eigenvalue
of the Hessian of the objective, called sharpness1 . Descent Lemma says that if this
relationship holds along the trajectory of Gradient Descent, loss drops during each
iteration. In deep learning where objectives are nonconvex and have multiple optima,
similar analyses can show convergence towards stationary points and local minima. In
practice, sharpness is unknown and η is set by trial and error. Since deep learning works,
it has been generally assumed that this trial and error allows η to adjust to sharpness
so that the theory applies. But recent empirical studies [234, 235] showed compelling
evidence to the contrary. On a variety of popular architectures and training datasets,
GD with fairly small values of η displays following phenomena that they termed Edge
of Stability (EoS): (a) Sharpness rises beyond 2/η, thus violating the above-mentioned
relationship. (b) Thereafter sharpness stops rising but hovers noticeably above 2/η and
even decreases a little. (c) Training loss behaves non-monotonically over individual
iterations, yet consistently decreases over long timescales.
Note that (a) was already pointed out by Li et al. [50]. Specifically, in modern
deep nets, which use some form of normalization combined with weight decay, training
to near-zero loss must lead to arbitrarily high sharpness. (However, Cohen et al. [234]
show that the EoS phenomenon appears even without normalization.) Phenomena
(b), (c) are more mysterious, suggesting that GD with finite η is able to continue
decreasing loss despite violating η < 2/λ, while at the same time regulating further
increase in value of sharpness and even causing a decrease. These striking inter-related
1
Confusingly, another traditional name for λ is smoothness.
395
phenomena suggest a radical overhaul of our thinking about optimization in deep
learning. At the same time, it appears mathematically challenging to analyze such
phenomena, at least for realistic settings and losses (as opposed to toy examples with
2 or 3 layers). The current paper introduces frameworks for doing such analyses.
We start by formal definition of stableness, ensuring that if a point + LR combi-
nation is stable then a gradient step is guaranteed to decrease the loss by the local
version of Descent Lemma.

Definition 9.1.1 (Stableness). Given a loss function L, a parameter x ∈ Rd and LR


η > 0 we define the stableness of L at (x, η) be SL (x, η) := η · sup0≤s≤η λ1 (∇2 L(x −
s∇L(x))). We say L is stable at (x, η) iff the stableness of L at (x, η) is smaller than
or equal to 2; otherwise we say L is unstable at (x, η).

The above defined stableness is a better indicator for EoS than only using the
sharpness at a specific point x, i.e. ηλ1 (∇2 L(x)) < 2, because the loss can still oscillate
2
in the latter case. A concrete example is L(x) = |x|, x ∈ R. For any c ∈ (0, 1) and
LR η > 0, the GD iterates x(2k) = cη and x(2k + 1) = −(1 − c)η, always have zero
sharpness for all k ∈ N, but Descent Lemma doesn’t apply because the gradient is not
continuous around x = 0 (i.e. the sharpness is infinity when x = 0). As a result, the
loss is not stable and oscillates between cη and (1 − c)η.

9.1.1 Two Provable Mechanisms for Edge of Stability: Non-

smoothness and Adaptivity

In this chapter we identify two settings where GD provably operates on Edge of


Stability. The intuition is from Definition 9.1.1, which suggests that either sharpness
or learning rate has to increase to avoid GD converge and stays at Edge of Stability.

2
See such experiments (e.g., ReLU CNN (+BN), Figure 75) in Appendix of in Cohen et al. [234].
396
102 102

100 100
101 101
L 50 L 50
12 12
x
0
x
100 0 100
4 4
4 4
y0 4 4
y0 4 4


(a) GD on L (b) Normalized GD on L

Figure 9.1: GD operating on EoS oscillates around the zero loss manifold Γ = {(x, y) |
y = 0} while slowly moving towards flatter local minima. Here L(x, y) = (1 + x2 )y 2
and the sharpness of L decreases as |x| decreases.

The first setting, which is simple yet quite general, is to consider a modified
training loss f (L) where f : R → R is a monotone increasing but non-smooth function.

For concreteness, assume GD is performed on L e := L where L is a smooth loss
∇L
function with minx L(x) = 0 and ∇2 L 6= 0 at its minimizers. Note that ∇L
e= √
2 L
2L∇2 L−∇L∇L>
and ∇2 L
e= √ 3 , which implies ∇2 L
e must diverge whenever x converges to
4 L

any minimizer where ∇2 L has rank at least 2, since ∇L∇L> is rank-1. (An analysis
is also possible when ∇2 L is rank-1, which is the reason for Definition 9.1.1.)
The second setting assumes that the loss is smooth but learning rate is effectively
adaptive. We focus a concrete example, Normalized Gradient Descent, x ← x −
η∇L/k∇Lk, which exhibits EoS behavior as ∇L → 0. We can view Normalized GD
η
as GD with a varying LR ηt = k∇L(x(t))k
, which goes to infinity when ∇L → 0.
3
These analyses will require (1) The zero-loss solution set {x | L(x) = 0} contains
a (D − M ) dimensional submanifold of RD for some 1 ≤ M ≤ D and we denote it by
Γ and (2) ∇2 L(x) is rank-M for any x ∈ Γ. Note that while modern deep learning

3
Without loss of generality, we assume minx0 L(x0 ) = 0 throughout the paper. The main results
for Normalized GD still hold √ if we relax the assumption
√ and√only √assume Γ to be a manifold of
local minimizers. For GD on L, we need to replace L by L − Lmin where Lmin is the local
minimum.
397
evolved using non-differentiable losses, the recent use of activations such as Swish
[236] instead of ReLU has allowed differentiable losses without harming performance.

Our Contribution: We show that Normalized GD on L (Section 9.5.2) and GD



on L (Section 9.5.3) exhibit similar two-phase dynamics with sufficiently small LR
η. In the first phase, GD tracks gradient flow (GF), with a monotonic decrease in loss
until getting O(η)-close to the manifold (Theorems 9.5.3 and 9.5.5) and the stableness
becomes larger than 2. In the second phase, GD no longer tracks GF and loss is not
monotone decreasing due to the high stableness. Repeatedly overshooting, GD iterate
jumps back and forth across the manifold while moving slowly along the direction in
the tangent space of the manifold which decreases the sharpness. (See Figure 9.1 for a
graphical illustration) Formally, we prove when η → 0, the trajectory of GD converges
to some limiting flow on the manifold. (Theorems 9.5.4 and 9.5.6) We further prove
that in both settings GD in the second phase operates on EOS, and loss decreases in
a non-monotone manner. Formally, we show that the average stableness over any two

consecutive steps is at least 2 and that the average of L/η over two consecutive is
proportional to sharpness or square root of sharpness. (Theorems 9.5.7 and 9.5.8)
Though many works have suggested (primarily via experiments and some intuition)
that the training algorithm in deep learning implicitly selects out solutions of low
sharpness in some way, we are not aware of a formal setting where this had ever been
made precise. Note that this result requires no stochasticity as in SGD (c.f.Chapter 8),
though we need to inject tiny noise (e.g., of magnitude O(η 100 ) ) to GD iterates
occasionally (Algorithms 7 and 8). We believe that this is due the technical limitation
of our current analysis and can be relaxed with a more advanced analysis. Indeed, in
experiments, our theoretical predictions hold for the deterministic GD directly without
any perturbation.

398
Novelty of Our Analysis: Our analysis is inspired by the mathematical framework
of studying limiting dynamics of SGD around manifold of minimizers by Li et al. [133],
where the high-level idea is to introduce a projection function Φ mapping the current
iterate xt to the manifold and it suffices to understand the dynamics of Φ(xt ). It turns
out that the one-step update of Φ(xt ) depends on the second moment of (stochastic)
gradient at xt , E[∇L(xt )(∇L(xt ))> ]. While for SGD the second moment converges to
the covariance matrix of stochastic gradient (see Chapter 8) as xt gets close to the
√ ∇L(xt )
manifold when η → 0, for GD operating on EOS, the updates ∇ L(xt ) or k∇L(x t )k
is
non-smooth and not even defined at the manifold of the minimizers! To show Φ(xt )
moves in the direction which decreases the sharpness, the main technical difficulty is
√ ∇L(xt )
to show that ∇ L(xt ) or k∇L(x t )k
aligns to the top eigenvector of the Hessian ∇2 L(xt )
and then the analysis follows from the framework by Li et al. [133].
To prove the alignment between the gradient and the top eigenvector of Hessian, it
boils down to analyze Normalized GD on quadratic functions (9.2), which to the best
of our knowledge has not been studied before. The dynamics is like chaotic version of
power iteration, and we manage to show that the iterate will always align to the top
eigenvector of Hessian of the quadratic loss. The proof is based on identifying a novel
potential (Section 9.3) and might be of independent interest.

9.2 Related Works

Sharpness: Low sharpness has long been related to flat minima and thus to good
generalization [9, 237]. Recent study on predictors of generalization [238] does show
sharpness-related measures as being good predictors, leading to SAM algorithm that
improves generalization by explicitly controlling a parameter related to sharpness [8].
However, Dinh et al. [239] show that due to the positive homogeneity in the network
architecture, networks with rescaled parameters can have very different sharpness yet

399
be the same to the original one in function space. This observation weakens correlation
between sharpness and and generalization gap and makes the definition of sharpness
ambiguous. In face of this challenge, multiple notions of scale-invariant sharpness have
been proposed [240–243]. Especially, Yi et al. [244], Kwon et al. [245] derived new
algorithms with better generalization by explicitly regularizing new sharpness notions
aware of the symmetry and invariance in the network. He et al. [246] goes beyond
the notion of sharpness/flatness and argues that the local minima of modern deep
networks can be asymmetric, that is, sharp on one side, but flat on the other side.

Limiting Diffusion/Flow around Manifold of Minimizers: The idea of ana-


lyzing the behavior of SGD with small LR along the the manifold originates from
[165], which gives a local analysis on a special noise type named label noise, i.e. noise
covariance is equal to Hessian at minimizers. Damian et al. [166] extends this analysis
and show SGD with label noise finds approximate stationary point for original loss
plus some Hessian-related regularizer. The formal mathematical framework of approx-
imating the limiting dynamics of SGD with arbitrary noise by Stochastic Differential
Equations is later established by Li et al. [133], which is built on the convergence
result for solutions of SDE with large-drift [187].

Implicit Bias: The notion that training algorithm plays an active role in selecting
the solution (when multiple optima exist) has been termed the implicit bias of the
algorithm [105] and studied in a large number of papers [14, 20, 25, 29, 95, 98, 100,
112, 138, 142, 159, 172, 247]. In the infinite width limit, the implicit bias of Gradient
Descent is shown to be the solution with the minimal RKHS norm with respect to
the Neural Tangent Kernel (NTK) [15, 109, 148, 150–155, 213]. The implicit bias
results from these papers are typically proved by performing a trajectory analysis
for (Stochastic) Gradient Descent. Most of the results can be directly extended to
the continuous limit (i.e., GD infinitesimal LR) and even some heavily relies on the
400
conservation property which only holds for the continuous limit. In sharp contrast,
the implicit bias shown in this chapter – reducing the sharpness along the minimizer
manifold – requires finite LR and doesn’t exist for the corresponding continuous limit.
Other implicit bias results that fundamentally relies on the finiteness of LR includes
stability analysis [248, 249] and implicit gradient regularization [250], which is a special
case of approximation results for stochastic modified equation by Li et al. [190, 214].

Non-monotone Convergence of Gradient Descent : Recently, a few conver-


gence results for Gradient Descent have been made where the loss is not monotone
decreasing, meaning at certain steps the stableness can go above 2 and the descent
lemma breaks. These results typically involve a two-phase analysis where in the first
phase the sharpness decreases and the loss can oscillate and in the second phase the
sharpness is small enough and thus the loss monotone decreases. Such settings include
scale invariant functions [23, 36] and 2-homogeneous models with `2 loss [55, 251].
Different to the previous works, the non-monotone decrease of loss shown in our work
happens at Edge of Stability and doesn’t require a entire phase where descent lemma
holds.

9.3 Warm-up: Quadratic Loss Functions

To introduce ideas that will be used in the main results, we sketch analysis of
Normalized GD (9.1) on quadratic loss function L(x) = 12 x> Ax where A ∈ RD×D
is positive definite with eigenvalues λ1 > λ2 ≥ . . . ≥ λD and v1 , . . . , vD are the
corresponding eigenvectors.

∇L(x(t)) Ax(t)
x(t+1) = x(t)−η = x(t)−η . (9.1)
k∇L(x(t))k kAx(t)k

401
Our main result Theorem 9.3.1 is that the iterates of Normalized GD x(t) converge
to v1 in direction, from which the loss oscillation Corollary 9.3.2 follows, suggesting
that GD is operating in EoS. Since in quadratic case there is only one local minima,
there is of course no need to talk about implicit bias. However, the observation that
the GD iterates always align to the top eigenvector as well as the technique used in
its proof play a very important role for deriving the sharpness-reduction implicit bias
for the case of general loss functions.
Ax(t)
Define x
e(t) = η
, and the following update rule (9.2) holds. It is clear that the
convergence of x
et to v1 in direction implies the convergence of xt as well.

x
e(t)
x e(t) − A
e(t + 1) = x . (9.2)
ke
x(t)k

Theorem 9.3.1. If |hv1 , x 6 0, ∀t ≥ 0, then there exists 0 < C < 1 and s ∈ {±1}
e(t)i| =
such that limt→∞ x e(2t + 1) = (C − 1)sλ1 v1 .
e(2t) = Csλ1 v1 and limt→∞ x

As a direct corollary, the loss oscillates as between time step 2t and time step
2t + 1 as t → ∞. This shows that the behavior of loss is not monotonic and hence
indicates the edge of stability phenomena for the quadratic loss.

Corollary 9.3.2. If |hv1 , x


e(t)i| 6= 0, ∀t ≥ 0, then there exists 0 < C < 1 such that
limt→∞ L(x(2t)) = 12 C 2 λ1 η 2 and limt→∞ L(x(2t + 1)) = 12 (C − 1)2 λ1 η 2 .

We analyse the trajectory of the iterate x


e(t) in two phases. For convenience,
we define P (j:D) as the projection matrix into the space spanned by {vi }D i=j , i.e.,
PD >
P (j:D) := i=j vi vi . In the first preparation phase, x
e(t) enters the intersection
of D invariant sets {Ij }D x | P (j:D) x
j=1 around the origin, where Ij := {e e ≤ λj }.
(Lemma 9.3.3) In the second alignment phase, the projection of x
e(t) on the top
eigenvector, | he
x(t), v1 i |, is shown to increase monotonically among the steps among
the steps {t ∈ N | ke
x(t)k ≤ 0.5λ1 }. Since it is bounded, it must converge. The

402
vanishing increment over steps turns out to suggest the x
e(t) must converge to v1 in
direction.

λ1
Lemma 9.3.3 (Preparation Phase). For any j ∈ [D] and t ≥ λj
ln λλ1j +
max{ kex(0)k−λ
λD
1
e(t) ∈ Ij .
, 0}, it holds that x

Proof of Lemma 9.3.3. First, we show for any j ∈ [D], Ij is indeed an invariant set
for update rule (9.2) via Lemma 9.9.1. With straightforward calculation, one can
λD kP (j:D) x
e(t)k
show that for any j ∈ [D], P (j:D) x
e(t) decreases by kex(t)k
if P (j:D) x
e(t) ≥
λj (Lemma 9.9.2). Setting j = 1, we have ke
x(t)k decreases by λD if ke
x(t)k ≥
λ1 (Corollary 9.9.3). Thus for all t ≥ max{ kex(0)k−λ
λD
1
e(t) ∈ I1 . Finally once
, 0}, x
x x(t)k by λ1 , and thus P (j:D) x
e(t) ∈ I1 , we can upper bound ke e(t) shrinks at least
λD λ1
by a factor of λ1
e(t) will be in Ij in another
per step, which implies x λj
ln λλ1j
steps.(Corollary 9.9.4)

Once the component of x


e(t) on an eigenvector becomes 0, it stays 0. So without
loss of generality we can assume that after the preparation phase, the projection of
x
e(t) along the top eigenvector v1 is non-zero, otherwise we can study the problem in
the subspace excluding the top eigenvector.

e(T ) ∈ ∩D
Lemma 9.3.4 (Alignment Phase). If x j=1 Ij holds for some T , then for any

t0 , t such that T ≤ t ≤ t0 and ke


x(t)k ≤ 0.5λ1 , it holds |hv1 , x e(t0 )i|.
e(t)i| ≤ |hv1 , x

Below we sketch the proof of Lemma 9.3.4.

Proof of Lemma 9.3.4. First, Lemma 9.3.5 (proved in Section 9.9) shows that the
norm of the iterate x
e(t) remains above 0.5λ1 for only one time-step.

λ1
Lemma 9.3.5. For any t with x e(t) ∈ ∩D
j=1 Ij , if ke
x(t)k > 2
, then ke
x(t + 1)k ≤
 
λ2
max λ21 − 2λD1 , λ1 − ke
x(t)k .

403
Invariant sets ||x(t)|| | v1, x(t) |
2 1 1.0 100
2
||x|| = 2
1

1 2 0.8
1 10 2
0.6
0
v1
0.4 10 4

1
0.2 10 6
2
0.0
2 1 0 1 2 0 2 4 6 8 10 0 2 4 6 8 10
v2 Normalized GD steps Normalized GD steps

Figure 9.2: Visualization of key concepts and lemmas in the analysis for Normalized
GD on a 2D quadratic loss with λ1 = 1, λ2 = 0.4. Left: invariant sets (defined in
Lemma 9.3.3). Middle: ke x(t)k drops below λ21 in the next step whenever it is above
λ1
2
(Lemma 9.3.5). Right: |hv1 , xe(t)i| monotone increases among all the steps with
λ1
norm below 2 . (Lemma 9.3.6)

λ1 λ1
e(t) ∈ ∩D
Thus, for any t with x j=1 Ij and ke
x(t)k ≤ 2
, either ke
x(t + 1)k ≤ 2
, or
λ1 λ1
ke
x(t + 1)k > 2
, which in turn implies that ke
x(t + 2)k ≤ 2
by Lemma 9.3.5. The
proof of Lemma 9.3.4 is completed by induction on Lemma 9.3.6.

λ1
Lemma 9.3.6. For any step t with ke
x(t)k ≤ 2
, for any k ∈ {1, 2}, |hv1 , x
e(t + k)i| ≥
|hv1 , x
e(t)i|.

Proof of case k = 1 in Lemma 9.3.6 follows directly from plugging the assumption
λ1
ke
x(t)k ≤ 2
into (9.2) (See Lemma 9.9.5). The case of k = 2 in Lemma 9.3.6 follows
from Lemma 9.9.7. We defer the complete proof of Lemma 9.3.6 into Section 9.9.

To complete the proof for Theorem 9.3.1, we relate the increase in the projection
along v1 at any step t, |hv1 , x
e(t)i|, to the magnitude of the angle between x
e(t) and the
λ1
top eigenspace, θt . Briefly speaking, we show that if ke
x(t)k ≤ 2
, |hv1 , x
e(t)i| has to
increase by a factor of Θ(θt2 ) in two steps. Since |hv1 , x
e(t)i| is bounded and monotone
λ1
increases among {t | ke
x(t)k ≤ 2
} by Lemma 9.3.4, we conclude that θt gets arbitrarily
λ1 λ1
small for sufficiently large t with ke
x(t)k ≤ 2
, ke
x(t + 2)k ≤ 2
satisfied. Since the
one-step normalized GD update Equation (9.2) is continuous when bounded away

404
from origin, with a careful analysis, we conclude θt → 0 for all iterates. Please see
Section 9.9.3 for details.

q
1 >
√ q
1 >
Equivalence to GD on 2
x Ax: Below we show GD on loss L(x) = 2
x Ax,
Equation (9.3), follows the same update rule as Normalized GD on L(x) = 12 x> Ax,
up to a linear transformation.

√ Ax(t)
x(t + 1) = x(t) − η∇ L(x(t)) = x(t) − η p . (9.3)
2x(t)> Ax(t)

e(t) = η1 (2A)1/2 x(t), we can easily check x


Denoting x e(t) also satisfies update rule (9.2).

9.4 Notations

For any integer k, we denote C k as the set of the k times continuously differentiable
functions. For any mapping F , we use ∂F (x)[u] and ∂ 2 F (x)[u, v] to denote the first and
second order directional derivative of F at x along the derivation of u (and v). Given
the loss function L, the gradient flow (GF) governed by L can be described through a

mapping φ : RD × [0, ∞) → RD satisfying φ(x, τ ) = x − 0 ∇L(φ(x, s))ds. We further
define the limiting map of gradient flow as Φ, that is, Φ(x) = limτ →∞ φ(x, τ ).
For a matrix A ∈ RD×D , we denote its eigenvalue-eigenvector pairs by
{λi (A), vi (A))}i∈[D] . For simplicity, whenever Φ is defined at point x, we use
{(λi (x), vi (x))}D 2
i=1 to denote the eigenvector-eigenvalue pairs of ∇ L(Φ(x)), with

λ1 (x) > λ2 (x) ≥ λ3 (x) . . . ≥ λD (x). As an analog to the quadratic case, we use x
e to
1/2
denote ∇2 L(Φ(x))(x − Φ(x)) for Normalized GD on L and (2∇2 L(Φ(x))) (x − Φ(x))

for GD on L. Furthermore, when the iterates x(t) are clear in the context, we also
use shorthand λi (t) := λi (x(t)), vi (t) := vi (x(t)) and θt ∈ [0, π2 ] to denote the angle
e(t) and top eigenspace of ∇2 L(Φ(x(t))). Given a differentiable submanifold
between x
Γ of RD and point x ∈ Γ, we use Px,Γ : Γ → RD to denote the projection operator

405

onto the normal space of Γ at x, and Px,Γ := ID − Px,Γ . As before, for notational
⊥ ⊥
convenience, we use the shorthand Pt,Γ := PΦ(x(t)),Γ and Pt,Γ := PΦ(x(t)),Γ .
In this section, we focus on the setting where LR η goes to 0 and we fix the
initialization xinit and the loss function L throughout this chapter. We use O(·) to
hide constants about xinit and L.

9.5 Main Results: Sharpness Reduction

In this section we present the main results of this chapter. In Section 9.5.1, we make
our key assumptions that the minimizers of the loss function form a manifold. In
Sections 9.5.2 and 9.5.3 we present our main results for Normalized GD and GD on

L respectively. In Section 9.5.4 we show the above two settings for GD do enter the
regime of Edge of Statbility.

9.5.1 Key Assumptions on Manifold of Local Minimizers

Similar to Chapter 8, we make the following assumption throughout this chapter. The
only difference between the following assumption and Assumption 8.5.1 is below we
assume L is C 4 smooth instead of C 3 . This extra degree of differentiability allows us to
give a non-asymptotic rate instead of asymptotic convergence as shown in Chapter 9.

Assumption 9.5.1. Assume that the loss L : RD → R is a C 4 function, and that Γ


is a (D − M ) dimensional C 1 -submanifold of RD for some integer 1 ≤ M ≤ D, where
for all x ∈ Γ, x is a local minimizer of L with L(x) = 0 and rank (∇2 L(x)) = M .

Let U be the sets of points starting from which, gradient flow w.r.t. loss L
converges to some point in Γ, that is, U := {x ∈ RD | Φ(x) exists and Φ(x) ∈ Γ}.
Assumption 9.5.1 implies that U is open and Φ is C 3 on U . (By Lemma 8.8.2)
We also make the following assumption to ensure that λ1 (∇2 L(·)) is differentiable,
which is necessary for our main results, Theorems 9.5.4 and 9.5.6.
406
Assumption 9.5.2. For any x ∈ Γ, ∇2 L(x) has a positive eigengap, i.e.,
λ1 (∇2 L(x)) > λ2 (∇2 L(x)).

9.5.2 Results for Normalized GD

We first denote the iterates of Normalized GD with LR η by xη (t), with xη (0) ≡ xinit
for all η:
∇L(xη (t))
Normalized GD: xη (t + 1) = xη (t) − η (9.4)
k∇L(xη (t))k

The first theorem demonstrates the movement in the manifold, when the iterate
travels from xinit to a position that is O(η) distance closer to the manifold (more
specifically, Φ(xinit )). Moreover, just like the result in the quadratic case, we have more
fine-grained bounds on the projection of xη (t) − Φ(xη (t)) into the bottom-k eigenspace
of ∇2 L(Φ(xη (t))) for every k ∈ [D]. For convenience, we define the following quantity
for all j ∈ [d] and x ∈ U :

v v
uM uM
uX uX
Rj (x) := t hvi (x), x 2
ei − λj (x)η = t λ2i (x)hvi (x), x − Φ(x)i2 − λj (x)η
i=j i=j

In the quadratic case, Lemma 9.3.3 shows that Rj (x) will eventually become
non-positive for normalized GD iterates. Similarly, for the general loss, the following
theorem shows that Rj (xη (t)) eventually becomes approximately non-positive (smaller
than O(η 2 )) in O( η1 ) steps.

Theorem 9.5.3 (Phase I). Let {xη (t)}t∈N be the iterates of Normalized GD (9.4)
with LR η and xη (0) = xinit ∈ U . There is T1 > 0 such that for any T10 > T1 , it
holds that for sufficiently small η that (1) max kxη (t) − Φ(xinit )k ≤ O(η) and (2)
T1 ≤ηt≤T10
max Rj (xη (t)) ≤ O(η 2 ).
T1 ≤ηt≤T10 ,j∈[D]

Our main contribution is the analysis for the second phase (Theorem 9.5.4), which
says just like the quadratic case, the angle between x
eη (t) and the top eigenspace of
407

Figure 9.3: Illustration for two-phase dynamics of Normalized GD and GD on L on a
1D zero loss manifold Γ. For sufficiently small LR η, Phase I is close to Gradient Flow
and lasts for Θ(η −1 ) steps, while Phase II is close to the limiting flow which decreases
the sharpness of the loss and lasts for Θ(η −2 ) steps. GD iterate oscillates along the
top eigenvector of the Hessian with the period equal to two steps. (cf. Figure 8.1)

∇2 L(Φ(xη (t))), denoted by θt , will be O(η) on average. And as a result, the dynamics
of Normalized GD tracks the riemannian gradient flow with respect to log(λ1 (∇2 L(·)))

on manifold, that is, the unique solution of Equation (9.5), where Px,Γ is the projection
matrix onto the tangent space of manifold Γ at x ∈ Γ.

Z τ
1 ⊥
Limiting Flow: X(τ ) = Φ(xinit ) − PX(s),Γ ∇ log λ1 (X(s))ds, X(τ ) ∈ Γ
4 s=0

(9.5)

Note Equation (9.5) is not guaranteed to have a global solution, i.e., a well-defined
solution for all τ ≥ 0, for the following two reasons: (1). when the multiplicity of
top eigenvalue is larger than 1, λ1 (∇2 L(·)) may be not differentiable and (2). the
projection matrix is only defined on Γ and the equation becomes undefined when the
solution leaves Γ, i.e., moving across the boundary of Γ. For simplicity, we make
Assumption 9.5.2 that every point on Γ has a positive eigengap. Or equivalently, we
can work with a slightly smaller manifold Γ0 = {x ∈ Γ | λ1 (x) > λ2 (x)}.
Towards a mathematical rigorous characterization of the dynamics in the second
phase, we need to make the following modifications: (1). we add negligible noise

408
Algorithm 7 Perturbed Normalized Gradient Descent
Input: loss function L : RD → R, initial point xinit , maximum number of iteration
T , LR η, Frequency parameter Tfreq = Θ(η −0.1 ), noise parameter r = Θ(η 100 ).
for t = 1 to T do
Generate n(t) ∼ B0 (r) if t mod Tfreq = 0, else set n(t) = 0.
∇L(x(t))
x(t) ← x(t − 1) − η k∇L(x(t))k + n(t).

of magnitude O(η 100 ) every η −0.1 steps, (2). we assume for each η > 0, there exist
some step t = Θ(1/η) in phase I, except the guaranteed condition (1) and (2) (by
Theorem 9.5.3, the additional condition (3) also holds. This assumption is mild
T1 T10
because we only require (3) to hold for one step among Θ(1/η) steps from η
to η
,
where T1 is the constant given by Theorem 9.5.3 and T10 is arbitrary constant larger
than T1 . This assumption also holds empirically for all our experiments in Section 9.7.

Theorem 9.5.4 (Phase II). Let {xη (t)}t∈N be the iterates of perturbed Normalized GD
(Algorithm 7) with LR η. Under Assumptions 9.5.1 and 9.5.2, if the initialization xη (0)
satisfy that (1) kxη (0) − Φ(xinit )k ≤ O(η) where xinit ∈ U , (2) maxj∈[D] Rj (xη (t)) ≤
O(η 2 ), and additionally (3) min{|hv1 (xη (0)), xη (0) − Φ(xη (0))i| , −R1 (xη (0))} ≥ Ω(η),
then for any time T2 > 0 till which the solution of (9.5) exists, it holds for suffi-
ciently small η, with probability at least 1 − O(η 10 ), that kΦ(xη (bT2 /η 2 c)) − X(T2 )k =
P 2 /η2 c
O(η) and bT21/η2 c bTt=0 θt ≤ O(η), where θt ∈ [0, π2 ] denotes the angle between
∇2 L(Φ(xη (t)))(xη (t) − Φ(xη (t))) and top eigenspace of ∇2 L(Φ(xη (t))).


9.5.3 Results for GD on L

In this subsection, we denote the iterates of GD on L with LR η by xη (t), with
xη (0) ≡ xinit for all η:

√ √
GD on L: xη (t + 1) = xη (t) − η∇ L(xη (t)) (9.6)

409

Algorithm 8 Perturbed Gradient Descent on L
Input: loss function L : RD → R, initial point xinit , maximum number of iteration
T , LR η, Frequency parameter Tfreq = Θ(η −0.1 ), noise parameter r = Θ(η 100 ).
for t = 1 to T do
Generate n(t) ∼ B0 (r)√if t mod Tfreq = 0, else set n(t) = 0.
x(t) ← x(t − 1) − η∇ L(x(t)) + n(t).

Similar to Normalized GD, we will have two phases. The first theorem demonstrates
the movement in the manifold, when the iterate travels from xinit to a position that is
O(η) distance closer to the manifold. For convenience, we will denote the quantity
qP
M
p
2
i=j λi (x)hvi (x), x − Φ(x)i − η 1/2λj (x) by Rj (x) for all j ∈ [M ] and x ∈ U .

Theorem 9.5.5 (Phase I). Let {xη (t)}t∈N be the iterates of Normalized GD (9.6)
with LR η and xη (0) = xinit ∈ U . There is T1 ∈ R+ such that for any T10 ∈ R+ ,
it holds for sufficiently small η that (1) max kxη (t) − Φ(xinit )k ≤ O(η) and (2)
T1 ≤ηt≤T10
max
0
Rj (xη (t)) ≤ O(η 2 ).
T1 ≤ηt≤T1 ,j∈[D]

The next result demonstrates that close to the manifold, the trajectory implicitly
minimizes sharpness.

Theorem 9.5.6 (Phase II). Let {xη (t)}t∈N be the iterates of perturbed GD on L
(Algorithm 8). Under Assumptions 9.5.1 and 9.5.2, if the initialization xη (0) sat-
isfy that (1) kxη (0) − Φ(xinit )k ≤ O(η), where xinit ∈ U , (2) maxj∈[D] Rj (xη (t)) ≤
O(η 2 ), and additionally (3) min{|hv1 (xη (0)), xη (0) − Φ(xη (0))i| , −R1 (xη (t))} ≥ Ω(η),
then for any time T2 > 0 where the solution of (9.7) exists, it holds for suffi-
ciently small η, with probability at least 1 − O(η 10 ), that kΦ(xη (bT2 /η 2 c)) − X(T2 )k =
P 2 /η2 c
O(η 1/2 ) and bT21/η2 c bT
t=0 θt ≤ O(η 1/2 ), where θt ∈ [0, π2 ] denotes the angle between
p
∇2 L(Φ(xη (t)))(xη (t) − Φ(xη (t))) and top eigenspace of ∇2 L(Φ(xη (t))).

Z τ
1 ⊥
X(τ ) = Φ(xinit ) − PX(s),Γ ∇λ1 (X(s))ds, X(τ ) ∈ Γ. (9.7)
8 s=0

410
9.5.4 Operating on the Edge of Stability

In this section, we show that both Normalized GD on L and GD on L is on Edge
of Stability in their phase II, that is, at least in one of every two consecutive steps,
the stableness is at least 2 and the loss oscillates in every two consecutive steps.
Interestingly, the average loss over two steps still monotonically decreases, even when
operating on the edge of Stability (see Figure 9.1 for illustration), as indicated by
the following theorems. Note that Theorems 9.5.4 and 9.5.6 ensures that the average

of θt are O(η) and O( η). We defer their proofs into Sections 9.13.5 and 9.15.4
respectively.

Theorem 9.5.7 (Stableness, Normalized GD). Under the setting of Theorem 9.5.4,
η
by viewing Normalized GD as GD with time-varying LR ηt := k∇L(xη (t))k
, we have
p
[SL (xη (t), ηt )]−1 +[SL (xη (t+1), ηt+1 )]−1 = 1+O(θt +η). Moreover, we have L(xη (t))+
q
2
L(xη (t + 1)) = η λ1 (∇ L(x η (t)))
p
2
+ O(ηθt ).

Theorem 9.5.8 (Stableness, GD on L). Under the setting of Theorem 9.5.6, we
p p
have [S√L (xη (t), ηt )] ≥ Ω( θ1t ). Moreover, we have L(xη (t)) + L(xη (t + 1)) =
ηλ1 (∇2 L(xη (t))) + O(ηθt ).

9.6 Proof Overview

We sketch the proof of the Normalized GD in phase I and II respectively in Section 9.6.2.

Then we briefly discuss how to prove the results for GD with L with same analysis
in Section 9.6.3. We start by introducing the properties of limit map of gradient flow
Φ in Section 9.6.1, which plays a very important role in the analysis.

9.6.1 Properties of Φ

The limit map of gradient flow Φ lies at the core of our analysis. When LR η is
small, one can show xη (t) will be O(η) close to manifold and Φ(xη (t)). Therefore,
411
Φ(xη (t)) captures the essential part of the implicit regularization of Normalized GD
and characterization of the trajectory of Φ(xη (t)) immediately gives us that of Φ(xη (t))
up to O(η).
Below we first recap a few important properties of Φ that will be used later this
section, which makes the analysis of Φ(xη (t)) convenient.

Lemma 9.6.1. Under Assumption 9.5.1, Φ satisfies the following two properties:

1. ∂Φ(x)∇L(x) = 0 for any x ∈ U. (Lemma 9.10.14)


2. For any x ∈ Γ, if λ1 (x) > λ2 (x), ∂ 2 Φ(x)[v1 (x), v1 (x)] = − 12 Px,Γ ∇ log λ1 (x).
(Lemmas 9.10.16 and 9.10.18)

∇L(xη (t))
Note that xη (t + 1) − xη (t) = −η k∇L(x η (t))k
, using a second order taylor expansion
of Φ, we have

Φ(xη (t + 1)) − Φ(xη (t)) (9.8)


η2 2
 
∇L(xη (t)) ∇L(xη (t)) ∇L(xη (t))
= − η∂Φ(xη (t)) + ∂ Φ(xη (t)) , + O(η 3 )
k∇L(xη (t))k 2 k∇L(xη (t))k k∇L(xη (t))k
2
 
η ∇L(xη (t)) ∇L(xη (t))
= ∂ 2 Φ(xη (t)) , + O(η 3 ), (9.9)
2 k∇L(xη (t))k k∇L(xη (t))k

where we use the first claim of Lemma 9.6.1 in the final step. Therefore, we have
Φ(xη (t + 1)) − Φ(xη (t)) = O(η 2 ), which means Φ(xη (t)) moves slowly along the
manifold, at a rate of at most O(η 2 ) step. The Taylor expansion of Φ, (9.9) plays a
crucial role in our analysis for both Phase I and II and will be used repeatedly.

9.6.2 Analysis for Normalized GD

Analysis for Phase I, Theorem 9.5.3: The Phase I itself can be divided into
two subphases: (A). Normalized GD iterate xη (t) gets O(η) close to manifold; (B).
counterpart of preparation phase in the quadratic case: local movement in the O(η)-

412
neighborhood of the manifold which decreases Rj (xη (t)) to O(η 2 ). Below we sketch
their proofs respectively:

• Subphase (A): First, with a very classical result in ODE approximation theory,
normalized GD with small LR will track the normalized gradient flow, which is
a time-rescaled version of standard gradient flow, with O(η) error, and enter a
small neighborhoods of the manifold where Polyak-Lojasiewicz (PL) condition
holds. Since then, Normalized GD decreases the fast loss with PL condition and
the gradient has to be O(η) small in O( η1 ) steps. (See details in Section 9.11.1).

• Subphase (B): The result in subphase (B) can be viewed as a generalization


of Lemma 9.3.3 when the loss function is O(η)-approximately quadratic, in both
space and time. More specifically, it means k∇2 L(Φ(xη (t))) − ∇2 L(x)k ≤ O(η)
for all x which is O(η)-close to some Φ(xη (t0 )) with t0 − t ≤ O(1/η). This is be-
cause by Taylor expansion (9.9), kΦ(xη (t)) − Φ(xη (t0 ))k = O(η 2 (t0 − t)) = O(η),
and again by Taylor expansion of ∇2 L, we know k∇2 L(x) − ∇2 L(Φ(xη (t)))k =
O(kx − Φ(xη (t))k) = O(η).

With a similar proof technique, we show xη (t) enters ainvariant set around the
manifold Γ, that is, {x ∈ U | Rj (x) ≤ O(η 2 ), ∀j ∈ [D]}. Formally, we show the
following analog of Lemma 9.3.3:

Lemma 9.6.2 (Preparation Phase, Informal version of Lemma 9.11.1). Let


{xη (t)}t≥0 be the iterates of Normalized GD (9.4) with LR η. If for some step
t0 , kxη (t0 ) − Φ(xη (t0 ))k = O(η), then for sufficiently small LR η and all steps
t ∈ [t0 +Θ(1), Θ(η −2 )] steps, the iterate xη (t) satisfy maxj∈[M ] Rj (xη (t)) ≤ O(η 2 ).

Analysis for Phase II, Theorem 9.5.4: Similar to the subphase (B) in the Phase
I, the high-level idea here is again that xη (t) locally evolves like normalized GD with
quadratic loss around Φ(xη (t)) and with an argument similar to the alignment phase

413
of quadratic case (though technically more complicated), we show xη (t) − Φ(xη (t))
approximately aligns to the top eigenvector of ∇2 L(Φ(xη (t))), denoted by v1 (t) and so
does ∇L(xη (t)). More specifically, it corresponds to the second claim in Theorem 9.5.4,
P 2 /η2 c
that bT21/η2 c bT
t=0 θt ≤ O(η).
We now have a more detailed look at the movement in Φ. Since Φ(xη (t)) belongs to
the manifold, we have ∇L(Φ(xη (t))) = 0 and so ∇L(xη (t)) = ∇2 L(Φ(xη (t)))(xη (t) −
Φ(xη (t))) + O(η 2 ) using a Taylor expansion. This helps us derive a relation between
the Normalized GD update and the top eigenvector of the hessian (simplified version
of Lemma 9.10.9):

∇L(xη (t))
∃s ∈ {±1}, = sv1 (t) + O(θt + η). (9.10)
k∇L(xη (t))k

Incorporating the above into the movement in Φ(xη (t)) from Equation (9.9) gives:

η2 2
Φ(xη (t + 1)) − Φ(xη (t)) = ∂ Φ(xη (t))[v1 (t), v1 (t)] + O(η 2 θt + η 3 ) (9.11)
2

Applying the second property of Lemma 9.6.1 on Equation (9.11) above yields
Lemma 9.6.3.

Lemma 9.6.3 (Movement in the manifold, Informal version of Lemma 9.10.12).


Under the setting in Theorem 9.5.4, for sufficiently small η, we have at any step
t ≤ bT2 /η 2 c

η2 ⊥
Φ(xη (t + 1)) − Φ(xη (t)) = − Pt,Γ ∇ log λ1 (t) + O(η 3 + η 2 θt ).
4

To complete the proof of Theorem 9.5.4, we show that for small enough η, the
P 2 /η2 c
trajectory of Φ(xη (τ /η 2 )) is O(η 3 bT2 /η 2 c+η 2 bT
t=0 θt )-close to X(τ ) for any τ ≤ T2 ,

414
PbT2 /η2 c
where X(·) is the flow given by Equation (9.5). This error is O(η), since t=0 θt =
O(bT2 /η 2 cη).
One technical difficulty towards showing the average of ηt is only O(η) is that
our current analysis requires |hv1 (xη (t)), xη (t) − Φ(xη (t))i| doesn’t vanish, that is, it
remains Ω(η) large throughout the entire training process. This is guaranteed by
Lemma 9.3.4 in quadratic case – since the alignment monotone increases whenever
λ1
it’s smaller 2
, but the analysis breaks when the loss is only approximately quadratic
and the alignment |hv1 (xη (t)), xη (t) − Φ(xη (t))i|could decrease decrease by O(θt η 2 )
per step. Once the alignment becomes too small, even if the angle θt is small, the
normalized GD dynamics become chaotic and super sensitive to any perturbation.
Our current proof technique cannot deal with this case and that’s the main reason we
have to make the additional assumption in Theorem 9.5.4.
Role of η 100 noise. Fortunately, with the additional assumption that the
initial alignment is at least Ω(η), we can show adding any poly(η) perturbation
(even as small as Ω(η 100 )) suffices to prevent the aforementioned bad case, that is,
|hv1 (xη (t)), xη (t) − Φ(xη (t))i| stays Ω(η) large. The intuition why Ω(η 100 ) perturba-
e = cv1 for any |c| ≤ 1
tion works again comes from quadratic case – it’s clear that x
is a stationary point for two-step normalized GD updates for quadratic loss under
the setting of Section 9.3. But if c is smaller than critical value determined by the
eigenvalues of the hessian, the stationary point is unstable, meaning any deviation
away from the top eigenspace will be amplified until the alignment increases above
the critical threshold. Based on this intuition, the formal argument, Lemma 9.13.11
uses the techniques from the ‘escaping saddle point’ analysis [252]. Adding noise is
not necessary in experiments to observe the predicted behavior (see ‘Alignment’ in
Figure 9.4 where no noise is added). On one hand, it might be because the floating
point errors served the role of noise. On the other hand, we suspect it’s not necessary

415
even for theory, just like GD gets stuck at saddle point only when initialized from a
zero measure set even without noise [119, 121].


9.6.3 Analysis for GD on L

In this subsection we will make an additional assumption that L(x) = 0 for all x ∈ Γ.
The analysis then will follow a very similar strategy as the analysis for (Normalized)
GD. However, the major difference from the analysis for Normalized GD comes from
the update rule for xη (t) when it is O(η)-close to the manifold:

√ p
∃s ∈ {±1}, ∇ L(xη (t)) = s λ1 (t)v1 (t) + O(η + θt ).

p
Thus, the effective learning rate is λ1 (t)η at any step t. This shows up, when we
compute the change in the function Φ. Thus, we have the following lemma showcasing

the movement in the function Φ with the GD update on L:

Lemma 9.6.4 (Movement in the manifold, Informal version of Lemma 9.15.1). Under
the setting in Theorem 9.5.6, for sufficiently small η, we have at any step t ≤ bT2 /η 2 c,
2
Φ(xη (t + 1)) − Φ(xη (t)) = − η8 Pt,Γ

∇λ1 (t) + O(η 3 + η 2 θt ).

9.7 Experiments

Though our main theorems characterizes the dynamics of Nomalized GD and GD on



L for sufficiently small LR, it’s not clear if the predicted phenomena is related to
the training with practical LR as the function and initialization dependent constants
are hard to compute and could be huge. Neverthesless, in this section we show the
phenomena predicted by our theorem does occur for real-life models like VGG-16. We
further verify the predicted convergence to the limiting flow for Normalized GD on a
two-layer fully-connected network trained on MNIST.

416
Top Eigenvalue Alignment 4.0Lower bound on Stableness Test Acc
Square root Loss 1.0
20 3.5 50
Normalized GD 0.9 3.0
15 0.8 2.5 40
0.7 2.0
10 30
1.5
0.6
5 Square root Loss 1.0 Square root Loss 20 Square root Loss
0.5 Normalized GD 0.5 Normalized GD Normalized GD
10
0 5000 10000 15000 0 5000 10000 15000 0.0 0 5000 10000 15000 0 5000 10000 15000
Gradient Steps

Figure 9.4: We verify our theoretical claims in the second phase —(a) the sharpness
decreases; (b) gradient aligns with the top eigenvector of Hessian; (c) stableness will
be higher than 2 — under the setting √ of training VGG-16 on CIFAR-10 dataset with
Normalized GD on L and GD with L loss respectively.

Verification for Predicted Phenomena on Real-life Models: We first observe


the behavior of different test functions throughout the training to verify our theoretical
findings. We perform our experiments on a VGG-16 model [253] trained on CIFAR-10

dataset [254] with Normalized GD and GD with L. For efficient full-batch training,
we trained the model on a sample of randomly chosen 5000 examples from the training
dataset. To meet the smoothness requirement by our theory, we modified our network
in two ways, (a) we used GeLU activation [76] in place of the non-smooth ReLU
activation, and (b) we used average pooling in place of the non-smooth max-pooling
[255]. We used `2 loss instead of softmax loss to ensure the existence of minimizers and
thus the manifold. We plot the behavior of the following four functions in Figure 9.4:
Top eigenvalue of the Hessian, Alignment, Stableness, and Test accuracy. Alignment is
defined as 1
λ1 kgk2
g > (∇2 L)g, where ∇2 L is the Hessian, g is the gradient and λ1 is the
η
top eigenvalue of the Hessian. To check the behavior for Stableness, we plot kgk × λ1

for Normalized GD and 2√η L × λ1 for GD with L, which are lower bounds on the
Stableness of the Hessian (9.1.1).
We observe that the alignment function reaches close to 1, towards the end of
training. The top eigenvalue decreases over time (as predicted byTheorem 9.5.4 and
Theorem 9.5.6), and the stableness hovers around 2 at the end of training.

417
140
Top Eigenvalue Hessian trace Test loss Test Acc
0.32 90.2
Normalized GD 4000
120 Riemannian Flow 90.1
3500 0.31
100 90.0
80 3000
0.30 89.9
60 2500
89.8
0.29
40 2000 89.7
20 1500 0.28 89.6
0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3
Continuous time

Figure 9.5: Normalized GD and Riemannian flow have almost the same behavior
under proper time scalings, for a 2-layer network on MNIST initialized with tiny loss.

Relative Parameter Difference

0.020

0.015

0.010

0 1 2 3
Continuous time
Figure 9.6: The trajectory of Normalized GD is very close to that of the limiting flow
minimizing the sharpness on manifold, as predicted by our theory. Absolute difference
is the norm of the difference between the parameters of the two trajectories at the
same continuous time, while relative parameter difference is the ratio of the norm of
the difference to the norm of parameters of each runs.

418
Verifying Convergence to Limiting Flow on MNIST: We further verify the
closeness between the Riemannian gradient flow w.r.t. the top eigenvalue and Nor-
malized GD, as predicted by Theorem 9.5.4, on a 1 hidden-layer fully connected
network on MNIST [256]. The network had 784 hidden units, with GeLU activation
function. We use `2 loss to ensure the existence of minimizers, which is necessary for
the existence of the manifold. For efficient training on a single GPU, we train on a
random subset of training data of size 1000.
We first trained the model with full to reach loss of order 10−3 . Starting from
this checkpoint, we make two different runs, one for Normalized GD and another for
Riemannian gradient flow w.r.t. the top eigenvalue (see Section 9.16 for details). We
plot the behavior of the network w.r.t. continuous time defined for Normalized GD as
#GradientSteps × η 2 /4, and for Riemannian flow as #GradientSteps × η, where η is
the learning rate. We track the behavior of Test Loss, Test accuracy, the top eigenvalue
of the Hessian and also the trace of the Hessian in Figure 9.5. We see that there is an
exact match between the behavior of the four functions, which supports our theory.
Moreover, Figure 9.6 computes the norm of the difference in the parameters between
the two runs, and shows that the runs stay close to each other in the parameter space
throughout training.

9.8 Limitation and Future Work

One limitation of our analysis is that it only applies close to the manifold of local
minimizers. In contrast, in experiments the EoS phenomenon, including the control of
sharpness, begins much sooner. Addressing this gap, as well as analying the EoS for

the loss L itself (as opposed to L as done here) is left for future work. Very likely
this will require novel understanding of properties of deep learning losses, which we

were able to circumvent by looking at L instead. Exploration of EoS-like effects in

419
SGD setting would also be interesting, although we first need definitive experiments
analogous to Cohen et al. [234].

9.9 Proofs for Results for Quadratic Loss Func-

tions

We first recall the settings and notations. Let A be a positive definite matrix. Without
loss of generality, we can assume A is diagonal, i.e., A = diag(λ1 , λ2 , . . . , λD ) ∈ RD×D ,
where λ1 > λ2 ≥ λ3 ≥ . . . ≥ λD > 0 and the eigenvectors are the standard basis
vectors e1 , · · · , eD of the D-dimensional space. We will denote P (j:D) = D >
P
i=j ei ei as

the projection matrix onto the subspace spanned by ej , . . . , eD .


Recall the loss function L is defined as L(x) = 12 x> Ax. The Normalized GD update
Ax(t) Ax(t)
(LR= η )is given by x(t + 1) = x(t) − η kAx(t)k . A substitution x
e(t) := η
gives the
following update rule:

x
e(t)
x e(t) − A
e(t + 1) = x . (9.2)
ke
x(t)k

Note Normalized GD (9.2) is not defined at ke


x(t)k = 0. Moreover, it’s easy to
check that if at some time step t |hv1 , x e(t0 )i| = 0 holds for any t0 ≥ t.
e(t)i| = 0, |hv1 , x
Thus it’s necessary to assume |hv1 , x 6 0 for all t ∈ N in order to prove alignment
e(t)i| =
to the top eigenvector of A for Normalized GD (9.2).
Now we recall the main theorem for Normalized GD on quadratic loss functions:

Theorem 9.3.1. If |hv1 , x 6 0, ∀t ≥ 0, then there exists 0 < C < 1 and s ∈ {±1}
e(t)i| =
such that limt→∞ x e(2t + 1) = (C − 1)sλ1 v1 .
e(2t) = Csλ1 v1 and limt→∞ x

We also note that GD on L with any LR η can also be reduced to update
rule (9.2), as shown in the discussion at the end of Section 9.3.

420
9.9.1 Proofs for Preparation Phase

In this subsection, we show (1). Ij is indeed an invariant set for normalized GD


∀j ∈ [D] and (2). from any initialization, normalized GD will eventually go into their
intersection ∩D
j=1 Ij .

Lemma 9.9.1. For any t ∈ N and j ∈ [D], P (j:D) x


e(t) ≤ λj =⇒ P (j:D) x
e(t + 1) ≤
λj . In other words, {Ij }D
j=1 are invariant sets of update rule Equation (9.2).

Proof of Lemma 9.9.1. Note P (j:D) A = P (j:D) AP (j:D) , by definition of Normalized


GD (9.2), we have

P (j:D) A
 
(j:D) (j:D) (j:D) x
e(t)
P x
e(t + 1) = P e(t) − P
x A = I− P (j:D) x
e(t),
ke
x(t)k kex(t)k

which implies

P (j:D) A
P (j:D) x
e(t + 1) ≤ I − P (j:D) x
e(t) . (9.12)
kex(t)k

Note that P (j:D) A 4 λj I, P (j:D) x x(t)k and P (j:D) x


e(t) ≤ ke e(t) ≤ λj by assump-
tion, we have

λj P (j:D) A P (j:D) A λj
− I4− 4I− 4I4 I.
kP (j:D) x
e(t)k kex(t)k kex(t)k kP (j:D) x
e(t)k

P (j:D) A λj
Therefore I − ≤ and thus we conclude P (j:D) x
e(t + 1) ≤ λj .
kex(t)k kP (j:D) xe(t)k

Lemma 9.9.2. For any t ∈ N and j ∈ [D], if P (j:D) x


e(t) ≥ λj , then
λD
P (j:D) x
e(t + 1) ≤ (1 − ke
x(t)k
) P (j:D) x
e(t) .

P (j:D) A
Proof of Lemma 9.9.2. Since λj ≤ P (j:D) x
e(t) ≤ ke
x(t)k, we have 0 4 I − kex(t)k
4
λD P (j:D) A λD
1− ke
x(t)k
. Therefore I − kex(t)k
≤1− ke
x(t)k
. The proof is completed by plugging
this into Equation (9.12).

421
Lemma 9.9.2 has the following two direct corollaries.

ke
x(0)k−λ1
e(0) and t ≥
Corollary 9.9.3. For any initialization x λD
, ke
x(t)k ≤ λ1 , that is,
e(t) ∈ I1 .
x

Proof of Corollary 9.9.3. Set j = 1 in Lemma 9.9.2, it holds that ke x(t + 1)k ≤
 kex(0)k−λ1 
ke
x(t)k − λD whenever ke x(t)k ≥ λ1 . Thus x
e( λD
) ≤ λ1 . The proof is
completed as I1 is an invariant set by Lemma 9.9.1.

Corollary 9.9.4. For any coordinate j ∈ [D] and initial point x


e(0) ∈ I1 , if t ≥
λ1
λD
ln λλ1j then P (j:D) x
e(t) ≤ λj .

Proof of Corollary 9.9.4. Since I1 is an invariant set, we have ke


x(t)k ≤ λ1 for all
t ≥ 0. Thus let T = b λλD1 ln λλ1j c, we have

−T
λD λj
P (j:D) x
e(T ) ≤ e λ1
P (j:D) x
e(0) ≤ ke
x(0)k ≤ λj .
λ1

The proof is completed since Ij is a invariant set for any j ∈ [D] by Lemma 9.9.1.

9.9.2 Proofs for Alignment Phase

In this subsection, we analyze how normalized GD align to the top eigenvector once
e(t) ∈ ∩D
it goes through the preparation phase, meaning x j=1 Ij for all t in alignment

phase.

λ1
Lemma 9.3.5. For any t with x e(t) ∈ ∩D
j=1 Ij , if ke
x(t)k > 2
, then ke
x(t + 1)k ≤
 
λ2
max λ21 − 2λD1 , λ1 − ke
x(t)k .

422
Proof. The update at step t as:

 
x(t)k − λ1 )e
 (ke x1 (t) 
 
1 1  (ke
x (t)k − λ 2 )e
x 2 (t) 
x
e(t + 1) = x(t)k I − A) x
(ke e(t) = .
 
ke
x(t)k ke
x(t)k 
 ..
 . 

 
x(t)k − λD )e
(ke xD (t)

Let the index k be the smallest integer such that λk+1 < 2 ke
x(t)k − λ1 . If no such
index exists, then one can observe that ke
x(t + 1)k ≤ λ1 − ke
x(t)k. Assuming that such
an index exists in [D], we have λk ≥ 2 ke
x(t)k − λ1 and ke
x(t)k − λj ≤ λ1 − ke
x(t)k,
∀j ≤ k. Now consider the following vectors:

v (1) (t) := (λ1 − ke


x(t)k)e
x(t),

v (2) (t) := (2 ke
x(t)k − λ1 − λk )P (k:D) x
e(t),

v (2+j) (t) := (λk+j−1 − λk+j )P (k+j:D) x


e(t), ∀1 ≤ j ≤ D − k.

By definition of k, | ke
x(t)k − λj | ≤ | ke
x(t)k − λ1 |. Thus

 
 (ke x(t)k − λ1 )e x1 (t) 
 .. 

 . 

 
(ke
x (t)k − λ )e
x (t)
 
1  1 k 
ke
x(t + 1)k ≤  
ke
x(t)k  
 x(t)k − λk+1 )e
(ke xk+1 (t)


 .
..


 
 
(kex(t)k − λD )e xD (t)
1
= v (1) (t) + v (2) (t) + . . . + v (D−k+2) (t)
ke
x(t)k
1
v (1) (t) + v (2) (t) + . . . + v (D−k+2) (t) .


ke
x(t)k

423
e(t) ∈ ∩D
By assumption, we have x j=1 Ij . Thus

v (1) (t) = (λ1 − ke


x(t)k) ke
x(t)k

v (2) (t) ≤ (2 ke
x(t)k − λ1 − λk )λk

v (2+j) (t) ≤ (λk−1+j − λk+j )λk+j , for all j ≥ 1.

Hence,

X X
v (j) (t) = (2 ke
x(t)k − λ1 − λk )λk + (λj − λj+1 )λj+1
j≥2 j≥k
X X
= (2 ke
x(t)k − λ1 )λk + λj λj+1 − λ2j
j≥k j≥k

(2 ke 2
x(t)k − λ1 ) + λ2k X λ2j + λ2j+1 X
≤ + − λ2j
2 j≥k
2 j≥k
2
(2 ke
x(t)k − λ1 ) λ2D
≤ − ,
2 2

where we applied AM-GM inequality multiple times in the pre-final step.


Thus,

1
v (1) (t) + v (2) (t) + . . . + v (D−k+1) (t)

ke
x(t + 1)k ≤
kex(t)k
x(t)k − λ1 )2
(2 ke λ2D
≤ − + λ1 − ke x(t)k
2 ke
x(t)k 2 ke
x(t)k
λ2 − λ2D
x(t)k + 1
= ke − λ1
2 kex(t)k
λ1 λ2
≤ − D,
2 2λ1

λ1
where the final step is because 2
≤ ke
x(t)k ≤ λ1 and that the maximal value of a
convex function is attained at the boundary of an interval.

424
λi
Lemma 9.9.5. At any step t and i ∈ [D], if ke
x(t)k T 2
, then |e
xi (t + 1)| S |e
xi (t)|,
where T denotes larger than, equal to and smaller than respectively. (Same for S, but
in the reverse order)
 
λi
Proof. From the Normalized GD update rule, we have x ei (t) 1 −
ei (t+1) = x ke
x(t)k
, for all i ∈
[D]. Thus

λ1 λ1
S 2 ⇐⇒ 1 − S 1 ⇐⇒ |e
xi (t + 1)| S |e
xi (t)| ,
ke
x(t)k ke
x(t)k

which completes the proof.

λ1
Lemma 9.9.6. At any step t, if ke
x(t)k ≤ 2
, then

 
λ λ
(λ1 − ke
x(t)k) cos θt ≤ ke
x(t + 1)k ≤ λ1 − ke
x(t)k − 1− λ1 sin2 θt ,
2λ1 λ1

kP (2:D) xe(t)k
where θt = arctan and λ = min(λ1 − λ2 , λD ).
|e>1 xe(t)|

Proof. We first show that the left side inequality holds by the following update rule
for he1 , x
e(t)i:

he1 , x
e(t)i
he1 , x x(t)k − λ1 )
e(t + 1)i = (ke .
kex(t)k

Since ke
x(t + 1)k ≥ |he1 , x
e(t + 1)i| and θt denotes the angle between e1 and x
e(t + 1),
we get the left side inequality.
Now, we focus on the right hand side inequality. First of all, the update in the
coordinate j ∈ [2, D] is given by

hej , x
e(t)i
hej , x x(t)k − λj )
e(t + 1)i = (ke .
kex(t)k

425
Then, we have

D
X
2
ke
x(t + 1)k = e(t + 1)i2
hej , x
j=1
D  2
X
2 hej , x
e(t)i
= x(t)k − λj )
(ke
j=1
kex(t)k
D 2 
2 2
X hej , x
e(t)i 2
x(t)k − λ1 ) cos θt +
= (ke x(t)k − λj )
(ke
j=2
kex(t)k
D  2
2 2 2
X hej , x
e(t)i
≤ (ke
x(t)k − λ1 ) cos θt + (ke
x(t)k − λ)
j=2
kex(t)k

x(t)k − λ1 )2 cos2 θt + (ke


= (ke x(t)k − λ)2 sin2 θt

x(t)k − λ1 )2 + (λ1 − λ)(2 ke


= (ke x(t)k − λ − λ1 ) sin2 θt

x(t)k − λ1 )2 − λ(λ1 − λ) sin2 θt ,


≤ (ke

where in the fourth step, we have used λ = argmaxλi |2≤i≤D |ke x(t)k − λi | . The final

x(t)k < λ21 . Hence, using the fact that 1 − y ≤ 1 − y/2 for any y ≤ 1, we
step uses ke
have

1
ke
x(t + 1)k ≤ λ1 − ke
x(t)k − λ(λ1 − λ) sin2 θt
2(λ1 − ke
x(t)k)
 
λ λ
≤ λ1 − ke
x(t)k − 1− λ1 sin2 θt ,
2λ1 λ1

λ1
where again in the final step, we have used ke
x(t)k < 2
. The above bound can be
further bounded by

 
λ λ
ke
x(t + 1)k ≤ λ1 − ke
x(t)k − 1− λ1 sin2 θt
2λ1 λ1
λ0 λ0
  
1
≤ λ1 − ke
x(t)k − min 1− λ1 sin2 θt
2 λ ∈{λ2 ,λD } λ1
0 λ1
  
1 λ λ
= λ1 − ke
x(t)k − 1− λ1 sin2 θt ,
2 λ1 λ1

426
where we have used λ = min(λ1 − λ2 , λD ).

Lemma 9.9.7. If at some step t, ke


x(t + 1)k + ke
x(t)k ≤ λ1 , then |e
x1 (t + 2)| ≥
|e
x1 (t)|, where the equality holds only when ke
x(t + 1)k + ke
x(t)k = λ1 . Therefore,
by Lemma 9.9.6, we have :

λ1 λ λ
ke
x(t)k ≤ =⇒ |e x1 (t)| (1 + 2 (1 − ) sin2 θt ),
x1 (t + 2)| ≥ |e
2 λ1 λ1

kP (2:D) xe(t)k
where θt = arctan , and λ = min(λ1 − λ2 , λD ).
|e>1 xe(t)|

Proof of Lemma 9.9.7. Using the Normalized GD update rule, we have

   
λ1 λ1
e1 (t + 1) = 1 −
x x
e1 (t), e1 (t + 2) = 1 −
x x
e1 (t + 1).
ke
x(t)k ke
x(t + 1)k

Combining the two updates, we have

  
λ1 λ1
|e
x1 (t + 2)| = 1 − 1− |e
x1 (t)|
ke
x(t)k ke
x(t + 1)k
λ2 − λ1 (ke
x(t)k − ke
x(t + 1)k)
= 1+ 1 |e
x1 (t)|
ke
x(t)k ke
x(t + 1)k
≥ |e
x1 (t)| ,

where the equality holds only when ke


x(t + 1)k + ke
x(t)k = λ1 .
λ1
Moreover, with the additional condition that ke
x(t)k < 2
, we have from
Lemma 9.9.6, ke x(t)k − λ(λ1 − λ) sin2 θt , where λ = min(λ1 − λ2 , λD ).
x(t + 1)k ≤ λ1 − ke

427
Hence, retracing the steps we followed before, we have

λ21 − λ1 (ke
x(t)k + kex(t + 1)k)
|e
x1 (t + 2)| = 1 + |e
x1 (t)|
ke
x(t)k kex(t + 1)k
λ(λ1 − λ) sin2 θt
≥ 1+ |e
x1 (t)|
kex(t)k kex(t + 1)k
λ λ
≥ 1 + 2 (1 − ) sin2 θt |e x1 (t)| ,
λ1 λ1

where the final step follows from ke


x(t + 1)k ≤ λ1 − ke
x(t)k and therefore
λ21
ke
x(t + 1)k ke
x(t)k ≤ 4
.

9.9.3 Proof of Main theorems for Quadratic Loss

Proof of Theorem 9.3.1. The analysis will follow in two phases:

1. Preparation phase: e(t) enters and stays in an invariant set around the
x
origin, that is, ∩D x| D e(t)i2 ≤ λ2j }. (See Lemma 9.3.3,
P
j=1 Ij , where Ij := {e i=j hei , x

which is a direct consequence of Lemmas 9.9.1 and 9.9.1 and corollary 9.9.3.)

e(t) on the top eigenvector, | he


2. Alignment phase: The projection of x x(t), e1 i |,
is shown to increase monotonically among the steps among the steps {t | ke
x(t)k ≤
0.5}, up until convergence, since it’s bounded. (Lemma 9.3.4)

By Lemma 9.9.7, the convergence of | he


x(t), e1 i | would imply the convergence of
x
e(t) to e1 in direction.

Below we elaborate the convergence argument in the alignment phase. For con-
venience, we will use θt to denote the angle between e1 and x
e(t) and we assume
e(0) ∈ ∩D Ij without loss of generality. We first define S := {t ∈ N | ke λ1
j=1 x(t)k ≤ 2
}
and S 0 := {t ∈ S | t + 2 ∈ S}. The result in alignment phase says that 1
λ1
|e
x1 (t)|
monotone increases and converges to some constant C ∈ (0, 12 ] among all t ∈ S, thus
|e
x1 (t+2)|
lim |e
x1 (t)|
= 1. By Lemma 9.9.7, we have lim θt = 0. Since the one-step
t→∞,t∈S 0 t→∞,t∈S 0

428
update function F (e e − A kexxek is uniformly lipschitz when ke
x) = x xk is bounded away
from zero, we know lim θt+k = 0, ∀k ∈ N.
t→∞,t∈S 0

Now we claim ∀t ≥ 3, there is some k ∈ {0, 1, 3} such that t − k ∈ S 0 . This is


because Lemma 9.3.5 says that if t ∈
/ S, then both t − 1, t + 1 ∈ S. Thus for any
/ S, t − 1 ∈ S 0 . Therefore, for any t ∈ S/S 0 , if t − 2 ∈
t∈ / S, then t − 3 ∈ S 0 . Thus we
conclude that ∀t ≥ 3, there is some k ∈ {0, 1, 3} such that t − k ∈ S 0 , which implies
lim θt = 0. Hence lim ke
x(t + 1) − x
e(t)k = λ1 , meaning for sufficiently large t, x
e1 (t)
t→∞ t→∞

e(t + 2) − x
flips its sign per step and thus lim x e(t) = 0, lim ke
x(t + 1)k + ke
x(t)k = λ1 .
t→∞ t→∞
1 λ1
If C = 2
, then we must have lim ke
x(t)k = 2
and we are done in this case. If C < 12 ,
t→∞

note that lim |e


x1 (t)| = Cλ1 , it must hold that lim ke
x(t + 1)k = (1 − C)λ1 ,
t→∞,t∈S 0 t→∞,t∈S 0

thus there is some large T ∈ S such that for all t ∈ S, t ≥ T , t+1 ∈


/ S. By Lemma 9.3.5,
t + 2 ∈ S. Thus we conclude lim x
e(T + 2t) = Cλse1 for some s ∈ {−1, 1} and thus
t→∞

e(T + 2t + 1) = (C − 1)λse1 . This completes the proof.


lim x
t→∞

9.9.4 Some Extra Lemmas (only used in the general loss

case)

For a general loss function L satisfying Assumption 9.5.1, the loss landscape looks like
a strongly convex quadratic function locally around its minimizer. When sufficient
small learning rate, the dynamics will be sufficiently close to the manifold and behaves
like that in quadratic case with small perturbations. Thus it will be very useful to
have more refined analysis for the quadratic case, as they allow us to bound the error
in the approximate quadratic case quantitatively. Lemmas 9.9.8 to 9.9.11 are such
examples. Note that they are only used in the proof of the general loss case, but not
in the quadratic loss case.
Lemma 9.9.8 is a slightly generalized version of Lemma 9.3.5.

429
λ2D
Lemma 9.9.8. Suppose at time t, P (j:D) x
e(t) ≤ λj (1 + λ21
), for all j ∈ [D], if
λ1 λ1
ke
x(t)k > 2
, then ke
x(t + 1)k ≤ 2
.

Proof of Lemma 9.9.8. The proof is similar to the proof of Lemma 9.3.5. Let the
index k be the smallest integer such that λk+1 < 2 ke
x(t)k − λ1 . If no such index exists,
then one can observe that ke
x(t + 1)k ≤ λ1 − ke
x(t)k. Assuming that such an index
exists in [D], we have λk ≥ 2 ke
x(t)k − λ1 and ke
x(t)k − λj ≤ λ1 − ke
x(t)k, ∀j ≤ k. With
λ2D
e(t) ∈ ∩D
the same decomposition and estimation, since x j=1 (1 + λ21
)Ij , we have

v (1) (t) = (λ1 − ke


x(t)k) ke
x(t)k
λ2D
v (2) (t) ≤ (1 + )(2 ke
x(t)k − λ1 − λk )λk
λ21
λ2
v (2+j) (t) ≤ (1 + D2 )(λk−1+j − λk+j )λk+j , for all j ≥ 1.
λ1

Thus we conclude

1
v (1) (t) + v (2) (t) + . . . + v (D−k+1) (t)

ke
x(t + 1)k ≤
ke
x(t)k
λ1 λ2 λ2 λ1
≤ (1 − D2 )(1 + 21 ) ≤ ,
2 λ1 λD 2

which completes the proof.


 r  
λ1
Lemma 9.9.9. Consider the function g : R → R, with g(λ) = 2
1 − 1 − 2 λλ1 1 − λ
λ1
.
For any small constant c > 0 and coordinate 1 ≤ k ≤ D, consider any t with
e(t) ∈ ∩D
x j=1 Ij . If x
e(t) satisfies that

• |he1 , x
e(t)i| ≤ (1 − 2c)g(λk ).
p
• θt ≤ c |he1 , x
e(t)i|,

kP (2:D) (ex(t))k
where θt = arctan |he1 ,e
x(t)i|
.

430
Then, we have

hek , x
e(t + 2)i hek , x
e(t)i
≥ (1 + c) .
he1 , x
e(t + 2)i he1 , x
e(t)i

Proof of Lemma 9.9.9. From the quadratic update, we have the update rule as:

 
λk
x ek (t) 1 −
ek (t + 1) = x , for all k ∈ {1, . . . , D}.
ke
x(t)k

Thus, we have for any 1 ≤ k ≤ d,

  
hek , x
e(t + 2)i λ1 − λk λ1 − λk hek , x
e(t)i
= 1− 1−
he1 , x
e(t + 2)i λ1 − ke
x(t)k λ1 − kex(t + 1)k he1 , x e(t)i
 
(λ1 − λk )(λ1 + λk − ke
x(t)k − kex(t + 1)k) hek , x e(t)i
= 1− .
(λ1 − kex(t + 1)k)(λ1 − kex(t)k) he1 , x
e(t)i

Thus, as long as, the following holds true:

(λ1 − λk )(λ1 + λk − ke
x(t)k − ke
x(t + 1)k)
≥ 2 + c,
(λ1 − kex(t + 1)k)(λ1 − ke
x(t)k)

we must have

hek , x
e(t + 2)i hek , x
e(t)i
≥ (1 + c) .
he1 , x
e(t + 2)i he1 , x
e(t)i

 
We can use (λ1 − ke
x(t)k) cos θt ≤ ke
x(t + 1)k ≤ λ1 − ke
x(t)k − λ
2λ1
1− λ
λ1
λ1 sin2 θt ,
where λ = min(λ1 − λ2 , λD ) from Lemma 9.9.6 to show the following with additional
algebraic manipulation:

(λ1 − λk )(λ1 + λk − ke
x(t)k − ke
x(t + 1)k) (λ1 − λk )λk
≥ .
(λ1 − kex(t + 1)k)(λ1 − ke
x(t)k) (λ1 − (λ1 − ke
x(t)k) cos θt )(λ1 − ke
x(t)k)

431
Hence, it suffices to show that

(λ1 − λk )λk
≥ 2 + c.
(λ1 − (λ1 − ke
x(t)k) cos θt )(λ1 − ke
x(t)k)

The left hand side can be simplified as

(λ1 − λk )λk (λ1 − λk )λk


= 2
(λ1 − (λ1 − ke
x(t)k) cos θt )(λ1 − ke
x(t)k) (2λ1 sin (θt /2) + |he1 , xe(t)i|)(λ1 − ke x(t)k)
(λ1 − λk )λk
≥ 2
λ1 θt /2 + |he1 , x
e(t)i|)(λ1 − |he1 , xe(t)i|)
(λ1 − λk )λk
≥ ,
e(t)i| (λ1 + 2c λ1 − |he1 , x
|he1 , x e(t)i|)

p
where the last step we use that |θt | ≤ c |he1 , x
e(t)i|, we only need

e(t)i|2 − 2λ1 (1 + c/2)(2 + c) |he1 , x


(2 + c) |he1 , x e(t)i| + (λ1 − λk )λk ≥ 0.

The above inequality is true when |he1 , x


e(t)i| ≤ (1 − 2c) g(λk ).

 r  
λ1
Lemma 9.9.10. Consider the function g : R → R, with g(λ) = 2
1− 1− 2 λλ1 1− λ
λ1
.
Consider any coordinate 2 ≤ k ≤ D. For any constant 0 < c < 4 λλk1 (1 − λk
λ1
), consider
e(t) ∈ ∩D
any t with x j=1 Ij , with x
e(t) satisfying

0.5λ1 ≥ ke
x(t)k ≥ (1 + c)g(λk ).

Then, the following must hold true at time t.

hek , x
e(t + 2)i hek , x
e(t)i
≤ (1 − 0.5c) ,
he1 , x
e(t + 2)i he1 , x
e(t)i

432
Proof. By the Normalized GD update, we have:

λk λk
! !
hek , x
e(t + 2)i 1− ke
x(t+1)k
1− ke
x(t)k hek , x
e(t)i
= λ1 λ1
he1 , x
e(t + 2)i 1− ke
1−
x(t+1)k ke
x(t)k
he1 , x
e(t)i
 
(λ1 − λk )(λ1 + λk − ke
x(t)k − ke
x(t + 1)k) hek , xe(t)i
= 1− .
(λ1 − kex(t + 1)k)(λ1 − ke
x(t)k) he1 , x
e(t)i
(9.13)

(λ1 −λk )(λ1 +λk −ke


x(t)k−kex(t+1)k)
Now, we focus on the term (λ1 −ke
x(t+1)k)(λ1 −ke
x(t)k)
. For simplicity, we will
denote the term as ratio(λ1 , λk , ke
x(t)k , ke
x(t + 1)k). The term behaves differently,
depending on whether ke
x(t)k ≥ λk or ke
x(t)k ≤ λk :

λ1
1. If ke
x(t)k ≥ λk , which is only possible when λk ≤ 2
, we find that
ratio(λ1 , λk , ke
x(t)k , ke
x(t + 1)k) is a monotonically decreasing function w.r.t.
ke
x(t + 1)k, keeping other terms fixed. Using the fact that ke
x(t + 1)k ≤
λ1 − ke
x(t)k from Lemma 9.9.6, we can bound the term as:

min ratio(λ1 , λk , a, λ1 − a) ≤ ratio(λ1 , λk , ke


x(t)k , ke
x(t + 1)k)
λk ≤a≤0.5λ1

≤ max ratio(λ1 , λk , a, 0).


λk ≤a≤0.5λ1

(λ1 +λk −a)(λ1 −λk )


We can simplify ratio(λ1 , λk , a, 0) as λ1 (λ1 −a)
for any a, and can be shown
λk
to be at most 1 + λ1
(≤ 3/2) for any a in the range (λk , 0.5λ1 ). Furthermore,
λk (λ1 −λk )
ratio(λ1 , λk , a, λ1 − a) simplifies as a(λ1 −a)
for any a, and can be shown to be
at least 4 λλk1 (1 − λk /λ1 ) in the range (λk , 0.5λ1 ), which it attains at a = λk .

2. If ke
x(t)k ≤ λk , we find that ratio(λ1 , λk , ke
x(t)k , ke
x(t + 1)k) is a monotonically
increasing function w.r.t. ke
x(t + 1)k, keeping other terms fixed. Using the fact

433
that ke
x(t + 1)k ≤ λ1 − ke
x(t)k from Lemma 9.9.6, we can bound the term as:

min ratio(λ1 , λk , a, 0) ≤ ratio(λ1 , λk , ke


x(t)k , ke
x(t + 1)k)
(1+c)g(λk )≤a≤min(0.5λ1 ,λk )

≤ max ratio(λ1 , λk , a, λ1 − a).


(1+c)g(λk )≤a≤min(0.5λ1 ,λk )

Continuing in the similar way as the previous case, we show that ratio(λ1 , λk , a, 0)
is at least 1−(λk /λ1 )2 in the range ((1+c)g(λk ), min(0.5λ1 , λk )). ratio(λ1 , λk , a, λ1 −
a) is maximized in the range ((1 + c)g(λk ), min(0.5λ1 , λk )) at a = (1 + c)g(λk )
λk (λ1 −λk )
and is given by (1+c)g(λk )(λ1 −(1+c)g(λk ))
. From the definition of λk , we observe that
λ1 − (1 + c)g(λk ) is atleast (1 − 4c )(λ1 − (1 + c)g(λk )) for any c ∈ (0, 1). Thus,
we have

λk (λ1 − λk ) 1 λk (λ1 − λk )
≤ c
(1 + c)g(λk )(λ1 − (1 + c)g(λk )) (1 + c)(1 − 4 ) g(λk )(λ1 − g(λk ))
2
= ≤ 2 − 0.5c,
(1 + c)(1 − 4c )

where the final step holds true for any c ∈ (0, 1).

Thus, we have shown that

 
λk λk λk λk λk 2
2 (1 − ) ≤ min 4 (1 − ), 1 − ( )
λ1 λ1 λ1 λ1 λ1
(λ1 − λk )(λ1 + λk − ke
x(t)k − ke
x(t + 1)k)
≤ ≤ 2 − 0.5c.
(λ1 − kex(t + 1)k)(λ1 − ke
x(t)k)

The result follows after substituting this bound in Equation (9.13).

λ1
Lemma 9.9.11. At any step t, if ke
x(t)k ≤ 2
,

x(t + 1), e1 ))| ≤ max( λλ21 , 1 − 2 λλD1 ) |tan(∠(e


1. |tan(∠(e x(t), e1 ))|.

λ1
2. |tan(∠(e
x(t + 2), e1 ))| ≤ ke
x(t)k
|tan(∠(e
x(t), e1 ))|.
434
Proof of Lemma 9.9.11. From the Normalized GD update rule, we have

 
λi
x ei (t) 1 −
ei (t + 1) = x , for all i ∈ [D],
ke
x(t)k

 
1
implying |e
xi (t + 1)| < 1− ke
x(t)k
|e
xi (t)| for all i ∈ [2, D], since λi < 1.
λ1
Since λi < λ1 and ke
x(t)k ≤ 2
, it holds that

λi
|e
xi (t + 1)| 1− ke
x(t)k |e
xi (t)| λ1 − λi |e
xi (t)| λi λi |e
xi (t)|
= λ1
= 1− ≤ max( , 1 − 2 ) .
|e
x1 (t + 1)| 1− ke
x(t)k
|e
x1 (t)| λ1 − ke
x(t)k |e
x1 (t)| λ1 λ1 |e
x1 (t)|

Finally we conclude

P (2:D) x
e(t + 1) λ2 λD P (2:D) x
e(t)
≤ max( , 1 − 2 ) .
|e
x1 (t + 1)| λ1 λ1 |e
x1 (t)|

kP (2:D) vk
Recall |tan(∠(v, e1 ))| = |he1 ,vi|
for any vector v, the first claim follows from re-
arranging the terms.
For the second claim, it suffices to apply the above inequality to t + 1, which yields
that

λ2 λ1 − λi λ1
|tan(∠(e
x(t + 2), e1 ))| ≤ max( , − 1) |tan(∠(e
x(t + 1), e1 ))| ≤ |tan(∠(e
x(t +
λ1 λ1 − kxt+1 k λ1 − kxt+1 k

The proof is completed by noting ke


x(t + 1)k ≤ λ1 −ke
xtk (Lemma 9.9.6) and tan(∠(e
x(t+
1), e1 )) ≤ tan(∠(e
x(t), e1 )).

9.10 Setups for General Loss Functions

Before we start the analysis for Normalized GD for general loss functions in Section 9.11,
we need to introduce some new notations and terminologies to complete the formal

435
setup. We start by first recapping some core assumptions and definitions in the main
paper and provide the missing proof in the main paper.

Assumption 9.5.1. Assume that the loss L : RD → R is a C 4 function, and that Γ


is a (D − M ) dimensional C 1 -submanifold of RD for some integer 1 ≤ M ≤ D, where
for all x ∈ Γ, x is a local minimizer of L with L(x) = 0 and rank (∇2 L(x)) = M .

Assumption 9.5.2. For any x ∈ Γ, ∇2 L(x) has a positive eigengap, i.e.,


λ1 (∇2 L(x)) > λ2 (∇2 L(x)).

Notations: We define Φ as the limit map of gradient flow below. We summarize


various properties of Φ from Chapter 8 in Section 9.10.2.

Z τ
Φ(x) = lim φ(x, τ ), where φ(x, τ ) = x − ∇L(φ(x, s))ds. (9.14)
τ →∞ 0

Let U be the sets of points starting from which, gradient flow w.r.t. loss L converges
to some point in Γ, that is, U := {x ∈ RD | Φ(x) exists and Φ(x) ∈ Γ}. We have that
U is open and Φ is C 3 on U . (By Lemma 8.8.2)
For a matrix A ∈ RD×D , we denote its eigenvalue-eigenvector pairs by
{λi (A), vi (A))}i∈[D] . For simplicity, whenever Φ is defined and C 2 at point x,
we use {(λi (x), vi (x))}D 2
i=1 to denote the eigenvector-eigenvalue pairs of ∇ L(Φ(x)),

with λ1 (x) > λ2 (x) ≥ λ3 (x) . . . ≥ λD (x). Given a differentiable submanifold Γ of RD


and point x ∈ Γ, we use Nx Γ and Tx Γ to denote the normal space and the tangent
space of the manifold Γ for any point x ∈ Γ. We use Px,Γ : Γ → RD to denote the

projection operator onto the normal space of Γ at x, and Px,Γ := ID − Px,Γ . Similar
e to denote ∇2 L(Φ(x))(x − Φ(x)) for
to quadratic case, for any x ∈ U , we use x
p √
Normalized GD on L and to denote 2∇2 L(Φ(x))(x − Φ(x)) for GD on L. We also
(j:D)
use Px,Γ to denote the projection matrix M >
P
j=i vi (x)vi (x) for x ∈ Γ and j ∈ [M ].

1:D
Therefore, Px,Γ = Px,Γ by Lemma 9.10.17. Additionally, for any x ∈ U , we use θ(x)

436
to denote the angle between x
e and the top eigenspace of the hessian at Φ(x), i.e.
(2:M )
PΦ(x),Γ x
e
θ(x) = arctan |hv1 (x),e
xi|
. Furthermore, when the iterates x(t) is clear in the context,
⊥ ⊥
we use shorthand λi (t) := λi (x(t)), vi (t) := vi (x(t)), Pt,Γ := PΦ(x(t)),Γ , Pt,Γ := PΦ(x(t)),Γ
and θt to denote θ(x(t)). We define the function gt : R → R for every t ∈ N as

s  !
1 λ λ
gt (λ) = 1− 1−2 1− .
2 λ1 (t) λ1 (t)

Given any two points x, y, we use xy to denote the line segment between x and y,
i.e., {z | ∃λ ∈ [0, 1], z = (1 − λ)x + λy}.
The main result of this chapter focuses on the trajectory of Normalized GD
from fixed initialization xinit with LR η converges to 0, which can be roughly split
into two phases. In the first phase, Theorem 9.5.3 shows that the normalized GD
trajectory converges to the gradient flow trajectory, φ(xinit , ·). In second phase,
Theorem 9.5.4 shows that the normalized GD trajectory converges to the limiting flow
which decreases sharpness on Γ, (9.5). Therefore, for sufficiently small η, the entire
trajectory of normalized GD will be contained in a small neighbourhood of gradient
flow trajectory Z and limiting flow trajectory Y . The convergence rate given by our
proof depends on the various local constants like smoothness of L and Φ in this small
neighbourhood, which intuitively can be viewed as the actual ”working zone” of the
algorithm. The constants are upper bounded or lower bounded from zero because this
”working zone” is compact after fixing the stopping time of (9.5), which is denoted by
T2 .

Z τ
1 ⊥
X(τ ) = Φ(xinit ) − PX(s),Γ ∇ log λ1 (X(s))ds, X(τ ) ∈ Γ (9.5)
4 s=0

Below we give formal definitions of the ”working zones” and the corresponding
properties. For any point y ∈ RD and positive number r, we define Br (y) := {x ∈
RD | ky − xk < r} as the open `2 norm ball centered at y and B r (y) as its closure. For
437
any set S and positive number r, we define S r := ∪y∈S Br (y) and B r (S) := ∪y∈S B r (y).
Given the stopping time T2 > 0, we denote the trajectory of limiting flow Equation (9.5)
{X(τ )}Tτ =0
2
by Y and we use the notation Y r := ∪y∈Y B y (r) for any r > 0. By definition,
Y r are compact for any r > 0.
We construct the ”working zone” of the second phase, Y ρ and Y  in Lemmas 9.10.2
and 9.10.5 respectively, where 0 <  < ρ, implying Y  ⊂ Y ρ . The reason that we
need the two-level nested ”working zones” is that even though we can ensure all the
points in Y ρ have nice properties as listed in Lemma 9.10.2, we cannot ensure the
trajectory of gradient flow from x ∈ Y ρ to Φ(x) or the line segment xΦ(x) is in Y ρ ,
which will be crucial for the geometric lemmas (in Section 9.10.1) that we will heavily
use in the trajectory analysis around the manifold. For this reason we further define
Y  and Lemma 9.10.5 guarantees the trajectory of gradient flow from x to Φ(x) or
the line segment xΦ(x) whenever x ∈ Y ρ .

Definition 9.10.1 (PL condition). A function L is said to be µ-PL in a set U iff for
all x ∈ U ,
k∇L(x)k2 ≥ 2µ(L(x) − inf L(x)).
x∈U

For convenience, we define ∆ := 12 inf x∈Y λ1 (∇2 L(x)) − λ2 (∇2 L(x))) and µ :=


1
4
inf x∈Y λM (∇2 L(x)). By Assumption 9.5.1, we have µ > 0. By Assumption 9.5.2,
∆ > 0.

Lemma 9.10.2. Given Y , there are sufficiently small ρ > 0 such that
1. Y ρ ∩ Γ is compact;
2. Y ρ ⊂ U ;
3. L is µ-PL on Y ρ ; (see Definition 9.10.1)

4. inf x∈Y ρ λ1 (∇2 L(x)) − λ2 (∇2 L(x))) ≥ ∆ > 0;
5. inf x∈Y ρ λM (∇2 L(x)) ≥ µ > 0.

438
Proof of Lemma 9.10.2. We first claim for every y ∈ Y , for all sufficiently small
ρy > 0 (i.e. for all ρy smaller than some threshold depending on y), the following
three properties hold (1) B y (ρy ) ∩ Γ is compact; (2) B y (ρy ) ∩ Γ ⊂ U and (3) L is
µ-PL on B y (ρy ∩ Γ).
Among the above three claims, (2) is immediate. (1) holds because B y (ρy ) ∩ Γ
is bounded and we can make ρy small enough to ensure B y (ρy ) ∩ Γ is closed. For
(3), by Proposition 7 of [203], we define p(y) := argminx∈Γ kx − yk which is uniquely
defined and C 1 in B y (ρy ) for sufficiently small ρy . Moreover, Lemma 14 in [203] shows
that k∇L(x) − ∇2 L(p(x))(x − p(x))k ≤ c kx − p(x)k22 for all x in By (ρy ) uniformly
and some constant c. Thus for small enough ρy ,

k∇L(x)k2 ≥(x − p(x))> (∇2 L(p(x)))2 (x − p(x)) − O(kx − p(x)k3 ) (9.15)

Furthermore, by Lemma 10 in [203], it holds that x − p(x) ∈ Np(x) Γ =


span({vi (p(x))}M
i=1 ), which implies

(x − p(x))> (∇2 L(p(x)))2 (x − p(x)) ≥ λM (∇2 L(p(x)))(x − p(x))> ∇2 L(p(x))(x − p(x)),


(9.16)

and that

(x − p(x))> ∇2 L(p(x))(x − p(x)) ≥ λM (∇2 L(p(x))) kx − p(x)k22 .

Thus for any c0 > 0, for sufficiently small ρy , (x − p(x))> ∇2 L(p(x))(x − p(x)) ≥
c0 kx − p(x)k3 . Combining Equations (9.15) and (9.16), we conclude that for sufficiently
small ρy ,

k∇L(x)k2 ≥ λM (∇2 L(p(x)))(x − p(x))> ∇2 L(p(x))(x − p(x)) − O(kx − p(x)k32 )

439
Again for sufficiently small ρy , by Taylor expansion of L at p(x), we have

1
(x − p(x))> ∇2 L(p(x))(x − p(x)) ≥ L(x) − O(kx − p(x)k3 ).
2

Thus we conclude

k∇L(x)k2 ≥ 2λM (∇2 L(p(x)))L(x)−O(kx − p(x)k3 ) ≥ λM (∇2 L(p(x)))L(x) ≥ 2µL(x).

Meanwhile, since λM (∇2 L(p(x))) and λ1 (∇2 L(p(x))) − λ2 (∇2 L(p(x))) are
continuous functions in x, we can also choose a sufficiently small ρy such that
1 1
for all x ∈ B y (ρy ), λM (∇2 L(p(x))) ≥ λ (∇2 L(p(y)))
2 M
= λ (∇2 L(y))
2 M
> ∆
and λ1 (∇2 L(p(x))) − λ2 (∇2 L(p(x))) ≥ 12

λ1 (∇2 L(p(y))) − λ2 (∇2 L(p(y))) =
1

2
λ1 (∇2 L(y)) − λ2 (∇2 L(y)) ≥ µ. Further note Y ⊂ ∪y∈Y By (ρy ) and Y is a compact
set, we can take a finite subset of Y , Y 0 , such that Y ⊂ ∪y∈Y 0 By (ρy ). Taking
ρy
ρ := miny∈Y 0 2
completes the proof.

Definition 9.10.3. The spectral 2-norm of a k-order tensor T = (ti1 i2 ···ik ) ∈


Rd1 ×d2 ×···dk is defined as the maximum of the following constrained multilinear
optimization problem:

kT k = max T x(1) , · · · , x(k) : x(i) = 1, x(i) ∈ Rdi , i = 1, 2, . . . , k .


 
2

(1) (2) (k)


Here, T x1 , · · · , x(k) = di11=1 di22=1 . . . dikk=1 ti1 i2 ...id xi1 xi2 . . . xik .
 P P P

Definition 9.10.4. We define the following constants regarding smoothness of L and


Φ of various orders over Y ρ .

ζ = sup ∇2 L(x) , ν = sup ∇3 L(x) , Υ = sup ∇4 L(x) ,


x∈Y ρ x∈Y ρ x∈Y ρ

ξ = sup ∇2 Φ(x) , χ = sup ∇3 Φ(x) ,


x∈Y ρ x∈Y ρ

440
We assume each of the constants ζ, ν, Υ, ξ, χ are at least 1 for simplicity (otherwise
we can set them to be 1 and our bound still holds)

Lemma 9.10.5. Given ρ as defined in Lemma 9.10.2, there is an  ∈ (0, ρ) such that
2 5
1. supx∈Y  L(x) − inf L(x) < min( µρ8 , ν2µ
2 ζ 2 );
x∈Y
ρ

2. ∀x ∈ Y , Φ(x) ∈ Y 2 .

Proof of Lemma 9.10.5. For every y ∈ Y , there is an y , such that ∀x ∈ By (y ),


2 4 ρ
it holds that L(x) < min( µρ8 , ν2µ
2 ζ 2 ) and Φ(x) ∈ Y
2 , as both L(x) and Φ(x) are

continuous. Further note Y ⊂ ∪y∈Y By (y ) and Y is a compact set, we can take a
y
finite subset of Y , Y 0 , such that Y ⊂ ∪y∈Y 0 By (y ). Taking  := miny∈Y 0 2
completes
the proof.

Summary for Setups: The initial point xinit is chosen from an open neighborhood
of manifold Γ, U , where the infinite-time limit of gradient flow Φ is well-defined and
for any x ∈ U , Φ(x) ∈ Γ. We consider normalized GD with sufficiently small LR
η such that the trajectory enters a small neighborhood of limiting flow trajectory,
Y ρ . Moreover, L is µ-PL on Y ρ and the eigengaps and smallest eigenvalues are
uniformly lower bounded by positive ∆, µ respectively on Y ρ . Finally, we consider
a proper subset of Y ρ , Y  , as the final ”working zone” in the second phase (defined
in Lemma 9.10.5), which enjoys more properties than Y ρ , including Lemmas 9.10.7
to 9.10.10.

9.10.1 Geometric Lemmas

In this subsection we present several geometric lemmas which are frequently used in
the trajectory analysis of normalized GD. In this section, O(·) only hides absolute
constants. Below is a brief summary:

441
• Lemma 9.10.6: Inequalities connecting various terms: the distance between x
and Φ(x), the length of GF trajectory from x to Φ(x), square root of loss and
gradient norm;

• Lemma 9.10.7: For any x ∈ Y  , the gradient flow trajectory from x to Φ(x) and
the line segment between x and Φ(x) are all contained in Y ρ , so it’s ”safe” to
use Taylor expansions along GF trajectory or xΦ(x) to derive properties;

• Lemmas 9.10.8 to 9.10.10: for any x ∈ Y  , the normalized GD dynamics at x


can be roughly viewed as approximately quadratic around Φ(x) with positive
definite matrix ∇2 L(Φ(x)).

• Lemma 9.10.11: In the ”working zone”, Y ρ , one-step normalized GD update


with LR η only changes Φ(xt ) by O(η 2 ).

• Lemma 9.10.13: In the ”working zone”, Y ρ , one-step normalized GD update



with LR η decreases L(x) − miny∈Y L(y) by η 42µ if k∇L(x)k ≥ ζη .
p

Lemma 9.10.6. If the trajectory of gradient flow starting from x, φ(x, t), stays in
Y ρ for all t ≥ 0, then we have

s

2(L(x) − L(Φ(x))) k∇L(x)k
Z
dφ(x, t)
kx − Φ(x)k ≤ dt ≤ ≤ .
t=0 dt µ µ

Proof of Lemma 9.10.6. Since Φ(x) is defined as limt→∞ φ(x, t) and φ(x, 0) = x, the
left-side inequality follows immediately from triangle inequality. The right-side in-
equality is by the definition of PL condition. Below we prove the middle inequality.
Since ∀t ≥ 0, φ(x, t) ∈ Y ρ , it holds that k∇L(φ(x, t))k2 ≥ 2µ(L(φ(x, t))−L(Φ(x)))
by the choice of ρ in Lemma 9.10.2. Without loss of generality, we assume L(y) =

442
0, ∀y ∈ Γ. Thus we have

∞ ∞
k∇L(φ(x, t))k2
Z Z
k∇L(φ(x, t))k dt ≤ p dt.
t=0 t=0 2µL(φ(x, t))

Since dφ(x, t) = −∇L(φ(x, t))dt, if holds that

s

k∇L(φ(x, t))k2 ∞ ∞
r
−dL(φ(x, t))
Z Z Z
2 p 2L(φ(x, 0))
p dt ≤ p = d L(φ(x, t)) = .
t=0 2µL(φ(x, t)) t=0 2µL(φ(x, t)) t=0 µ µ

The proof is complete since φ(x, 0) = x and we assume L(Φ(x)) is 0.

Lemma 9.10.7. Let ρ,  be defined in Lemmas 9.10.2 and 9.10.5. For any x ∈ Y  ,
we have

1. The entire trajectory of gradient flow starting from x is contained in Y ρ , i.e.,


φ(x, t) ∈ Y ρ , ∀t ≥ 0;

2
2. Moreover, kΦ(x) − φ(x, t)k ≤ min(ρ, 2µ
νζ
), ∀t ≥ 0.

Proof of Lemma 9.10.7. Let time τ ∗ ≥ 0 be the smallest time after which the trajec-
tory of GF is completely contained in Y ρ , that is, τ ∗ := inf{t ≥ 0 | ∀t0 ≥ t, φ(x, t0 ) ∈
Y ρ }. Since Y ρ is closed and φ(x, ·) is continuous, we have φ(x, τ ∗ ) ∈ Y ρ .
Since ∀τ ≥ τ ∗ , φ(x, τ ) ∈ Y ρ , by Lemma 9.10.6, it holds that kφ(x, τ ∗ ) − Φ(x)k ≤
q
2(L(φ(x,τ ∗ ))−L(Φ(x)))
µ
.
Note that loss doesn’t increase along GF, we have L(φ(x, τ ∗ )) − L(Φ(x)) ≤ L(x) −
µρ2
L(Φ(x)) ≤ 8
, which implies that kφ(x, τ ∗ ) − Φ(x)k ≤ ρ2 . Therefore τ ∗ must be 0,
otherwise there exists a 0 < τ 0 < τ ∗ such that kφ(x, τ ) − Φ(x)k ≤ ρ for all τ 0 < τ < τ ∗
by the continuity of φ(x, ·). This proves the first claim.
Given the first claim is proved, the second claim follows directly from Lemma 9.10.6.

The following theorem shows that the projection of x in the tangent space of
Φ(x) is small when x is close to the manifold. In particular if we can show that in a
443
discrete trajectory with a vanishing learning rate η, the iterates {xη (t)} stay in Y  ,
we can interchangeably use kxη (t) − Φ(xη (t))k with kPt,Γ (xη (t) − Φ(xη (t)))k, with an
additional error of O(η 3 ), when kPt,Γ (xη (t) − Φ(xη (t)))k ≤ O(η).

Lemma 9.10.8. For all x ∈ Y  , we have that

νζ

PΦ(x),Γ (x − Φ(x)) ≤ kx − Φ(x)k2 ,
4µ2

and that

 
2 2 νζ 1
PΦ(x),Γ (x − Φ(x)) ≥ kx − Φ(x)k 1 − 2
kx − Φ(x)k ≥ kx − Φ(x)k2 .
4µ 2

Proof of Lemma 9.10.8. First of all, we can track the decrease in loss along the
Gradient flow trajectory starting from x. At any time τ , we have

d d
L(φ(x, τ )) = h∇L(φ(x, τ )), φ(x, τ )i = − k∇L(φ(x, τ ))k2 ,
dτ dτ

where φ(x, 0) = x. Without loss of generality, we assume L(y) = 0, ∀y ∈ Γ. Using the


fact that L is µ-PL on Y ρ and the GF trajectory starting from any point in Y  stays
inside Y ρ (from Lemma 9.10.7), we have

d
L(φ(x, τ )) ≤ −2µL(φ(x, τ )),

which implies

L(φ(x, τ )) ≤ L(φ(x, 0))e−2µτ

444
By Lemma 9.10.6, we have

s
2√
r
2L(φ(x, 0))e−2µτ
kφ(x, τ ) − Φ(x)k ≤ L(φ(x, τ )) ≤ . (9.17)
µ µ

Moreover, we can relate L(φ(x, 0) with kΦ(x) − xk with a second order taylor
expansion:

L(x) =L(Φ(x)) + h∇L(Φ(x)), x − Φ(x)i


Z 1
+ (1 − s)(x − Φ(x))> ∇2 L(sx + (1 − s)Φ(x))(x − Φ(x))ds
s=0

where in the final step, we have used the fact that L(Φ(x)) = 0 and ∇L(Φ(x)) = 0. By
Lemma 9.10.7, we have xΦ(x) ⊂ Y ρ . Thus maxs∈[0,1] k∇2 L(sx + (1 − s)Φ(x))k ≤ ζ
from Definition 9.10.4 and it follows that

Z 1
ζ
L(x) ≤ (1 − s)ζ kx − Φ(x)k2 ds = kΦ(x) − xk2 , (9.18)
s=0 2

Finally we focus on the movement in the tangent space. It holds that

Z ∞ Z ∞
⊥ ⊥ ⊥
PΦ(x),Γ (φ(x, ∞) − φ(x, 0)) ≤ PΦ(x),Γ ∇L(φ(x, τ )) dτ ≤ PΦ(x),Γ ∇L(φ(x, τ )) dτ.
0 0

(9.19)

By Lemma 9.10.7, we have φ(x, τ )Φ(x) ⊂ Y ρ for all τ ≥ 0 and thus

ν
∇L(φ(x, τ )) − ∇2 L(Φ(x)) Φ(x) φ(x, τ ) − Φ(x) ≤ kφ(x, τ ) − Φ(x)k2 .
 
2

445
⊥ ⊥
Since PΦ(x),Γ is the projection matrix for the tangent space, PΦ(x),Γ ∇2 L(Φ(x)) = 0 and
thus by Equation (9.17)

⊥ ν 2 νL(φ(x, 0))e−2µτ
PΦ(x),Γ ∇L(φ(x, τ )) ≤ kφ(x, τ ) − Φ(x)k ≤ (9.20)
2 µ

Plug Equation (9.20) into Equation (9.19), we conclude that


νL(φ(x, 0))e−2µτ νζ kx − Φ(x)k2
Z
⊥ νL(x)
PΦ(x),Γ (φ(x, ∞) − x) ≤ = ≤
τ =0 µ 2µ2 4µ2
(9.21)
For the second claim, simply note that


PΦ(x),Γ (x − Φ(x))
q
2
= kx − Φ(x)k2 − PΦ(x),Γ (x − Φ(x))
2
PΦ(x),Γ (x − Φ(x))
≥ kx − Φ(x)k − .
kx − Φ(x)k

The left-side inequality of the second inequality is proved by plugging the first
claim into the above inequality Equation (9.21) and rearranging the terms. Note by
νζ
the second claim in Lemma 9.10.7, 4µ2
kx − Φ(x)k ≤ 12 , the right-side inequality is
also proved.

Lemma 9.10.9. At any point x ∈ Y  , we have

1
∇L(x) − ∇2 L(Φ(x))(x − Φ(x)) ≤ ν kx − Φ(x)k2 .
2

and

k∇L(x)k ν
− 1 ≤ kx − Φ(x)k ,
k∇2 L(Φ(x))(x − Φ(x))k µ

446
Moreover, the normalized gradient of L can be written as

∇L(x) ∇2 L(Φ(x))(x − Φ(x)) ν


= 2
+ O( kx − Φ(x)k). (9.22)
k∇L(x)k k∇ L(Φ(x))(x − Φ(x))k µ

Proof of Lemma 9.10.9. Using taylor expansion at x, we have using ∇L(Φ(x)) = 0:

∇L(x) − ∇2 L(Φ(x))(x − Φ(x))


Z 1
= (1 − s)∂ 2 (∇L)(sx + (1 − s)Φ(x))[x − Φ(x), x − Φ(x)]ds
0
Z 1
≤ (1 − s)ds max ∂ 2 (∇L)(sx + (1 − s)Φ(x)) kx − Φ(x)k2
0 0≤s≤1

1
≤ ν kx − Φ(x)k2 .
2

Further note that

∇2 L(Φ(x))(x − Φ(x)) ≥ PΦ(x),Γ ∇2 L(Φ(x))(x − Φ(x)) = ∇2 L(Φ(x))PΦ(x),Γ (x − Φ(x))

≥µ PΦ(x),Γ (x − Φ(x)) ,

we have

k∇L(x)k ν kx − Φ(x)k2 ν
2
−1 ≤ ≤ kx − Φ(x)k ,
k∇ L(Φ(x))(x − Φ(x))k 2µ PΦ(x),Γ (x − Φ(x)) µ

where we use Lemma 9.10.8 since x ∈ Y  . Thus, the normalized gradient at any step
t can be written as

∇L(x) ∇2 L(Φ(x))[x − Φ(x)] + O(ν kx − Φ(x)k2 )


=    .
k∇L(x)k 2 ν
k∇ L(Φ(x))(x − Φ(x))k 1 + O µ kx − Φ(x)k
∇2 L(Φ(x))[x − Φ(x)] ν
= 2
+ O( kx − Φ(x)k),
k∇ L(Φ(x))[x − Φ(x)]k µ

which completes the proof.

447
Lemma 9.10.10. Consider any point x ∈ Y  . Then,

 
∇L(x) ν
v1 (x), ≥ cos θ − O( kx − Φ(x)k),
k∇L(x)k µ

(2:M )
PΦ(x),Γ x
e = ∇2 L(Φ(x))(x − Φ(x)).
e
where θ = arctan |hv1 (x),e
xi|
, with x

Proof of Lemma 9.10.10. From Lemma 9.10.9, we have that

∇L(x) ∇2 L(Φ(x))(x − Φ(x)) ν


= 2
+ O( kx − Φ(x)k).
k∇L(x)k k∇ L(Φ(x))(x − Φ(x))k µ

Hence, we have that

|hv1 (x), ∇L(x)i| |hv1 (x), ∇2 L(Φ(x))(x − Φ(x))i| ν


= + O( kx − Φ(x)k)
k∇L(x)k k∇2 L(Φ(x))(x − Φ(x))k µ
ν
≥ cos θ − O( kx − Φ(x)k),
µ

which completes the proof.

∇L(x)
Lemma 9.10.11. For any xy ∈ Y  where y = x − η k∇L(x)k is the one step Normalized
GD update from x, we have

1
kΦ(y) − Φ(x)k ≤ ξη 2 .
2

Moreover, we must have for every 1 ≤ k ≤ M ,

1
λk (∇2 L(Φ(x))) − λk (∇2 L(Φ(y))) ≤ νξη 2 ,
4

and

1 νξη 2 νξη 2 ν 2 ξ2 η4
v1 (∇2 L(Φ(x))) − v1 (∇2 L(Φ(y))) ≤ = + O( ).
2 ∆ − 14 νξη 2 2∆ ∆

448
Proof. By Lemma 9.10.14, we have ∂Φ(x)∇L(x) = 0 for all x ∈ U . Thus we have

1 
∇L(x) ∇L(x)
Z
kΦ(y) − Φ(x)k =η ∂Φ x − sη ds
s=0 k∇L(x)k k∇L(x)k
Z 1    
∇L(x) ∇L(x)
=η ∂Φ x − sη − ∂Φ(x) ds
s=0 k∇L(x)k k∇L(x)k
Z 1  
∇L(x)
≤η ∂Φ x − sη − ∂Φ(x) ds
s=0 k∇L(x)k
Z 1
≤η 2
s sup ∇2 Φ((1 − s0 )x + s0 y) ds
s=0 s0 ∈[0,s]
2
η
= sup ∇2 Φ((1 − s0 )x + s0 y)
2 s0 ∈[0,1]
1
≤ ξη 2 ,
2

where the final step follows from using Definition 9.10.4.


For the second claim, we have for every 1 ≤ k ≤ M ,

λk (∇2 L(Φ(x))) − λk (∇2 L(Φ(y)))

≤ ∇2 L(Φ(x)) − ∇2 L(Φ(y))
Z 1
= (1 − s)∂ 2 (∇L)(Φ(sx + (1 − s)y))(Φ(x) − Φ(y))ds
s=0
Z 1
≤ (1 − s)ds max ∂ 2 (∇L)(Φ(sx + (1 − s)y)) kΦ(x) − Φ(y))k
s=0 s∈[0,1]

1
≤ νξη 2 ,
4

where the first step involves Theorem 9.14.2.

449
The third claim follows from using Theorem 9.14.4. Again,

k∇2 L(Φ(x)) − ∇2 L(Φ(y))k


v1 (∇2 L(Φ(x))) − v1 (∇2 L(Φ(y))) ≤
λ1 (∇2 LΦ(x)) − λ2 (∇2 L(Φ(y)))
1 νξη 2

2 λ1 (∇2 L(Φ(x))) − λ2 (∇2 L(Φ(y))
1 νξη 2

2 λ1 (∇2 L(Φ(x))) − λ2 (∇2 L(Φ(x)) − 41 νξη 2
1 νξη 2
≤ ,
2 ∆ − 14 νξη 2

where we borrow the bound on k∇2 L(Φ(x)) − ∇2 L(Φ(y))k from our previous calcula-
tions. The final step follows from the constants defined in Definition 9.10.4.

∇L(x)
Lemma 9.10.12. For any xy ∈ Y  where y = x − η k∇L(x)k is the one step Normalized
GD update from x, we have that

η2 ⊥
Φ(y) − Φ(x) = − P ∇(log λ1 (∇2 L(Φ(x))))
4 Φ(x),Γ
νξ kx − Φ(x)k η 2
+ O(η 2 ξθ) + O( ) + O(χ kx − Φ(x)k η 2 ) + O(χη 3 ).
µ

(2:M )
PΦ(x),Γ x
e = ∇2 L(Φ(x))(x − Φ(x)). Additionally, we have that
e
Here θ = arctan |hv1 (x),e
xi|
, with x

νξ
PΦ(x),Γ (Φ(y) − Φ(x)) ≤ O(χ kx − Φ(x)k η 2 ) + O(χη 3 ) + O( kx − Φ(x)k η 2 ).
µ

Proof of Lemma 9.10.12. By Taylor expansion for Φ at x, we have

1
Φ(y) − Φ(x) =∂Φ(x) (y − x) + ∂ 2 Φ(x)[y − x, y − x] + O(χ ky − xk3 )
2 
η2 2
  
∇L(x) ∇L(x) ∇L(x)
=∂Φ(x) −η + ∂ Φ(x) , + O(χη 3 )
k∇L(x)k 2 k∇L(x)k k∇L(x)k
η2 2
 
∇L(x) ∇L(x)
= ∂ Φ(x) , + O(χη 3 ),
2 k∇L(x)k k∇L(x)k

450
where in the pre-final step, we used the property of Φ from Lemma 9.10.14. In the
final step, we have used a second order taylor expansion to bound the difference
∇L(x)
between ∂ 2 Φ(x) and ∂ 2 Φ(Φ(x)). Additionally, we have used y − x = η k∇L(x)k from the
Normalized GD update rule.
Applying Taylor expansion on Φ again but at Φ(x), we have that

η2 2
 
∇L(x) ∇L(x)
Φ(y) − Φ(x) = ∂ Φ(Φ(x)) , + O(χ kx − Φ(x)k η 2 ) + O(χη 3 )
2 k∇L(x)k k∇L(x)k
(9.23)

Also, at Φ(x), since v1 (x) is the top eigenvector of the hessian ∇2 L, we have that
from Corollary 9.10.21,

1
∂ 2 Φ(Φ(x)) v1 (x)v1 (x)> = − ∂Φ(Φ(x))∂ 2 (∇L)(Φ(x))[v1 (x), v1 (x)].
 
(9.24)
2λ1 (x)

By Lemma 9.10.10, it holds that

 
∇L(x) ∇L(x)
sign , v1 (x) − v1 (x)
k∇L(x)k k∇L(x)k
θ ν kx − Φ(x)k ν kx − Φ(x)k
≤2 sin + O( ) ≤ θ + O( ). (9.25)
2 µ µ

Plug Equations (9.24) and (9.25) into Equation (9.23), we have that

η2 1
Φ(y) − Φ(x) = − ∂Φ(Φ(x))∂ 2 (∇L)(Φ(x))[v1 (x), v1 (x)]
2 2λ1 (x)
νξ kx − Φ(x)k η 2
+ O(η 2 ξθ) + O( ) + O(χ kx − Φ(x)k η 2 ) + O(χη 3 ).
µ

By Lemma 9.10.16, for any x ∈ Γ, ∂Φ(x) is the projection matrix onto the tangent

space TΦ(x) Γ. Thus, ∂Φ(Φ(x)) = PΦ(x),Γ . Thus the proof of the first claim is completed

by noting that ∂Φ(Φ(x))∂ 2 (∇L)(Φ(x))[v1 (x), v1 (x)] = PΦ(x),Γ ∇λ1 (∇2 L(Φ(x))) by
Corollary 9.10.22.

451
For the second claim, continuing from Equation (9.23), we have that

η2 2
 
∇L(x) ∇L(x)
Φ(y) − Φ(x) = ∂ Φ(Φ(x)) , + O(χ kx − Φ(x)k η 2 ) + O(χη 3 )
2 k∇L(x)k k∇L(x)k
2
η νξ
= ∂ 2 Φ(Φ(x)) [Σ] + O(χ kx − Φ(x)k η 2 ) + O(χη 3 ) + O( kx − Φ(x)k η 2 ),
2 µ

 >
∇L(x) ∇L(x)
where Σ = PΦ(x),Γ k∇L(x)k PΦ(x),Γ k∇L(Φ(x))k and the last step is by Lemma 9.10.9.
Here PΦ(x),Γ denotes the projection matrix of the subspace spanned by v1 (x), . . . , vM (x).
By Lemmas 9.10.16, 9.10.17 and 9.10.20, we have that PΦ(x),Γ ∂ 2 Φ(Φ(x)) [Σ] =
−PΦ(x),Γ ∂Φ(x)∂ 2 (∇L)(x)[L−1
∇2 L(x) Σ] = 0, we conclude that

νξ
PΦ(x),Γ (Φ(y) − Φ(x)) ≤ O(χ kx − Φ(x)k η 2 ) + O(χη 3 ) + O( kx − Φ(x)k η 2 ),
µ

which completes the proof.

∇L(x)
Lemma 9.10.13. Let Lmin = miny∈U L(y). For any xy ∈ Y  where y = x − η k∇L(x)k
is the one step Normalized GD update from x, if k∇L(xη (t))k ≥ ζη, we have that


p p 2µ
L(y) − Lmin ≤ L(x) − Lmin − η .
4

Proof of Lemma 9.10.13. By Taylor expansion, we have that

ζη 2
L(y) ≤ L(x) − η k∇L(x)k + .
2

Thus for k∇L(xη (t))k ≥ ζη, we have that


η 2µ p
L(y) − L(x) ≤ − k∇L(x)k ≤ −η L(x) − Lmin ≤ 0,
2 2

452
where the last step is because L is µ-PL on Y  . In other words, we have that

p √ √
p p L(x) − Lmin 2µ 2µ
L(y) − Lmin − L(x) − Lmin ≤ −η p p ≤ −η ,
L(y) − Lmin + L(x) − Lmin 2 4

where in the last step we use L(y) − L(x) ≤ 0. This completes the proof.

9.10.2 Properties of limiting map of gradient flow, Φ

The following results Lemmas 9.10.14 to 9.10.18 and 9.10.20 and definition 9.10.19
are from Chapter 8.

Lemma 9.10.14. For any x ∈ U , it holds that (1). ∂Φ(x)∇L(x) = 0 and (2).
∂ 2 Φ(x)[∇L(x), ∇L(x)] = −∂Φ(x)∇2 L(x)∇L(x).

Lemma 9.10.15. For any x ∈ Γ and any v ∈ Tx Γ, it holds that ∇2 L(x)v = 0.

Lemma 9.10.16. For any x ∈ Γ, ∂Φ(x) ∈ RD×D is the projection matrix onto the

tangent space Tx Γ, i.e. ∂Φ(x) = Px,Γ .

Lemma 9.10.17. For any x ∈ Γ, if v1 , . . . , vM denote the non-zero eigenvectors of


the hessian ∇2 L(Φ(x)), then v1 , . . . , vM ∈ Nx Γ.

Lemma 9.10.18. For any x ∈ Γ and u ∈ Nx Γ, it holds that

∂ 2 Φ(x) uu> + ∇2 L(x)† uu> ∇2 L(x) = −∂Φ(x)∂ 2 (∇L)(x) ∇2 L(x)† uu> .


   

Definition 9.10.19 (Lyapunov Operator). For a symmetric matrix H, we define


WH = Σ ∈ RD×D | Σ = Σ> , HH † Σ = Σ = ΣHH † and Lyapunov Operator LH :


WH → WH as LH (Σ) = H > Σ + ΣH. It’s easy to verify L−1


H is well-defined on WH .

Lemma 9.10.20. For any x ∈ Γ and Σ = span{uu> | u ∈ Nx Γ},

h∂ 2 Φ(x), Σi = −∂Φ(x)∂ 2 (∇L)(x)[L−1


∇2 L(x) (Σ)].
453
We will also use the following two corollaries of Lemma 9.10.20.

Corollary 9.10.21. For any x ∈ Γ, let v1 be the unit top eigenvector of ∇2 L(x), then

1
∂ 2 Φ(x)[v1 v1> ] = − ∂Φ(x)∂ 2 (∇L)(x)[v1 , v1 ]
2λ1 (∇2 L(x))

Proof of Corollary 9.10.21. Simply note that L−1 >


∇2 L(x) (v1 v1 ) =
1
v v>
2λ1 (∇2 L(x)) 1 1
and
apply Lemma 9.10.20.

Corollary 9.10.22. For any x ∈ Γ, let v1 be the unit top eigenvector of ∇2 L(x), then

1 ⊥
∂ 2 Φ(x)[v1 v1> ] = − Px,Γ ∇ log(λ1 (∇2 L(x))).
2

Proof of Corollary 9.10.22. The proof follows from using Corollary 9.10.21 and the
derivative of λ1 from Theorem 9.14.1.

As a variant of Corollary 9.10.22, we have the following lemma.

Corollary 9.10.23. For any x ∈ Γ, let v1 be the unit top eigenvector of ∇2 L(x), then

1 ⊥
∂ 2 Φ(x)[λ1 (∇2 L(x))v1 v1> ] = − Px,Γ ∇(λ1 (∇2 L(x))).
2

9.11 Analysis of Normalized GD on General Loss

Functions

9.11.1 Phase I, Convergence

We restate the theorem concerning Phase I for the Normalized GD algorithm. Recall
the following notation for each 1 ≤ j ≤ M :

v
uM
uX
Rj (x) := t λ2i (x)hvi (x), x − Φ(x)i2 − λj (x)η, for all x ∈ U.
i=j

454
Theorem 9.5.3 (Phase I). Let {xη (t)}t∈N be the iterates of Normalized GD (9.4)
with LR η and xη (0) = xinit ∈ U . There is T1 > 0 such that for any T10 > T1 , it
holds that for sufficiently small η that (1) max kxη (t) − Φ(xinit )k ≤ O(η) and (2)
T1 ≤ηt≤T10
max Rj (xη (t)) ≤ O(η 2 ).
T1 ≤ηt≤T10 ,j∈[D]

The intuition behind the above theorem is that for sufficiently small LR η, xη (t)
will track the normalized gradient flow starting from xinit , which is a time-rescaled
version of the standard gradient flow. Thus the normalized GF will enter Y  and so
does normalized GD. Since L satisfies PL condition in Y  , the loss converges quickly
and the iterate xη (t) gets η to manifold. To finish, we need the following theorem,
which is the approximately-quadratic version of Lemma 9.3.3 when the iterate is O(η)
close to the manifold.

Lemma 9.11.1. Suppose {xη (t)}t≥0 are iterates of Normalized GD (9.4) with a
learning rate η and xη (0) = xinit . There is a constant C > 0, such that for any
kxη (t0 )−Φ(xη (t0 ))k
constant ς > 1, if at some time t0 , xη (t0 ) ∈ Y  and satisfies η
≤ ς, then
for all t̄ ≥ t0 + C ζς
µ
log ςζ
µ
, the following must hold true for all 1 ≤ j ≤ M :

v
uM
uX
e(t̄)i2 ≤ ηλj (t̄) + O(η 2 ),
t hvi (t̄), x (9.26)
i=j

provided that for all steps t ∈ {t0 , . . . , t̄ − 1}, xη (t)xη (t + 1) ⊂ Y  .

The proof of the above theorem is in Section 9.12.1.

Proof of Theorem 9.5.3. We define the Normalized gradient flow as φ(x, τ ) = x −


R τ ∇L(φ(x,s))
0 k∇L(φ(x,s))k
ds. Since φ(x, ·) is only a time rescaling of φ(x, ·), they have the same
limiting mapping, i.e., Φ(x) = limτ →Tx φ(x, τ ), where Tx is the length of the trajectory
of the gradient flow starting from x.

455
Let Tx be the length of the GF trajectory starting from x, and we know
limτ →Tx φ(x, τ ) = Φ(x), where φ(x, τ ) is defined as the Normalized gradient flow
starting from x. In Lemmas 9.10.2 and 9.10.5 we show there is a small neigh-
bourhood around Φ(xinit ), Y  such that L is µ-PL in Y  . Thus we can take some
time T0 < Txinit such that φ(xinit , T0 ) ∈ Y /2 and L(φ(xinit ), T0 ) ≤ 12 Lcritical , where
2 µ
Lcritical := 8
. (Without loss of generality, we assume miny∈Y L(y) = 0) By standard
ODE approximation theory, we know there is some small η0 , such that for all η ≤ η0 ,
xη (dT0 /ηe) − φ(xinit , T0 ) = O(η), where O(·) hides constants depending on the
initialization xinit and the loss function L.
Without loss of generality, we can assume η0 is small enough such that xη (dT0 /ηe) ∈
Y  and L(xη (dT0 /ηe)) ≤ Lcritical . Now let tη be the smallest integer (yet still
larger than dT0 /ηe) such that xη (tη )xη (tη − 1) 6⊂ Y  and we claim that there is
t ∈ {dT0 /ηe, . . . , tη }, k∇L(xη (t))k < ζη. By the definition of tη , we know for any t ∈
{dT0 /ηe + 1, . . . , tη − 1}, by Lemma 9.10.11 we have kΦ(xη (t)) − Φ(xη (t − 1))k ≤ ξη 2 ,

and by Lemma 9.10.13, L(xη (t)) − xη (t − 1) ≤ −η 42µ if k∇L(xη (t))k ≥ ζη. If
p p

the claim is not true, since L(xη (t)) decreases η 42µ per step, we have
p



q q
0 ≤ L(xη (tη − 1)) ≤ L(xη (dT0 /ηe)) − (tη − dT0 /ηe − 1)η ,
4

which implies that tη − dT0 /ηe − 1 ≤ η , and therefore by Lemma 9.10.11,

ξη 2 ξη
kΦ(xη (tη − 1)) − Φ(xη (dT0 /ηe))k ≤ (tη − dT0 /ηe − 1) =
2 2

Thus we have

kΦ(xη (tη − 1)) − Φ(xinit )k

≤ kΦ(xη (tη − 1)) − Φ(xη (dT0 /ηe))k + Φ(xη (dT0 /ηe) − Φ(φ(xinit , T0 ))) = O(η).

456
q
2L(xη (tη −1))
Meanwhile, by Lemma 9.10.6, we have kΦ(xη (tη − 1)) − xη (tη − 1)k ≤ µ

q
2L(xη (dT0 /ηe))
µ
= 2 . Thus for any κ ∈ [0, 1], we have kκxη (tη ) + (1 − κ)xη (tη − 1) − Φ(xinit )k
is upper bounded by

κ kxη (tη ) − xη (tη − 1)k + kΦ(xη (tη − 1)) − xη (tη − 1)k + kΦ(xη (tη − 1)) − Φ(xinit )k

=κη + + O(η),
2

which is smaller than  since we can set η0 sufficiently small. In other words,
Φ(xη (tη ))Φ(xη (tη − 1)) ⊂ Y  , which contradicts with the definition of tη . So far we
have proved our claim that there is some t0η ∈ {dT0 /ηe, . . . , tη }, ∇L(xη (t0η )) < ζη.

Moreover, since L(xη (t)) decreases η 42µ per step before t0η , we know t0η −dT0 /ηe ≤ η .
p

ζη
By Lemma 9.10.6, we know xη (t0η ) − Φ(xη (t0η )) ≤ µ
.
Now we claim that for any T10 , there is some sufficiently small threshold
T10
η0 , tη ≥ η
+ 1 if η ≤ η0 . Below we prove this claim by contradiction. If
T10
the claim is not true, that is, tη < η
+ 1. if tη ≤ C ζς
µ
log ςζ
µ
+ t0η with
ζ
ς = µ
, we know kxη (tη ) − Φ(xinit )k ≤ xη (tη ) − xη (t0η ) + xη (t0η ) − Φ(xη (t0η )) +
Φ(xη (t0η )) − Φ(xinit ) = O(η), which implies that xη (tη )xη (tη − 1) ∈ Y . If
tη ≥ C ζς
µ
log ςζ
µ
+ t0η , by Lemma 9.11.1, we have kxη (tη ) − Φ(xη (tη ))k = O(η). By
Lemma 9.10.11, we have kΦ(xη (tη )) − Φ(xη (dT0 /ηe))k ≤ O(η). Thus again we
have that kxη (tη ) − Φ(xinit )k ≤ kxη (tη ) − Φ(xη (tη ))k + kΦ(xη (tη )) − Φ(xη (dT0 /ηe))k
+ kΦ(xη (dT0 /ηe)) − Φ(xinit )k = O(η), which implies that xη (tη )xη (tη − 1) ∈ Y . In
both cases, the implication is in contradiction to the definition of tη .
T10
Thus for any T10 , tη ≥ η
+ 1 for sufficiently small threshold η0 and η ≤ η0 . To
complete the proof of Theorem 9.5.3, we pick T1 to be any real number strictly
larger than  + T0 , as T1
η
> C ζς
µ
log ςζ
µ
+ 
η
+ dT0 /ηe ≥ C ζς
µ
log ςζ
µ
+ t0η when η is
sufficiently small with ς = µζ . By Lemma 9.11.1 the second claim of Theorem 9.5.3

457
T10
is proved. Using the same argument again, we know ∀ Tη1 ≤ t ≤ η
, it holds that
kΦ(xη (t)) − Φ(xinit )k ≤ O(η).

9.11.2 Phase II, Limiting Flow

We first restate the main theorem that demonstrates that the trajectory implicitly
minimizes sharpness.

Theorem 9.5.4 (Phase II). Let {xη (t)}t∈N be the iterates of perturbed Normalized GD
(Algorithm 7) with LR η. Under Assumptions 9.5.1 and 9.5.2, if the initialization xη (0)
satisfy that (1) kxη (0) − Φ(xinit )k ≤ O(η) where xinit ∈ U , (2) maxj∈[D] Rj (xη (t)) ≤
O(η 2 ), and additionally (3) min{|hv1 (xη (0)), xη (0) − Φ(xη (0))i| , −R1 (xη (0))} ≥ Ω(η),
then for any time T2 > 0 till which the solution of (9.5) exists, it holds for suffi-
ciently small η, with probability at least 1 − O(η 10 ), that kΦ(xη (bT2 /η 2 c)) − X(T2 )k =
P 2 /η2 c
O(η) and bT21/η2 c bTt=0 θt ≤ O(η), where θt ∈ [0, π2 ] denotes the angle between
∇2 L(Φ(xη (t)))(xη (t) − Φ(xη (t))) and top eigenspace of ∇2 L(Φ(xη (t))).

To show the closeness between the continuous and the discrete dynamic, we
will need to use the following classic differential inequality from [257]. The original
statement is for differential equations defined on RD . Without loss of generality, we
can restrict it to an open subset of RD with the same proof.

Theorem 9.11.2. [Adaption of “Variant form of Theorem 10.2”, p.59, [257]] Let U
be an open subset of RD . Suppose that {y(τ ) ∈ U }Tτ=0 is a solution of the differential
dy
equation dτ
= f (y(τ )), y(τ ) = y0 , and that v(τ ) ∈ U is a piecewise linear curve. If
f (y) is βlip -Lipschitz in y, that is, ∀y, y 0 ∈ U , τ ∈ [0, T ], kf (y) − f (y 0 )k ≤ βlip , then

458
for any 0 ≤ τ ≤ T , it holds that for any τ ∈ [0, T ],

 Z τ 
−βlip τ
ky(τ ) − v(τ )k ≤e T βlip
kv(0) − y(0)k + e kv (τ + 0) − f (v(τ ))k dτ 0
0 0
0
 Zτ τ=0 
≤e T βlip
kv(0) − y(0)k + kv (τ + 0) − f (v(τ ))k dτ 0 ,
0 0
τ 0 =0

v(τ 0 +δ)−v(τ 0 )
where v 0 (τ 0 + 0) := limδ→0 δ
is the right time derivative of v at τ 0 .

Proof of Theorem 9.5.4. Without loss of generality, we can change assumption (3) in
xη (0)k ≤ ηλ1 (0)/2 + Ψnorm η 2 and |hv1 (xη (0)), x
the theorem statement into ke eη (0)i| ≥
Ω(η). (Constant Ψnorm is defined in Lemma 9.12.1) This is because we know from
λ1 (·)
Lemma 9.12.1, that the norm can’t stay above 2
η + Ψnorm η 2 for two consecutive
eη (0)| ≥ Ω(η) but ηλ1 (0)/2 + Ω(Ψnorm η 2 ) ≤ ke
steps. Moreover, if |v1 (0), x xη (0)k ≤
ηλ1 (0) − Ω(η), we can further show that |v1 (1), x
eη (1)| ≥ Ω(η) from the update rule of
Normalized GD (Lemma 9.10.9). Thus, we can shift our analysis by one time-step
if our assumption isn’t true at step 0. This simplification of assumption helps us to
prove the second claim using Lemma 9.13.5.
To prove the first claim, we first show the movement in the manifold for the
discrete trajectory for Algorithm 7 by Lemma 9.10.12: for each step t, provided
Φ(xη (t))Φ(xη (t + 1)) ∈ Y  , it holds that

η2 ⊥
Φ(xη (t + 1)) − Φ(xη (t)) = − PΦ(xη (t)),Γ ∇ log λ1 (xη (t)) + O((θt + kxη (t) − Φ(xη (t))k)η 2 ).
4
(9.27)

To recall, the limiting flow is given by

Z τ
1 ⊥
X(τ ) = Φ(xinit ) − PX(s),Γ ∇ log λ1 (X(s))ds, X(τ ) ∈ Γ (9.5)
4 s=0

The high-level idea for the proof of the first claim is to bound the gap between
Equation (9.27) and Equation (9.5) using Theorem 9.11.2. And the first claim
459
eventually boils down to upper bound the average angle by O(η), which is exactly the
second claim.
Formally, let t2 be the largest integer no larger than bT2 /η 2 c such that for any
0 ≤ t ≤ t2 , it holds that Φ(xη (t))Φ(xη (t + 1)) ∈ Y  .
To apply Theorem 9.11.2, we let y(τ ) = X(τ ), f : U → RD , f (y) =

PΦ(y),Γ ∇ log λ1 (y), βlip be an upper bound for lipschitzness of f on compact set Y 
and v(τ ) = Φ(xη (bτ /η 2 c)) + (τ /η 2 − bτ /η 2 c)) (Φ(xη (bτ /η 2 c + 1)) − Φ(xη (bτ /η 2 c))).
In other words, v is a piecewise linear curve interpolating all xη (t) at time tη 2 .
Therefore, by Equation (9.27), note Φ(x) = Φ(Φ(x)) for all x ∈ U , it holds that

v 0 (tη 2 + 0) − f (v(tη 2 )) = 1/η 2 (Φ(xη (t + 1)) − Φ(xη (t))) − PΦ(x



η (t))),Γ
∇ log λ1 (Φ(xη (t))))

=O(θt + kxη (t) − Φ(xη (t))k).

Since we started from a point that has max1≤j≤M Rj (xη (0)) ≤ O(η 2 ), we have from
Lemma 9.11.1, that the iterate satisfies the condition max1≤j≤M Rj (xη (t)) ≤ O(η 2 ) at
step t as well, meaning that kxη (t) − Φ(xη (t))k ≤ O(η).
Therefore, for any τ ≤ t2 η 2 , note that v 0 (τ + 0) = v 0 (bτ /η 2 c + 0) and that
kf (v(bτ /η 2 c + 0)) − f (v(τ ))k = O(kΦ(xη (bτ /η 2 c + 1)) − Φ(xη (bτ /η 2 c))k) = O(η 2 ),
we have that

kv 0 (τ + 0) − f (v(τ ))k = v 0 (bτ /η 2 c + 0) − f (v(bτ /η 2 c)) + f (v(bτ /η 2 c + 0)) − f (v(τ ))

≤O(θbτ /η2 c + η) + O(η 2 ) = O(θbτ /η2 c + η).

Thus by Theorem 9.11.2, we conclude that

Z t2 η 2 t2
X
2 2 2 2
 
X(t2 η )−xη (t2 ) = y(t2 η )−v(t2 η ) = O η+ (θbτ /η2 c +η) = O(η+ η (θt +η) ) = O(η),
τ =0 t=0

460
where in the last step we use the second claim. This implies that t2 must be equal
to bT2 /η 2 c for sufficiently small η otherwise xη (t2 )xη (t2 + 1) ⊆ Y  . This is because
kxη (t2 + 1) − xη (t2 )k = O(η) and X(t2 η 2 ) ∈ Y . The proof is completed by noting
that kX(T2 ) − X(bT2 /η 2 c)k = O(η 2 ).

9.12 Phase I, Proofs of the Main Lemmas

9.12.1 Proof of Lemma 9.11.1

Proof of Lemma 9.11.1. The Normalized GD update at any step t can be written as
(from Lemma 9.10.9)

∇2 L(Φ(xη (t)))[xη (t) − Φ(xη (t))] ν


xη (t + 1) − xη (t) = −η + O( η kxη (t) − Φ(xη (t))k).
k∇2 L(Φ(xη (t)))[xη (t) − Φ(xη (t))]k µ
(9.28)

From Lemma 9.10.11, we have kΦ(xη (t)) − Φ(xη (t + 1))k ≤ O(ξη 2 ), which further
implies,
k∇2 L(Φ(xη (t + 1))) − ∇2 L(Φ(xη (t)))k ≤ O(νξη 2 ). Thus, using the notation x
e =
∇2 L(Φ(x))(x − Φ(x)), we have

eη (t + 1) − x
x eη (t + 1) − ∇2 L(Φ(xη (t)))(xη (t + 1) − Φ(xη (t)))
eη (t) = x

+ ∇2 L(Φ(xη (t)))(xη (t + 1) − Φ(xη (t))) − x


eη (t)

= ∇2 L(Φ(xη (t + 1)))(Φ(xη (t)) − Φ(xη (t + 1)))

+ (∇2 L(Φ(xη (t + 1))) − ∇2 L(Φ(xη (t))))(xη (t + 1) − Φ(xη (t)))

+ ∇2 L(Φ(xη (t)))(xη (t + 1) − xη (t))


x
eη (t)
= −η∇2 L(Φ(xη (t))) + err + O(η 2 + η kxη (t) − Φ(xη (t))k),
ke
xη (t)k

461
That is,

∇2 L(Φ(xη (t)))
 
eη (t + 1) = I − η
x eη (t) + O(η 2 ) + O(η kxη (t) − Φ(xη (t))k).
x
ke
xη (t)k
(9.29)

Below we will show that kxη (t) − Φ(xη (t))k ≤ O(η), and thus the trajectory of
eη is similar to the trajectory in the qudratic model with an O(η 2 ) error, with the
x
hessian fixed at ∇2 L(Φ(xη (t))), and hence we can apply the same techniques from
Corollary 9.9.4 and Lemma 9.9.1.
eη (t) for t0 + 1 ≤ t ≤ t. We will show
First, we consider the norm of the vector x
the following induction hypothesis:

ke
xη (t)k ≤ 1.01ηζς.

1. Base case: (t = t0 ). We have ke


xη (t0 )k = k∇2 L(Φ(xη (t0 )))[xη (t0 ) − Φ(xη (t0 ))]k ≤
ηλ1 (t)ς ≤ ηζς.

2. Induction case:(t > t0 ). Suppose the hypothesis holds true for t − 1. Then,

1 1.01ηςζ
kxη (t − 1) − Φ(xη (t − 1))k ≤ ke
xη (t − 1)k ≤ .
λM (t) µ

We consider the following two cases:

462
(a) If ke
xη (t − 1)k ≥ ηλ1 (t). We can directly apply Corollary 9.9.3 on (9.29) to
show that

 
ηλM (t − 1)
ke
xη (t)k ≤ 1 − xη (t − 1)k + O(νξη 2 )
ke
2 ke
xη (t − 1)k
νζ
+ O( η kxη (t − 1) − Φ(xη (t − 1))k)
µ
ηλM (t − 1) νζ
≤ ke
xη (t − 1)k − + O(νξη 2 ) + O( 2 ςη 2 )
2 µ
ηλM (t − 1)
≤ ke
xη (t − 1)k − ,
4

where the final step follows if η is sufficiently small. Hence, ke


xη (t)k <
ke
xη (t − 1)k ≤ ηζς.

(b) If ke
xη (t − 1)k ≤ ηλ1 (t). Then, we can directly apply Lemma 9.9.1 on (9.29)
to show that

νζ
xη (t)k ≤ ηλ1 (t) + O(νξη 2 ) + O(
ke η kxη (t − 1) − Φ(xη (t − 1))k)
µ
νζς
≤ ηλ1 (t) + O(νξη 2 ) + O( 2 η 2 )
µ
≤ 1.01ηλ1 (t).

1 1.01ηςζ
Hence, we have shown that, kxη (t) − Φ(xη (t))k ≤ λM (t)
ke
xη (t)k ≤ µ
for all
time t0 ≤ t ≤ t.
We complete the proof of Lemma 9.11.1 with a similar argument as that for the
quadratic model (see Corollary 9.9.4 and Lemma 9.9.1). The major difference from the
quadratic model is that here the hessian changes over time, along with its eigenvectors
and eigenvalues. Hence, we need to take care of the errors introduced in each step by
the change of hessian.
The high-level idea is to divide the eigenvalues at each step t into groups such that
eigenvalues in the same group are O(η) close and eigenvalues from different groups

463
are at least 2η far away from each other. Formally, we divide [M ] into disjoint subsets
(t) (t)
S1 , · · · , Sp(t) (with 1 ≤ p(t) ≤ M ) such that

∀k, ` ∈ [p(t)], k 6= ` min |λi (t) − λj (t)| > η


i∈Sk ,j∈S`

and
(t)
∀k ∈ [p(t)], i, i + 1 ∈ Sk λi (t) − λi+1 (t) ≤ η.

For S ⊂ [M ], we denote by PtS the projection matrix at time t onto the subspace
spanned by {vi (t)}i∈S . From Lemma 9.10.11, we have kΦ(xη (t + 1)) − Φ(xη (t + 1))k ≤
ξη 2 , which further implies, k∇2 L(Φ(xη (t + 1))) − ∇2 L(Φ(xη (t)))k ≤ O(νξη 2 ). That
implies, using Theorem 9.14.2, |λj (t) − λj (t)| ≤ O(νξη 2 ) for any j ∈ [M ]. Therefore,
(t) (t)
S` S
we can use Theorem 9.14.4 to have for any ` ∈ [p] Pt − Pt+1
`
≤ O(νξη), since
we have created the eigen subspaces such that the eigenvalue gap between any two
distinct eigen subspaces is at least 0.5η in the desired interval.
(t) (t)
Thus for any t0 ≤ t ≤ t − 1 and k ∈ [p(t)], suppose i ∈ Sk and j = min Sk , we
have that
v
uM
uX
eη (t + 1)i2
t hv (t + 1), x
h
h=i
v v
uM u p(t) 2
uX uX S (t)
eη (t + 1)i2 = t
≤t hvh (t + 1), x Pt+1
`
x
eη (t + 1)
h=j `=k
v v
u p(t) 2 uM
uX S (t) uX
2
≤t Pt ` x
eη (t + 1) eη (t + 1)i2 + O(η 2 )
+ O(η ) = t hvh (t), x
`=k h=j

and ηλi (t + 1) ≥ ηλi (t) − O(η 2 ) ≥ ηλj (t) − O(η 2 ).

464
Therefore, we have that

v v
uM uM
1 uX 1 uX
eη (t + 1)i2 ≤
t hv (t + 1), x
h eη (t + 1)i2 + O(η)
t hvh (t), x
ηλi (t + 1) h=i
ηλj (t) h=j

Next we will use the results from the quadratic case to upper bound
qP qP
M 2 M
h=j hvh (t), x
eη (t + 1)i using eη (t)i2 . For all 1 ≤ j ≤ M , we
h=j hvh (t), x

consider the following two cases for any time t0 + 1 ≤ t ≤ t:


qP
M
1. If eη (t)i2
h=j hvh (t), x > ηλj (t), then we can apply Lemma 9.9.2 on (9.29) to
show that
v
uM
uX
eη (t + 1)i2
t hvh (t), x
h=j
v
  uX M
ηλM (t) u νζ
≤ 1− t eη (t)i2 + O(νξη 2 ) + O( η kxη (t) − Φ(xη (t))k)
hvh (t), x
ke
xη (t)k h=j
µ
v
 uX M
νζ 2 ς 2

µ u 2 2
≤ 1− t hvh (t), x
eη (t)i + O(νξη ) + O( 2 η ).
2ζς h=j
µ

qP
M
2. If eη (t)i2
h=j hvh (t), x ≤ ηλj (t), then we can apply Lemma 9.9.1 on (9.29) to
show that
v
uM
uX νζ
eη (t + 1)i2 ≤ ηλj (t) + O(νξη 2 ) + O( η kxη (t) − Φ(xη (t))k)
t hvh (t), x
h=j
µ
νζ 2 ς 2
≤ ηλj (t) + O(νξη 2 ) + O( η ).
µ2

465
Thus we conclude that
v
uM
1 uX
max h eη (t + 1)i2
t hv (t + 1), x
i∈[M ] ηλi (t + 1)
h=i
 v 
uM
 µ 1 u X 
≤ max 1, (1 − )· eη (t)i2 + O(η),
max t hvh (t), x
 2ζς ηλj (t) j∈[M ] h=j 

and therefore following the same proof of quadratic case Corollary 9.9.4, for t ≥
qP
M
t0 + Ω( ςζ
µ
log ζς
µ
), it holds that ∀j ∈ [M ], e(t̄)i2 ≤ ηλj (t̄) + O(η 2 ).
i=j hvi (t̄), x

9.12.2 Properties of the condition in Equation (9.26)

By Lemma 9.11.1, the following condition will continue to hold true for all 1 ≤ j ≤ M
eη (t) leaves Y  :
before x

v
uM
uX
eη (t)i2 ≤ λj (t)η + O(η 2 ),
t hvi (t), x (9.30)
i=j

eη (t) = ∇2 L(Φ(xη (t)))(xη (t) − Φ(xη (t))). We will call the above condition as
where x
the alignment condition from now onwards.
From the alignment condition (9.30), we can derive the following property that
continues to hold true throughout the trajectory, once the condition is satisfied:

Lemma 9.12.1. There is some constant Ψnorm > 0, such that if the condition (9.30)
ηλ1 (t)
holds true and ke
xη (t)k > 2
, we have:

ηλ1 (t)
ke
xη (t + 1)k ≤ + Ψnorm η 2 .
2

The proof follows from applying Lemma 9.9.8 using the alignment condition (9.30).
eη (t) can’t stay at norm larger than 0.5ηλ1 (t) + Ψnorm η 2 for time
Hence, the iterate x
larger than 1.
466
Another useful lemma is to about the change of the angle between x
eη (t) and the top
ηλ1 (t)
eigenvector when ke
xη (t)k ≤ 2
+ Ψnorm η 2 , which is a noisy version of Lemma 9.9.11
for a quadratic model.

ηλ1 (t)
Lemma 9.12.2. Consider any time t such that ke
xη (t)k ≤ 2
+ Ψnorm η 2 , and the
condition (9.30) holds true, then we have that

 
min(∆, 2µ) η
tan θt+1 ≤ 1− tan θt + O( ).
ζ ke
xη (t)k

and that

ηλ1 η2
tan θt+2 ≤ tan θt + O( ).
ke
xη (t)k ke
xη (t)k

ηλk (t)
Corollary 9.12.3. If for some 1 ≤ k ≤ M , ke
xη (t)k ≤ 2
+ Ψnorm η 2 and condition
(9.30) holds true, the following must hold true:

vk (t + 1)> x
eη (t + 1) ≥ vk (t)> x
eη (t) − O(η 2 ).

The proof follows from using the noisy quadratic update for Normalized GD in
Lemma 9.10.9 (Equation (9.22)) and the behavior in a quadratic model along the
non-top eigenvectors in Lemma 9.9.5.

9.13 Phase II, Proofs of the Main Lemmas

The main lemma in this section is Lemma 9.13.5 in Section 9.13.2, which says the
sum of the angles across the entire trajectory in any interval [0, t2 ] with t2 = Ω(1/η 2 ),
is at most O(ηt2 ). Before proving the main lemma, we will first recap and introduce
some notations that will be used.

467
In Phase II, we start from a point xη (0), such that (1) kxη (0) − Φ(xinit )k ≤ O(η),
(2) maxj∈[D] Rj (xη (t)) ≤ O(η 2 ), and additionally (3) |hv1 (xη (0)), xη (0) − Φ(xη (0))i| =
Ω(η).
(2:M )
Pt,Γ x
eη (t)
Formally, recall our notation on θt as θt = arctan |hv1 (t),e
xη (t)i|
, with our notation
eη (t) as ∇2 L(Φ(xη (t)))(xη (t) − Φ(xη (t))). Moreover, recall the definition of the
of x
function gt : R → R as

s  !
λ1 (t) λ λ
gt (λ) = 1− 1−2 1− .
2 λ1 (t) λ1 (t)

For convenience, we also define Gt := | hv1 (t), x


e(t)i |, gmax (t) := maxk∈[M ] gt (λk (t)) and
gmin (t) := mink∈[M ] gt (λk (t)).
The condition (9.30) that was shown to hold true in Phase II is:

v
uM
uX
eη (t)i2 ≤ λj (t)η + O(η 2 ).
t hvi (t), x
i=j

Further, we had proved in Lemma 9.12.1 that:

ηλ1 (t) ηλ1 (t + 1)


ke
xη (t)k > =⇒ ke
xη (t + 1)k ≤ + Ψnorm η 2 .
2 2

ηλ1 (t)
Thus, the iterate can’t stay greater than 2
+ Ψnorm η 2 for more than 1 step. This
lemma allows us to divice all time steps into groups of length 1 and 2.

9.13.1 Dividing Time Steps Into Cycles of Length 1 or 2

The properties of the three sets N0 , N1 , N2 in an interval (e


t, t) given directly by
the algorithm in Algorithm 9 include:

xη (t + 2)k ≥ 0.5λ1 (t + 2)η + Ψnorm η 2 .


1. ∀t ∈ N0 , ke

xη (t + 2)k ≤ 0.5λ1 (t + 2)η + Ψnorm η 2 .


2. ∀t ∈ N1 , ke
468
Algorithm 9 Grouping into 1-cycles and 2-cycles
Input: Interval [e t, t], Iterates of Algorithm 7: {e
xη (t)}t∈[et,t] where xη (t)xη (t + 1) ⊆
Y , Top eigenvalue λ1 (t) = ∇2 L(Φ(xη (t))).
Requires: x t) ≤ 0.5λ1 (e
eη (e t)η + Ψnorm η 2 .
Initialize: N0 , N1 , N2 ← ∅, t ← e t.
while t ≤ t do
xη (t + 2)k > 0.5λ1 (t + 2)η + Ψnorm η 2 then
if ke
N0 ← N0 ∪ {t}
t←t+1
else if kexη (t + 2)k ≤ 0.5λ1 (t + 2)η + Ψnorm η 2 then
N1 ← N1 ∪ {t}
N2 ← N2 ∪ {t + 1}
t←t+2
Return: N0 , N1 , N2

3. ∀t, t ∈ N1 ⇐⇒ t + 1 ∈ N2 .

4. N0 ∪ N1 ∪ N2 = [e
t, t], and the intersection between each pair of them is empty.

We also have the following lemmas, which is less direct:

Lemma 9.13.1. For any step t in N0 , t − 1, t + 2 ∈ N2 and t − 2, t + 1 ∈ N1

Proof of Lemma 9.13.1. Suppose t is in N0 , and therefore ke


xη (t + 2)k ≥ 0.5λ1 (t +
2)η + Ψnorm η 2 . By Lemma 9.12.1 we know that ke
xη (t + 3)k < 0.5λ1 (t + 3)η + Ψnorm η 2 ,
thus t + 1 ∈ N1 and t + 2 ∈ N2 . Applying similar argument on t − 1, we know if t − 1
is in N0 , t cannot be in N0 . Meanwhile, t − 1 cannot be in N1 , which would imply
t ∈ N2 . Thus t − 1 must be in N2 and therefore t − 2 is in N1 .

xη (t)k ≤ 0.5λ1 (t)η + Ψnorm η 2 .


Lemma 9.13.2. ∀t ∈ N0 ∪ N1 , ke

Proof of Lemma 9.13.2. By Lemma 9.13.1, we know if t ∈ N0 , then t − 2 ∈ N1 and


xη (t)k ≤ 0.5λ1 (t)η + Ψnorm η 2 .
thus by the definition of N1 , we have ke
If t ∈ N1 , then we consider the three possibilities for t − 2. If t − 2 ∈ N1 , then
the proof is done by definition of N1 . If t − 2 ∈ N0 , then by Lemma 9.13.1, we have
t ∈ N2 , contradiction! Thus this case is not possible. If t − 2 ∈ N2 , then t − 1 cannot

469
be in N1 as t is not in N2 . Thus t − 2 must be in N0 and Lemma 9.12.1 implies that
xη (t)k ≤ 0.5λ1 (t)η + Ψnorm η 2 .
ke

Lemma 9.13.3. For any step t in N0 and N1 ,


  2
min(∆,2µ)
1. tan θt+1 ≤ 1 − ζ
tan θt + O( Gη t )

ηλ1 2
2. tan θt+2 ≤ Gt
tan θt + O( Gη t )
  
3. If Gt ≥ 1.02gmax (t)η, tan θt+2 ≤ 1 − min 0.01, min 2λi (t)
λ1 (t)
(1 − λi (t)
λ1 (t)
) tan θt +
i≤M

O(η).

Proof of Lemma 9.13.3. The proof follows from using the noisy update rule for Nor-
malized GD, as derived in Equation (9.29), which says that the Normalized GD
update is very close to the update in a quadratic model with an additional O(η 2 )
error. Using the property of N0 and N1 outlined above, we have the norm of x
eη (t)
at most 0.5λ1 (t)η + Ψnorm η 2 . The result then follows from using Lemma 9.9.11 and
Lemma 9.9.10, that computes the convergence rate towards the top eigenvector for a
quadratic model. (The first two properties are stated in Lemma 9.12.2).

As a direct consequence of Lemma 9.13.3, we have the following lemma:

t = max N1 ∩ {e
Lemma 9.13.4. Given any t with θt = Ω(1), let e t | e
t ≤ t}. If
Get ≥ Ω(η), then θet = Ω(1).

Proof of Lemma 9.13.4. The claim is clearly true if t ∈ N1 . If t ∈ N0 , then


Lemma 9.13.1 shows that t − 1 ∈ N2 , t − 2 ∈ N1 and thus e
t = t − 2. The claim is true
because of the second property of Lemma 9.13.3. If t ∈ N2 , then e
t = t − 1 ∈ N1 and
the proof is completed by applying the first property of Lemma 9.13.3.

9.13.2 Time Average of Angles Against Top Eigenspace

Lemma 9.13.5 (Average of the Angles). For any T2 > 0 for which solution of
Equation (9.5) exists, consider an interval [0, t2 ], with Ω(1/η 2 ) ≤ t2 ≤ bT2 /η 2 c.
470
Suppose Algorithm 7 is run with learning rate η for t2 steps, starting from a point
xη (0) that satisfies (1) maxj∈[D] Rj (xη (0)) ≤ O(η 2 ), and (2) G0 := |hv1 (0), x
e(0)i| ≥
ηλ1 (0)
βη, ke
xη (0)k ≤ 2
+ Ψnorm η 2 for some constant β independent of η. The following
holds true with probability at least 1 − η 10 :

t2
1X
θ` ≤ O (η) ,
t2 `=0

provided η is set sufficiently small, and for all time 0 ≤ t ≤ t2 −1, xη (t)xη (t + 1) ⊂ Y  .

Proof. We split the entire interval [0, t2 ) into small trunks in the following way,
0=e
t0 < e
t1 < e t` = t2 with e
t2 . . . e t` denoting the starting step of each trunk. Each e
ti is
defined from e
ti−1 for i > 0. The behavior of each trunk depends on the magnitude of
the iterate along the top eigenvector of hessian. We classify the trunks on the basis of
3 possibilities: Consider a general e
ti ,

A. If Geti ≥ 1.02gmax (e
ti ), then we define e
ti+1 as

x(t)k ≤ 0.5λ1 (t)η + Ψnorm η 2 }.


ti+1 = min{t | Gt ≤ 1.01gmax (t), ke
e
t>e
ti

ti ) ≤ Geti ≤ 1.02gmax (e
B. If 0.98gmin (e ti ) then we define e
ti+1 as

x(t)k ≤ 0.5λ1 (t)η+Ψnorm η 2 }.


ti+1 = min{t | (0.97gmin (t) ≥ Gt ∨ Gt > 1.03gmax (t))∧ke
e
t>e
ti

C. If Geti ≤ 0.98gmax (e
ti ), then we define e
ti+1 as

x(t)k ≤ 0.5λ1 (t)η + Ψnorm η 2 }


ti+1 = min{t | Gt ≥ 0.99gmax (t), ke
e
t>e
ti

We analyze the behavior of a general e


ti when it falls in any of the above cases:

471
Case (A). First of all, since Gt ≥ 1.02gmax (t) for all e
ti ≤ t < e
ti+1 we can show from
Lemma 9.13.3 that the angle with the top eigenvector quickly drops to O(η) in at
most O(ln 1/η) time-steps. Moreover, the iterate’s magnitude can only drop along the
top eigenvector when the angle with the top eigenvector is smaller than O(η), and the
drop is at most O(η 3 ) (Lemma 9.13.10). Thus, during alignment of the iterate to the
top eigenvector, Gt never drops. Moreover, after the alignment, it takes Ω( η12 ) steps for
the iterate’s magnitude along the top eignvector to drop below 1.01 maxk∈[M ] gt (λk (t)).
Hence,

  ti+1
e
1 X  
e ti ≥ Ω
ti+1 − e , θt ≤ O (e
ti+1 − e ti+1 − e
ti )η + log 1/η = O (e ti )η .
η2
t=e
ti

After Gt drops below 1.01gmax (t), it moves to case B(1).

Case (B). From Lemma 9.13.6 we have that the sum of angle over this time is

ti+1
e q 
X
θt = O ti+1 − e
e ti+1 − e
ti + η(e ti ) .
t=e
ti

Case (C). We claim Gt will become larger than 0.99gmax (t) in O(η −0.1 ) steps with
probability at least 1 − O(η 12 ), because of the η 100 perturbation added per Θ(η −0.1 )
ti , θt ≤ Ω(η), then by Lemma 9.13.11, we know that in
steps. If for some t > e
O(log 1/η) steps after the perurbation, with probability at least 1 − O(η 12 ), we have
θt0 ≥ Ω(η) for some t0 ≤ t + O(log 1/η). And thus we can apply Lemma 9.13.7
and θt0 = Ω(1). By Lemma 9.13.4, we know the θet = Ω(1) as well, where e
t is the
t. Then by Lemma 9.13.10, Get+2 ≥ Get + Ω(η). If
largest step in N1 yet smaller than e
t+2 ∈
e / N1 , then e t + 3 ∈ N1 . Again by Lemma 9.13.9, we have
t + 2 must be in N0 and e
Get+3 ≥ Get+2 − O(η 2 ) ≥ Get + Ω(η). Thus Gt will increase Ω(η) every O(log 1/η) steps
among those steps in N1 (among the steps in N1 and N0 , Gt decreases at most O(η 3 )

472
ti + O(log 1/η) + O(η −0.1 ) = e
ti+1 ≤ e
by Lemma 9.13.10. Thus e ti + O(η −0.1 ). Thus,
Peti+1
t=e
θ = O(η −0.1 ).
t t i

Now it remains to upper bound the number of occurrence of (A),(B) and (C).
Since our goal is to show average angle is O(η), which is equal to the average angle in
case (A), so the number of occurrence of case (A) doesn’t matter. For case (B), if it is
followed by case (A), then there is an Ω(1/η 2 ) gap before next occurrence of (B). If
(B) is followed by case (C), then by Lemma 9.13.10, it takes at least Ω(1/η 2 ) steps to
escape from (B). Thus we can have O(1) occurrence of case (B). For the same reason,
there could be at most O(1) occurrence of case (C).
All in all, with probability at least 1 − O(η 12 · η 12 ) = 1 − O(η 10 ), we must have

 
t2
X X q
t0 )η + O(1) · O(η −0.1 ) + O(

θt ≤ O (e
t` − e ti+1 − e
e ti 
t=0 i:case (B)
 
s X X
t0 )η + O(η −0.1 ) + O 

≤ O (e
t` − e ti+1 − e
(e ti ) 1
i:case (B) i:case (B)

= O (t2 η) + O (t2 )

= O(t2 η)

where we use t2 ≥ Ω( η12 ) in the last step and and the number of occurrence of case
(B) is O(1) in the second to the last step.

Lemma 9.13.6. Consider the setting of Lemma 9.13.5. Consider any time interval
[t, t0 ], where t ≤ ` < t0 , xη (`)xη (` + 1) ⊂ Y  and Ω(η) ≤ Gl := |hv1 (t), x
eη (t)i| ≤
λ1 (`)η
2
− Ω(η), we have that

X X p
θt = θt ≤ O( t0 − t + (t0 − t)η).
t∈[t,t0 ] t∈N0 ∪N1 ∪N2

Proof of Lemma 9.13.6. The noisy update rule for Normalized GD, as derived in
Lemma 9.10.9, which says that the Normalized GD update is very close to the update
473
in a quadratic model with an additional O(η 2 ) error. Keeping this in mind, we then
divide our trajectory in the interval (t, t0 ) as per Algorithm 9 into three subsets
N0 , N1 , N2 . (Please see Section 9.13.1 for a summary on the properties of these 3 sets.)
Consider any t ∈ N1 . Using the behavior of Gt from Lemma 9.13.10, we can
show that in each of the time-frames, Gt+2 ≥ (1 + Ω(sin2 θt ))Gt − O(η 2 (η + ηt )) ≥
Gt + Ω(θt2 η) − O(η 2 (η + θt )).
P
Next we want to telescope over Gt+2 − Gt to get an upper bound for t∈N1 θt . If
t + 2 is also in N1 then it’s fine. If t + 2 ∈ N0 , then t + 3 ∈ N1 by Lemma 9.13.1 and
we proceed in the following two cases.

• If θt+2 ≤ C for some sufficiently small constant C, since Gt+2 ≤ λ1 (t + 2)η/2 −


Gt+2
Ω(η), we have ke
xη (t + 2)k ≤ cos θt+2
= λ1 (t + 2)η/2 − Ω(η), and thus by
Lemma 9.13.9, we have Gt+3 ≤ Gt+2 and therefore, Gt+3 ≥ Gt + Ω(θt2 η) −
O(η 2 (η + θt )).

• If θt+2 ≥ C, then by Lemma 9.12.2, we have θt = Ω(1), thus Gt+2 ≥ Gt + Ω(η)


by Lemma 9.13.10. Again by Lemma 9.13.9, we have Gt+3 ≥ Gt+2 − O(η 2 ).
Thus again we conclude Gt+3 ≥ Gt + Ω(η) ≥ Gt + Ω(θt2 η) − O(η 2 (η + θt )), since
θt is always O(1).

Since total increase in Gt during this interval can is most O(η), we conclude that
2
P P
t∈N1 θt = O(1) + η t∈N1 (η + θt ) and thus it holds that

 
X s X p s X
θt ≤ (t0 − t) θt2 ≤ t0 − t · O  1 + η θt + η 2 (t0 − t)
t∈N1 t∈N1 t∈N1
p
≤O( t0 − t + (t0 − t)η)

Moreover, by Lemma 9.13.3, we must have θt < θt−1 + O(η) for any time t ∈ N2 ,
and t − 1 must be in N1 . By Lemma 9.13.9, we have θt ≤ Ω(θt−2 ) for any t ∈ N0 and

474
t − 2 must be in N2 . That implies,

X X p
= θt ≤ O( t0 − t + (t0 − t)η),
t∈[t,t0 ] t∈N0 ∪N1 ∪N2

which completes the proof.

Lemma 9.13.7. Consider any coordinate 2 ≤ k ≤ M . For any constants 0 < β,


t) is in Y  ,
there is some constant α > 0 such that for any time step t where xη (e
Get ≥ βη, condition (9.30) holds and vk (e
t), x
e(e
t) ≥ αη 2 , then there is some time
t ≤ t0 < t, xη (t0 )xη (t0 + 1) ⊂ Y  , then
t ≤ t + O (ln 1/η) such that if for all time e
condition (9.30) holds at time t and at least one of the following two conditions hold:

1. Gt ≥ 0.99gt (λk (t)).

2. θt ≥ Ω(1).

Proof of Lemma 9.13.7. We will prove by contradiction. Suppose neither of the two
condition happens, we will show θt grows exponentially and thus the condition (2)
must be false in O(log 1/η) steps.
ηλ1 (t) ηλ1 (t)
First of all, because θt = O(1) and Gt ≤ 2
whenever ke
xη (t)k ≤ 2
+ Ψnorm η 2 ,
ηλ1 (t) ηλ1 (t+1)
by Lemma 9.13.9, we know if ke
xη (t)k ≤ 2
+ Ψnorm η 2 , then ke
xη (t + 1)k > 2
+
ηλ1 (t+2)
Ψnorm η 2 . And thus ke
xη (t + 2)k ≤ 2
+ Ψnorm η 2 . In other words, t ∈ N1 ∪ N2 for
t + 2k ∈ N1 for all natural numbers k with e
all t. Therefore, for e t + 2k ≤ t. Moreover,
Get+2k ≥ Gt − O(kη 2 ) = Ω(η) by Lemma 9.13.10 for k ≤ O( η1 ).
Now, we can use Equation (9.22) (Lemma 9.10.9) to show that the Normalized GD
update is equivalent to update in quadratic model, up to an additional O(η 2 ) error.

∇2 L(Φ(xη (t)))[xη (t) − Φ(xη (t))]


xη (t + 1) − xη (t) = −η + O(η 2 ).
k∇2 L(Φ(xη (t)))[xη (t) − Φ(xη (t))]k

475
Similar to Lemma 9.9.9, consider the coordinate k, we have that

|hvk (t), ∇2 L(Φ(xη (t)))(xη (t + 2) − Φ(xη (t)))i|


|hv1 (t), ∇2 L(Φ(xη (t)))(xη (t + 2) − Φ(xη (t)))i|
 
λ1 (t) − λk (t)
= 1−η
λ1 (t)η − k∇2 L(Φ(xη (t)))(xη (t + 1) − Φ(xη (t)))k
 
λ1 (t) − λk (t) |hvk (t), x
eη (t)i|
· 1−η − O(η) (9.31)
λ1 (t)η − ke xη (t)k |hv1 (t), x eη (t)i|
 
1 |hvk (t), x
eη (t)i|
≥ 1+ − O(η)
100 |hv1 (t), x eη (t)i|
 
1 |hvk (t), x
eη (t)i|
≥ 1+ ,
200 |hv1 (t), x eη (t)i|

The third step follows from using the same argument as the one used for the quadratic
update in Lemma 9.9.9 and the assumption that Gt ≥ 0.99gt (λk (t)). The final step
holds true because we can pick α as a large enough constant and by assumption
|hvk (t),e
xη (t)i|
|hv1 (t),e
xη (t)i|
≥ αη.
We then bound vk (t) − vk (e t)) by O(η 2 (t − e
t) and Φ(xη (t) − Φ(xη (e t)) using
Lemma 9.10.11. Combining everything, we conclude that at least one of the two
assumptions has to break for some t ≤ e
t + O(log 1/η).

Lemma 9.13.8. Consider any coordinate 2 ≤ k ≤ M . For any constants 0 < β,


suppose at time step t, xη (t) is in Y  , (1.01) gt (λk (t))η ≤ Gt < 0.5ηλ1 (t) and condition
(9.30) holds, then there is some time t ≤ t + O (ln 1/η) such that if for all time
t ≤ t0 < t, xη (t0 )xη (t0 + 1) ⊂ Y  , then the following two conditions hold:

eη (t)i ≤ O(η 2 ).
hvk (t), x

and
∇2 L(Φ(xη (t)))(xη (t) − Φ(xη (t))) ≤ 0.5λ1 (t)η + Ψnorm η 2 .

476
The proof of Lemma 9.13.8 is very similar to the proof of Lemma 9.13.7 and thus
we omit the proof. The only difference will be that we need to use Lemma 9.9.10 in
place of Lemma 9.9.9, when we use the result for the quadratic model.

9.13.3 Dynamics in the Top Eigenspace

Here, we will state two important lemmas that we used for the proof of Lemma 9.13.5,
which is about the behavior of the iterate along the top eigenvector. Lemma 9.13.10
can be viewed as perturbed version for Lemma 9.9.7 in the quadratic case, and
We assume in all the lemmas, that Equation (9.30) holds true for the time under
consideration, which we showed in Lemma 9.11.1, and also that we start Phase II
from a point where the alignment along the top eigenvector is non negligible.
The following lemmas give the properties of dynamics in the top eigenspace in
Phase II for one-step and two-step updates respectively. Recall we use Gt to denote
the quantity |hv1 (t), x
e(t)i|.

Lemma 9.13.9 (Behavior along the top eigenvector, one step). For sufficiently small
1
η, consider any time t, such that xη (t) ∈ Y  , ke
xη (t)k ≤ 2
ηλ1 (t) + Ψnorm η 2 and
Gt ≥ Ω(η) holds true , the following holds:

 
λ1 (t)
Gt+1 ≥ − 1 Gt − O(η 2 )
ke
xi (t)k

provided that xη (t)xη (t + 1), xη (t + 1)xη (t + 2) ⊂ Y  .

Lemma 9.13.10 (Behavior along the top eigenvector, two steps). For sufficiently
xη (t)k ≤ 21 ηλ1 (t) + Ψnorm η 2 and
small η, consider any time t, such that xη (t) ∈ Y  , ke

477
Gt ≥ Ω(η) holds true, the following holds:

λj (t)(λ1 (t) − λj (t)) 2


Gt+2 ≥ (1 + 2 min 2
sin θt )Gt − O((θt + η)η 2 )
2≤j≤M λ1 (t)
≥ Gt − O(η 3 )

provided that xη (t)xη (t + 1), xη (t + 1)xη (t + 2) ⊂ Y  .

xη (t), ∇L(xη (t))) = O( kexη (t)−∇L(x


Proof of Lemma 9.13.10. First note that ∠(e ke
xη (t)k
η (t))k
)=
2
O( Gη t ) = O(η), where the last step we use Lemma 9.10.9. Let δ = ∠(v1 (t), ∇L(xη (t)))−
eη (t)) and we have |δ| ≤ ∠(e
∠(v1 (t), x xη (t), ∇L(xη (t))) = O(η). Therefore, it holds that

 
∇L(xη (t))
v1 (t), = cos ∠(v1 (t), ∇L(xη (t)))
k∇L(xη (t))k

= cos ∠(v1 (t), x


eη (t)) cos δ + sin ∠(v1 (t), x
eη (t)) sin δ

= cos ∠(v1 (t), x eη (t)) + sin2 δ)


eη (t)) + O(sin δ sin ∠(v1 (t), x
 
x
eη (t)
= v1 (t), + O((θt + η)η)
kexη (t)k

Using the Normalized GD update, we have

hv1 (t), xη (t + 1) − Φ(xη (t))i − hv1 (t), xη (t) − Φ(xη (t))i


 
∇L(xη (t))
= − η v1 (t),
k∇L(xη (t))k
 
x
eη (t)
= − η v1 (t), + O((θt + η)η 2 ) (9.32)
ke
xη (t)k

From Lemma 9.10.11, we have kΦ(xη (t)) − Φ(xη (t + 1))k ≤ O(ξη 2 ), which further
implies, k∇2 L(Φ(xη (t + 1))) − ∇2 L(Φ(xη (t)))k ≤ O(η 2 ). Thus, we can use Theo-
νξη 2
rem 9.14.4 to have kv1 (t) − v1 (t + 1)k ≤ O( λ1 (t)−λ 2 (t)
) = O(η 2 ). From Lemma 9.10.12,

478
we have |hv1 (t), Φ(xη (t + 1)) − Φ(xη (t))i| ≤ O(η 3 ). Thus we have that

hv1 (t), x eη (t)i = v1 (t), ∇2 L(Φ(xη (t)))(xη (t + 1) − xη (t)) + O(η 3 ),


eη (t + 1) − x

and therefore,

hv1 (t + 1), x
eη (t + 1)i − hv1 (t), x
eη (t)i

= hv1 (t), x eη (t)i + O(η 3 )


eη (t + 1)i − hv1 (t), x
 
∇L(xη (t))
= − ηλ1 (t) v1 (t), + O(η 3 )
k∇L(xη (t))k
 
xeη (t)
= − ηλ1 (t) v1 (t), + O((θt + η)η 2 )
ke
xη (t)k

Thus we have that

λ1 (t)
Gt+1 = 1 − η Gt + O((θt + η)η 2 )
ke
x(t)k

Therefore, we have the following inequality by applying the same argument above
to t + 1:

λ1 (t + 1)
Gt+2 = 1 − η Gt+1 + O((θt+1 + η)η 2 )
ke
x(t + 1)k
λ1 (t + 1) λ1 (t)
= 1−η 1−η Gt
ke
x(t + 1)k kex(t)k
λ1 (t + 1)
+ 1−η · O((θt + η)η 2 ) + O((θt+1 + η)η 2 )
ke
x(t + 1)k

By Lemma 9.12.1, we know ke


x(t + 1)k = Ω(η). By Lemma 9.12.2, we know that
θt+1 ≤ θt + O(η). Thus

ηλ1 (t + 1) ηλ1 (t)


Gt+2 = 1 − 1− Gt + O((θt + η)η 2 ). (9.33)
ke
x(t + 1)k ke
x(t)k

479
η∇2 L(Φ(xη (t)))
Next we will show ke
xη (t + 1)k − (I − ke
xη (t)k
)e
xη (t) = O(η 2 θt ). For con-
venience, we denote ∇2 L(Φ(xη (t))) by H. First we have that

2 2
∇L(xη (t)) x
eη (t)
eη (t) − ηH
x − x eη (t) − ηH
k∇L(xη (t))k ke
xη (t)k
    
∇L(xη (t)) x
eη (t) ∇L(xη (t)) x
eη (t)
= 2exη (t) − ηH + , ηH −
k∇L(xη (t))k ke xη (t)k k∇L(xη (t))k ke xη (t)k
    
2 ∇L(xη (t)) x
eη (t) ∇L(xη (t)) x
eη (t)
= 2H x eη (t) − ηH + ,η −
k∇L(xη (t))k ke xη (t)k k∇L(xη (t))k ke xη (t)k
   
2 ∇L(xη (t)) x
eη (t) ∇L(xη (t)) x
eη (t)
eη (t) − ηH
= 2H x + η − cos α
k∇L(xη (t))k ke xη (t)k k∇L(xη (t))k ke xη (t)k

=O(η 3 cos α),

 
∇L(xη (t)) ∇L(xη (t))
where α is the angle between − xxeηη (t)k
k∇L(xη (t))k ke
(t)
eη (t)−ηH 2
and 2H x k∇L(xη (t))k
+ x
eη (t)
ke
xη (t)k
.
Note that and that both ∠(e xη (t), v1 (t)), ∠(∇L(xη (t)), v1 (t)) = O(ηt + η), we have
 
∇L(xη (t)) x
eη (t) 2 ∇L(xη (t)) x
eη (t)
that the angle between k∇L(x η (t))k
+ ke
xη (t)k
and 2H x
e η (t) − ηH k∇L(xη (t))k
+ ke
xη (t)k
∇L(xη (t)) x
eη (t)
is at most O(ηt + η). Further note that k∇L(xη (t))k
− ke
xη (t)k
is perpendicular to
∇L(xη (t)) x
eη (t)
k∇L(xη (t))k
+ ke
xη (t)k
, we know cos α ≤ O(θt + η). Therefore we have that

∇L(xη (t)) x
eη (t)
eη (t) − ηH
x − x
eη (t) − ηH
k∇L(xη (t))k ke
xη (t)k
2 2
∇L(xη (t))
eη (t) − ηH k∇L(x
x η (t))k
eη (t) − ηH kexxeηη (t)k
− x (t)

=
∇L(xη (t))
eη (t) − ηH k∇L(x
x η (t))k
eη (t) − ηH kexxeηη (t)k
+ x (t)

=O(η 2 (η + θt )).

By Lemma 9.9.6, we have that

η∇2 L(Φ(xη (t)))


 
1 λj (t)(λ1 (t) − λj (t)) 2
ke
xη (t)k + (I − xη (t) ≤ ηλ1 (t) 1 −
)e min sin θt .
ke
xη (t)k 2λ1 (t) 2≤j≤M λ21 (t)

480
Thus we have proved a perturbed version of Lemma 9.9.6, that is,

 
1 λj (t)(λ1 (t) − λj (t)) 2
ke
xη (t)k + ke
xη (t + 1)k ≤ ηλ1 (t) 1 − min sin θt + O(η 2 (η + θt )).
2λ1 (t) 2≤j≤M λ21 (t)

Therefore a perturbed version of Lemma 9.9.7 would give us:

ηλ1 (t + 1) ηλ1 (t) λj (t)(λ1 (t) − λj (t)) 2


(1 − )(1 − ) ≥ 2 min sin θt + O((η + θ)η).
ke
x(t + 1)k ke
x(t)k 2≤j≤M λ21 (t)

The proof of the first inequality is completed by plugging the above equation into
Equation (9.33).
The second inequality is immediate by noting that ηθt2 + C 2 η 3 ≥ 2Cη 2 θt for any
C > 0.

9.13.4 Dynamics in Top Eigenspace When Dropping Below

Threshold

In this section, we will show that the projection along the top eigenvector cannot
drop below a certain threshold. Formally, we will show the following lemma that
predicts the increase in the projection Gt = |hv1 (t), x
eη (t)i| along the top eigenvector in
O(log 1/η) steps, whenever the projection drops below a certain threshold gmax (t) :=
maxk∈[M ] gt (λk (t)).

Lemma 9.13.11. Denote r = η 100 . For any constant 0 < β, there is a constant
α > 0, such that for any step t and xη (t) ∈ Y  with the following conditions hold:

1. βη ≤ Gt ≤ 0.98gmax (t)η.

eη (t)i| ≤ O(η 2 ), for all 2 ≤ i ≤ M.


2. |hvi (t), x

Then, with probability at least 1 − η 12 , after perturbing xη (t) with noise generated
uniformly from B0 (r) followed by tesc + 2 = Θ(log 1/η) steps of Normalized GD
481
(2:M )
(t = t+tesc +2), it holds that Pt,Γ eη (t) ≥ Ω(η 2 ) provided that xη (t0 )xη (t0 + 1) ⊂ Y 
x
for all time t ≤ t0 ≤ t.

Lemma 9.13.11 is a direct consequence of the following lemma.

Lemma 9.13.12. Consider any time t, with xη (t) ∈ Y  . Suppose xη (t) satisfies the
conditions in Lemma 9.13.11. The constants cesc , gmax (t), r, α, and β have been taken
from Lemma 9.13.11. Define Xstuck as the region in Bxη (t) (r) such that starting from
any point u ∈ Xstuck , the points {u(e
t)}et∈[tesc ] , with u(0) := u, obtained using tesc steps
of Normalized GD satisfy:

(2:M )
Pt,Γ t) − Φ(xη (t))) ≤ αη 2 ,
(u(e t ∈ [tesc ],
for all e (9.34)

(2:M )
where Pt,Γ denotes the subspace spanned by v2 (t), . . . , vM (t).
Consider two points u and w in Bxη (t) (r), with the property w = u + Kη 12 rvk (t), 4

where K ≥ 1 can be arbitrary number and vk (t) denotes the eigenvector corresponding
to the eigenvalue λk (t) = argmaxλi (t)|1≤i≤M gt (λi (t)). Then, at least one of u and w is
not present in the region Xstuck .

We will first prove Lemma 9.13.11 and then we turn to the proof of Lemma 9.13.12.

Proof of Lemma 9.13.11. Lemma 9.13.12 shows that if some point u ∈ Bxη (r) is in
Xstuck , then it holds that

Xstuck ∩ {u + λvk (t) | λ ∈ R} ⊂ {u + λvk (t) | λ ∈ R, |λ| ≤ η 12 r}.

The other words, Xstuck is only a thin slice of width at most η 12 r of Bxη (t) (r), which
implies vol(Xstuck )/vol(Bxη (t) (t)) = O(η 1 2), where vol(·) denotes the volume of the
set.
4 12
η can be replaced by any η p , and the final success probability in Lemma 9.13.11 becomes
1 − η p−2 .
482
Proof of Lemma 9.13.12. We will prove by contradiction. Consider the two sequences
obtained with tesc steps of Normalized GD, {u(e
t), w(e
t)}et∈[tesc ] :

∇L(u(e
t)) ∇L(w(e
t))
u(0) = u, w(0) = w, u(e t) − η
t) = u(e , w(e t) − η
t + 1) = w(e .
∇L(u(e
t)) ∇L(w(e
t))

For convenience, we denote ∇2 L(Φ(xη (t)))[u(e


t) − Φ(xη (t))] by u
e(e
t) and
(2:M ) e (2:M )
∇2 L(Φ(xη (t)))[w(e
t) − Φ(xη (t))] by w(
ee t). Suppose both Pt,Γ u e(t) , Pt,Γ ve(et)
are O(η 2 ), we will show the following, which indicates the contradiction:

(2:M )
Pt,Γ (u(tesc ) − w(tesc )) ≥ Ω(η 2 ).

An important claim to note is the following:

Lemma 9.13.13. Both the trajectories {u(e


t), w(e
t)}et≤tesc satisfy a modified version of
the alignment condition (Equation (9.30)), i.e. for all 1 ≤ j ≤ M :

v v 
uM uM
uX uX
max t hvi (t), w(
ee t)i2 , t hvi (t), u t)i2  ≤ λj (t)η + O(Ψnorm η 2 ).
e(e
i=j i=j

Note that the condition has been slightly changed to use {vi (t)} as reference
coordinate system and Φ(xη (t)) as reference point. The above lemma follows from the
fact that both u(0) and w(0) are r-close to xη (t), which itself satisfies the alignment
condition (Equation (9.30)). Thus, both u(0) and w(0) initially follow the desired
condition. Since, both the trajectories follow Normalized GD updates, the proof
will follow from applying the same technique used in the proof of Lemma 9.11.1.
Another result to keep in mind is the following modified version of Corollary 9.12.3,
Lemma 9.13.14.

483
Lemma 9.13.14. If u t) ≤ η λ12(t) +Ψnorm η 2 , then v1 (t)> u
e(e t + 1) ≥ v1 (t)> u
e(e t) −
e(e
O(η 2 ). The same results hold for w(
ee t) as well.
If w(
eet) ≤ η λ12(t) + Ψnorm η 2 , u t) ≤ η λ12(t) + Ψnorm η 2 , and z(γ) denotes γu(0) +
e(e
∇L(x)
(1 − γ)w(0) for any γ ∈ [0, 1], let F (x) = x − η k∇L(x)k , we have

v1 (t)> ∇2 L(Φ(xη (t)))(F (z(γ)) − Φ(xη (t)))

≥ v1 (t)> ∇2 L(Φ(xη (t)))(z(γ) − Φ(xη (t))) − O(η 2 ).

The above lemma uses {vi (t)} as reference coordinate system and Φ(xη (t)) as
reference point. The above lemma follows from showcasing Normalized GD updates
of u(e
t) and w(e
t) as equivalent to the update in a quadratic model, with an additional
noise of O( νζ
µ
η 2 ), similar to Equation (9.29).
Continuing with the proof of Lemma 9.13.12, we first consider the behavior of u.
Since u ∈ Xstuck , we have for any time-step e
t:

t) − Φ(xη (t))
u(e
min − sv1 (t) ≤ η (9.35)
s∈{±1} t) − Φ(xη (t))
u(e

Further, applying the same technique from Lemma 9.13.10, we can show that

hv1 (t), u t + 2)i − hv1 (t), u


e(e t))i = O(η 3 ).
e(e (9.36)

Initially, because u was initialized close to xη (t), we must have

|hv1 (t), u
e(0)i − hv1 (t), x
eη (t)i| ≤ O(r).

484
t) − Φ(xη (t))i − hv1 (t), u(0) − Φ(xη (t))i ≤ O(η 3 tesc ) for all even e
Hence, hv1 (t), u(e t∈
[tesc ]. With tesc ∼ O(log 1/η) , we must have

0.99gmax (t) ≥ hv1 (t), u t))i ≥ 0.5βη,


e(e (9.37)

for any t ≤ e
t ≤ t + tesc . The same argument applies to w(·) as well. By Equation (9.34),
we know u
e(e
t) , w(
eet) = o(η).
Now, we consider the behavior of w(·) and u(·). Consider an even time step
0≤e
t ≤ tesc . From the update rule of w and u, we have

t + 2) = F (2) (w(e
t + 2) − u(e
w(e t)) − F (2) (u(e
t)),

∇L(v)
where the function F : RD → RD , F (v) = v − η k∇L(v)k is the one-step update rule of
Normalized GD and F (2) = F ◦ F .
Now, we use taylor expansion of F around u(e
t) to get

w(e t + 2) = F (2) (w(e


t + 2) − u(e t)) − F (2) (u(e
t)) = ∇F (2) (u(e t) − u(e
t))(w(e t)) + err,

where kerrk can be bounded as follows, with z(γ) defined as γu(e


t) + (1 − γ)w(e
t):

1 2
max ∇2 F (2) (z(γ))) w(e t) − u(e
t)
γ∈[0,1] 2
1 2
= max k∇[∇F (F (z(γ))))∇F (z(γ))]k w(e t) − u(e
t)
γ∈[0,1] 2
 
1 1 2
≤ max η · O 2 + 2 t) − u(e
w(e t)
γ∈[0,1] k∇L(z(γ))k k∇L(F (z(γ)))k
 
2 2
· max ∂ 2 (∇L)(z(γ)) , ∇2 L(z(γ)) , ∂ 2 (∇L)(F (z(γ))) , ∇2 L(F (z(γ)))
 
1 1 2
≤ max 2 + 2 · O(η w(e t) − u(e
t) ).
γ∈[0,1] k∇L(z(γ))k k∇L(F (z(γ)))k

485
Using taylor expansion: ∇L(z(γ)) = ∇2 L(Φ(xη (t)))(z(γ)−Φ(xη (t)))+O(ν kz(γ) − Φ(xη (t))k2 )
and hence, we must have k∇L(z(γ))k ≥ Ω(η).
With u
e(e
t) = o(η), we can apply Lemma 9.13.14 to show

hv1 (t), ∇2 L(Φ(xη (t)))[F (z(γ)) − Φ(xη (t))]i ≥ hv1 (t), ∇2 L(Φ(xη (t)))[z(γ) − Φ(xη (t))]i +O(η 2 ).

That implies, k∇L(F (z(γ)))k ≥ Ω(η)


1 1
Hence, µ(e
t) = maxγ k∇L(z(γ))k2
+ k∇L(F (z(γ)))k2
≤ Ω(1/η 2 ).
Thus we conclude

1 2
w(e t + 2) − ∂F (2) (u(e
t + 2) − u(e t) − u(e
t))(w(e t)) ≤ O( w(e
t) − u(e
t) ), (9.38)
η

where ∂F (2) (u(e


t)) = Aet+1 Aet with

" #
∇L(u(e
t))∇L(u(et))> ∇2 L(u(et))
t)) = I − η I −
Aet := ∂F (u(e 2 ,
∇L(u(et)) ∇L(u(e t))

and µ(e
t) is given by

 
1 1
µ(e
t) = max 2 + , (9.39)
γ∈[0,1]:z(γ)=γu(e
t)+(1−γ)w(e
t) k∇L(z(γ))k k∇L(F (z(γ)))k2

Now we define Bet and claim Aet can be approximated as below with kBet k = O(η).
Furthermore, kAet k ≤ O(1).

∇2 L(Φ(xη (t)))
Bet = Aet − I − η I − v1 (t)v1 (t)>
 
,
hv1 (t), ∇2 L(Φ(xη (t)))[u(e
t) − Φ(xη (t))]i

The following strategies have been used to obtain the above approximation. First,
∇2 L(u(e
t)) − ∇2 L(Φ(xη (t))) ≤ O( u(e
t) − Φ(xη (t)) ) = O( u(e
t) − Φ(u(0)) ) +
O(kΦ(u(0)) − Φ(xη (t))k) = O(η). Therefore, Using taylor expansion, ∇L(u(e
t)) =
∇2 L(Φ(xη (t)))(u(e
t) − Φ(xη (t))) + O(η 2 ) = u t) + O(η 2 ).
e(e Using the update
486
from Equation (9.36) and note tesc = O(log 1/η), we must have ∇L(u(e
t)) ≥
t) − Φ(xη (t))i − O(η 2 ) ≥ βη − O(η 3 tesc ) ≥ Ω(η). Finally we use the condi-
hv1 (t), u(e
 >
∇L(u(e
t)) ∇L(u(e
t))
tion from Equation (9.35) to show that ∇L(u(et)) = v1 (t)v1 (t)> + O(η).
k k k∇L(u(et))k
Similarly, we can show that:

∇2 L(Φ(xη (t)))
= I − η I − v1 (t)v1 (t)>
 
Aet+1 + Bet+1 ,
ηλ1 (t) − hv1 (t), u
e(e
t)i

with Aet+1 ≤ O(1) and Bet+1 ≤ O(η).


Now we define the following error term

Y
err(e t + 2) − u(e
t) := w(e t + 2) − H(u(i))(w(0) − u(0)), (9.40)
0≤i≤e
t:2|i

Finally, we use Lemma 9.13.15 and Lemma 9.13.16 to handle the main and error
terms in Equation (9.42),

Y
|hvk (t), w(tesc ) − u(tesc )i| = vk (t)> t))(w(0) − u(0)) + vk (t)> err(tesc )
H(u(e
t≤tesc :2|e
0≤e t

Y
≥ vk (t)> t))(w(0) − u(0)) − kerr(tesc )k
H(u(e
t≤tesc :2|e
0≤e t

≥ Ω(η 2 ) − O(η 3 ) = Ω(η 2 ).

which completes the proof of Lemma 9.13.12.

Lemma 9.13.15.

Y
vk (t)> t))(w(0) − u(0)) ≥ Ω(η 2 ).
H(u(e
t≤tesc :2|e
0≤e t

487
Lemma 9.13.16.

kerr(tesc )k ≤ O(η kw(t) − u(t)k) = O(η 3 ).

Proof of Lemma 9.13.15. For simplicity of presentation, we have used Met to define
" #" #
h i ∇2 L(Φ(xη (t))) h i ∇2 L(Φ(x (t)))
η
I − η I − v1 (t)v1 (t)> I − η I − v1 (T )v1 (t) >
.
ηλ1 (t) − hv1 (t), u
e(e
t)i hv1 (t), u
e(et)i

Thus, the term under consideration can be simplified as follows,

Y
t))(w(0) − u(0))
H(u(e
t≤tesc :2|e
0≤e t
Y
= η3r H(u(e
t))vk (t)
t≤tesc :2|e
0≤e t
Y
= η3r Aet+1 Aet vk (t)
t≤tesc :2|e
0≤e t
Y
= η3r Met vk (t) + rem,
t≤tesc :2|e
0≤e t

where using the bounds on {Aet , Aet+1 , Bet , Bet+1 }0≤et≤tesc , we have

X Y
kremk ≤ max (kBet k + Bet+1 ) · max (kAet k + Aet+1 ) · Mj
t≤tesc
e t≤tesc
e
t≤tesc :2|e
0≤e t 0≤j≤tesc :2|j,j6=e
t
X Y
≤ O(kη 1 2r) · Mj . (9.41)
t≤tesc :2|e
0≤e t 0≤j≤tesc :2|j,j6=e
t

From the behavior of u(e t) from Equation (9.37), we have hv1 (t), u
e(et)i ≤
1

1 − 200M gmax (t)η. Recall that gmax (t) was chosen as max1≤k≤M gt (λk (t)). It
turns out that for the chosen upper bound of gmax (t), vk (t) acts as the top eigenvector
t ≤ tesc .
of Met for any e

488
For all j ∈ [2, M ] and e
t ∈ [tesc ], we have:

" #" #
λj (t)/λ1 (t) λj (t)/λ1 (t)
Met vj (t) = 1 − η 1−η vj (t),
η − hv1 (t), u(e
t) − Φ(xη (t))i hv1 (t), u(et) − Φ(xη (t))i

with Met v1 (t) = v1 (t). When hv1 (t), u t)i ≤ gmax (t), kMet vk t)k ≥ kMet v1 (t)k, for all
e(e
j ≥ 2. Furthermore, kMet vj (t)k maximizes when j = k. Therefore,

" #" #
λk (t)/λ1 (t) λk (t)/λ1 (t)
kMet k = 1 − η 1−η
η − hv1 (t), u(e
t) − Φ(xη (t))i hv1 (t), u(e
t) − Φ(xη (t))i
  
λk (t) λk (t)
≥ λ1 (t) − 1− , for all e t ∈ [tesc ],
λ1 (t) − 0.99gmax (t) 0.99gmax (t)

since we showed before that hv1 (t), u t)i ≤ 0.99gmax (t)η .


e(e
Now, we explain our choice of tesc . We select tesc s.t.

* +
Y
1
vk (t), Kη 2r Met vk (t) = Θ(η 2 ).
t≤tesc :2|e
0≤e t

That is, we select the time step e t, where the magnitude of the useful term
Q
t Me
t≤tesc :2|e
0≤e t vk (t) along the eigenvector vk (t) reaches cesc η. With gmax (t) = gt (λk (t)),
h ih i
λk (t)/λ1 (t) λk (t)/λ1 (t)
we have 1 − 1−0.99g max (t)
1 − 0.99gmax (t)
≥ 1.001 and so, we just need tesc ≤
O(log(cesc /η)).
With this choice of tesc , we must have from Equation (9.41), kremk ≤ O(η 3 ) and
therefore

* +
Y
vk (t), t))(w(0) − u(0))
H(u(e ≥ Ω(η 2 ) − O(η 3 ) ≥ Ω(η 2 ),
t≤tesc :2|e
0≤e t

Thus, we have shown that with the appropriate choice of tesc , the magnitude of
2
Q
t H(u(t))(w(0) − u(0)) can reach at least Ω(η ) along the eigenvector vk (t).
t≤tesc :2|e
0≤e
e

489
Proof of Lemma 9.13.16. We first recall the definition of the error term:

Y
err(e t) − u(e
t) := w(e t) − H(u(i))(w(0) − u(0)), (9.42)
t−2:2|i
0≤i≤]

t ≥ 0 for some % > 0


By Equation (9.38), the following property holds for all e

2
t + 2) − H(u(e
err(e t) ≤ %( w(e
t))err(e t) − u(e
t) /η)

We will use induction hypothesis to show for all even t ≤ tesc , err(e
t) ≤
2
t) − w(e
C u(e t) /η for some sufficiently large constant C. The base case is t = 0
which holds by definition. Now suppose the induction hypothesis holds for all even
0 ≤ t0 ≤ e
t} and below we will show for e
t + 2.
t − 2, we know
First, by induction hypothesis at e

Y 2
t) − u(e
w(e t) − H(u(i))(w(0) − u(0)) ≤ err(e
t) ≤ C( w(e
t) − u(e
t) /η).
t−2:2|i
0≤i≤e

Since w(e t) = O(η 2 ), we have err(e


t) − u(e t) ≤ O(η w(e
t) − u(e
t) ) and that

Y
t) − u(e
w(e t) ≤ (1 + O(η)) H(u(i))(w(0) − u(0)) .
0≤i≤e
t:2|i

490
Meanwhile, we have

err(e
t + 2)

Y
≤ u(e
t + 2) − w(e
t + 2) + H(u(i))(w(0) − u(0))
0≤i≤e
t+2:2|i
 
Y
≤ u(e
t + 2) − w(e
t + 2) + H(u(e t) − u(e
t)) w(e t) − H(u(i))(w(0) − u(0))
0≤i≤e
t:2|i

Thus we have that

H(u(e
t)) t) − u(e
w(e t)

Y
≤(1 + O(η)) H(u(e
t)) H(u(i))(w(0) − u(0))
0≤i≤e
t:2|i

Y
=(1 + O(η)) H(u(i))(w(0) − u(0))
0≤i≤e
t+2:2|i

≤(1 + O(η)) t + 2) − w(e
u(e t + 2) + err(e
t + 2) .

t) is also O(η 2 ), we have


Note err(e

2 2 2
H(u(e
t)) t) − u(e
w(e t) ≤ (1 + O(η)) u(e
t + 2) − w(e
t + 2) + O(η 2 ) err(e
t + 2)

491
Denote by ϕ = minet≤tesc kMet k. From previous analysis we know ϕ > 1 and thus we
have

η err(e
t + 2)
2
≤% w(e
t + 2) − u(e
t + 2) + H(u(e
t))err(e
t) η
2 2
≤% w(e
t + 2) − u(e
t + 2) + C H(u(e
t)) t) − u(e
w(e t)
2 C 2
≤% w(e
t + 2) − u(e
t + 2) +( t + 2) − w(e
+ O(η)) u(e t + 2) + O(η 2 ) err(e
t + 2) .
ϕ

Thus we conclude that

C 2
t + 2) ≤ (% + O(η)) w(e
η err(e t + 2) − u(e
t + 2) .
ϕ

The proof is completed by picking C large enough such that % Cϕ + O(η) ≤ C.

9.13.5 Proof for Operating on Edge of Stability

Proof of Theorem 9.5.7. According to the proof of Theorem 9.5.4, we know for all t,
it holds that Rj (xη (t)) ≤ O(η 2 ). Thus SL (xη (t), ηt ) = ηt · sup0≤s≤ηt λ1 (∇2 L(xη (t) −
k∇L(xη (t))k
s∇L(xη (t)))) = ηt (λ1 (t) + O(η)), which implies that [SL (xη (t), ηt )]−1 = ηλ1 (t)
+
ke
xη (t)k
O(η) = ηλ1 (t)
+ O(η). The proof for the first claim is completed by noting that
1
η
(ke
xη (t)k+ ke
xη (t + 1)k) = λ1 (t) + O(η + θt ) as an analog of the quadratic case.
ke
xη (t)k
p
For the second claim, it’s easy to check that L(xη (t)) = √ + O(ηθt ).
2λ1 (t)
kexη (t)k ke
xη (t+1)k
p p
Thus have L(xη (t)) + L(xη (t + 1)) = √ +√ + O(η(θt + θt+1 )). Note
2λ1 (t) 2λ1 (t+1)
p
that λ1 (t) − λ1 (t + 1) = O(η 2 ) and θt+1 = O(θt ), we conclude that L(xη (t)) +
q
2
L(xη (t + 1)) = η λ1 (∇ L(x η (t))
p
2
) + O(ηθt ).

492
9.14 Some Useful Lemmas About Eigenvalues and

Eigenvectors

Theorem 9.14.1 (Derivative of eigenvalues and eigenvectors of a matrix, Theorem 1


in [258]). Let X0 be a real symmetric n × n matrix. Let u0 be a normalized eigenvector
associated with an eigenvalue λ0 of X0 with multiplicity 1. Then a real valued function
λ and a vector valued function u are defined for all X in some neighborhood N (X0 ) ⊂
Rn×n of X0 , such that

λ(X0 ) = λ0 , u(X0 ) = u0 ,

and

Xu = λu, u> u = 1, X ∈ N (X0 ).

Moreover, the functions λ and u are C ∞ on N (X0 ) and the differentials at X0 are

dλ = u>
0 (dX)u0 , du = (λ0 In − X0 )† (dX)u0 .

Theorem 9.14.2. [Eigenvalue perturbation for symmetric matrices, Cor. 4.3.15 in


b ∈ Rp×p be symmetric, with eigenvalues λ1 ≥ . . . ≥ λp and λ
[259]] Let Σ, Σ b1 ≥ . . . ≥ λ
bp

respectively. Then, for any i ≤ p, we have

λi − λ
bi ≤ Σ − Σ
b .
2

The next theorem is the Davis-Kahan sin(θ) theorem, that bounds the change in
the eigenvectors of a matrix on perturbation. Before presenting the theorem, we need

493
to define the notion of unitary invariant norms. Examples of such norms include the
frobenius norm and the spectral norm.

Definition 9.14.3 (Unitary invariant norms). A matrix norm k · k∗ on the space of


matrices in Rp×d is unitary invariant if for any matrix K ∈ Rp×d , kU KW k∗ = kKk∗
for any unitary matrices U ∈ Rp×p , W ∈ Rd×d .

b ∈ Rp×p be symmetric,
Theorem 9.14.4. [Davis-Kahan sin(θ) theorem [260]] Let Σ, Σ
with eigenvalues λ1 ≥ . . . ≥ λp and λ
b1 ≥ . . . ≥ λ
bp respectively. Fix 1 ≤ r ≤ s ≤ p, let

d := s − r + 1 and let V = (vr , vr+1 , . . . , vs ) ∈ Rp×d and Vb = (b


vr , vbr+1 , . . . , vbs ) ∈ Rp×d
have orthonormal columns satisfying Σvj = λj vj and Σb b vj = λbj vbj for j = r, r +
n o
1, . . . , s. Define ∆ := min max{0, λs − λ
bs+1 }, max{0, λ
br−1 − λr } , where λ b0 := ∞
bp+1 := −∞, we have for any unitary invariant norm k · k∗ ,
and λ

∆ · k sin Θ(Vb , V )k∗ ≤ kΣ


b − Σk∗ .

Here Θ(Vb , V ) ∈ Rd×d , with Θ(Vb , V )j,j = arccos σj for any j ∈ [d] and Θ(Vb , V )i,j = 0
for all i 6= j ∈ [d]. σ1 ≥ σ2 ≥ · · · ≥ σd denotes the singular values of Vb > V. [sin Θ]ij is
defined as sin(Θij ).


9.15 Analysis of L

The analysis will follow the same line of proof used for the analysis of Normalized
GD. Hence, we write down the main lemmas that are different from the analysis of
Normalized GD. Rest of the lemmas are nearly the same and hence, we have omitted
them.

The major difference between the results of Normalized GD and GD with L is in
the behavior along the manifold Γ (for comparison, see Lemma 9.10.12 for Normalized

GD and Lemma 9.15.10 for GD with L). Another difference between the results of

494

Normalized GD and GD with L is in the error rates mentioned in Theorem 9.5.4
and Theorem 9.5.6. The difference comes from the stronger behavior of the projection
along the top eigenvector that we showed for Normalized GD in Lemma 9.13.10, but

doesn’t hold for GD with L (see Lemma 9.15.6). This difference shows up in the
sum of angles across the trajectory (for comparison, see Lemma 9.13.5 for Normalized

GD and Lemma 9.15.4 for GD with L), and is finally reflected in the error rates.

9.15.1 Notations

The notations will be the same as Section 9.10 . However, here we will use x
eη (t) to
1/2
denote (2∇2 L(Φ(xη (t)))) (xη (t) − Φ(xη (t))). We will now denote Y as the limiting
flow given by Equation (9.7).

Z τ
1 ⊥
X(τ ) = Φ(xinit ) − PX(s),Γ ∇λ1 (X(s))ds, X(τ ) ∈ Γ. (9.7)
8 s=0

9.15.2 Phase I, convergence



Here, we will show a very similar stability condition for the GD update on L as
the one (Lemma 9.11.1) derived for Normalized GD. Recall our notation x
eη (t) =
p
2∇2 L(Φ(xη (t)))(xη (t) − Φ(xη (t))).

Lemma 9.15.1. Suppose {xη (t)}t≥0 are iterates of GD with L (9.6) with a learning
rate η and xη (0) = xinit . There is a constant C > 0, such that for any constant
kxη (t0 )−Φ(xη (t0 ))k
ς > 0, if at some time t0 , xη (t0 ) ∈ Y  and satisfies η
≤ ς, then for all
t̄ ≥ t0 + C ζς
µ
log ςζ
µ
, the following must hold true for all 1 ≤ j ≤ M :

v
uM
uX
eη (t̄)i2 ≤ ηλj (t̄) + O(η 2 ),
t hvi (t̄), x (9.43)
i=j

provided that for all steps t ∈ {t, . . . , t̄ − 1}, xη (t)xη (t + 1) ⊂ Y  .

495
Proof. The proof exactly follows the strategy used in Lemma 9.11.1. We can use the
noisy update formulation from Lemma 9.15.7 and the bound on the movement in Φ
from Lemma 9.15.10 to get for any time t with t̄ ≥ t ≥ t0 (similar to Equation (9.29)):

∇2 L(Φ(xη (t)))
 
eη (t + 1) = I − η
x eη (t) + O(η 2 ) + O(kxη (t) − Φ(xη (t))k η),
x
ke
xη (t)k

and the rest proof are the same.

Hence, similar to Lemma 9.12.1, we can derive the following property that continues
to hold true throughout the trajectory, once the condition Equation (9.43) is satisfied:

Lemma 9.15.2. There is some constant Ψnorm > 0, such that if the condition Equa-
ηλ1 (t)
tion (9.43) holds true and ke
xη (t)k > 2
, the following must hold true:

ηλ1 (t)
ke
xη (t + 1)k ≤ + Ψnorm η 2 .
2

We also have the counterpart of Corollary 9.12.3 with the same proof, which follows

from using the noisy update of GD on L from Lemma 9.15.7 and using the quadratic
update result from Lemma 9.9.5.

ηλ1 (t)
Lemma 9.15.3. If at time t, ke
xη (t)k ≤ 2
+ Ψnorm η 2 and stability condition
(Equation (9.43)) holds true, the following must hold true:

v1 (t + 1)> x
eη (t + 1) ≥ v1 (t)> x
eη (t) − O(η 2 ).

9.15.3 Phase II, limiting flow

To recall, the limiting flow given by

Z τ
1 ⊥
X(τ ) = Φ(xinit ) − PX(s),Γ ∇λ1 (X(s))ds, X(τ ) ∈ Γ. (9.7)
8 s=0

496
Let T2 be the time up until which solution to the limiting flow exists.
Lemma 9.15.10 shows the movement in Φ, which can be informally given as follows:
in each step t,

η2 ⊥
Φ(xη (t + 1)) − Φ(xη (t)) = − Pt,Γ ∇λ1 (∇2 L(Φ(xη (t)))) + O(η 2 (θt + kxη (t) − Φ(xη (t))k)),
8
(9.44)

provided Φ(xη (t))Φ(xη (t + 1)) ∈ Y  .


Motivated by this update rule, we show that the trajectory of Φ(xη (·)) is close to
the limiting flow in Equation (9.7), for a small enough learning rate η. The major
difference from Theorem 9.5.4 comes from the fact that the total error introduced
in Equation (9.44) over an interval [0, t2 ] is tt=0 O(η 2 θt + η 3 ), which is of the order
P2

O(η 1/2 ) using the result of Lemma 9.15.4.

Average of the angles The first lemma shows that the sum of the angles in an
interval [0, t2 ] of length Ω(1/η 2 ) is at most O(t2 η 1/2 ).

Lemma 9.15.4. For any T2 > 0 for which solution of Equation (9.7) exists, con-
sider an interval [0, t2 ], with Ω(η −2 ) = t2 ≤ bT2 /η 2 c. Suppose Algorithm 8 is
run with learning rate η for t2 steps, starting from a point xη (0) that satisfies (1)
maxj∈[D] Rj (xη (0)) ≤ O(η 2 ), and (2) |v1 (0), xη (0) − Φ(xη (0))| ≥ βη for some constant
√ 2
0.5λ1 (0)
0 < β independent of η, with ke xη (0)k ≤ 2
η + Ψnorm η 2 , the following holds true
with probability at least 1 − η 10 :

t2
1X √
θ` ≤ O ( η) ,
t2 `=0

provided η is sufficiently small and for all time 0 ≤ t ≤ t2 − 1, xη (t)xη (t + 1) ⊂ Y  .

497
Proof of Lemma 9.15.4. The proof is very similar to the proof of Lemma 9.13.5, except
we replace Lemma 9.13.6 by Lemma 9.15.5 in the analysis of case (B). Hence the final

average angle becomes O( η).

Lemma 9.15.5. Consider the setting of Lemma 9.15.4. Consider any time interval
[t, t0 ], where t ≤ ` < t0 , xη (`)xη (` + 1) ⊂ Y  and Ω(η) ≤ Gl := |hv1 (t), x
eη (t)i| ≤
λ1 (`)η
2
− Ω(η), we have that

X X p √
θt = θt ≤ O( t0 − t + (t0 − t) η).
t∈[t,t0 ] t∈N0 ∪N1 ∪N2

Proof. The proof will follow exactly as Lemma 9.13.6, except we replace Lemma 9.13.10
√ √
by Lemma 9.15.6, which changes the rate into O( t0 − t + (t0 − t) η)

Lemma 9.15.6. [Behavior along the top eigenvector] Consider any time t, such that
xη (t)k ≤ 21 ηλ1 (t) + Ψnorm η 2 holds true, then the following holds
xη (t) ∈ Y  , where ke
true:

1 λj (t)(λ1 (t) − λj (t)) 2


Gt+2 ≥ (1 + min sin θt )Gt − O(η 2 ),
2 2≤j≤M λ21 (t)

provided Gt ≥ Ω(η) and xη (t)xη (t + 1), xη (t + 1)xη (t + 2) ⊂ Y  .

Proof. Here, we will follow a much simpler approach than Lemma 9.13.10 to have a
weaker error bound. The stronger error bounds in Lemma 9.13.10 were due to the
very specific update rule of Normalized GD.
p
eη (t) = 2∇2 L(Φ(x))(xη (t) − Φ(xη (t))). By Lemma 9.15.10, we have
First recall x
p
kΦ(xη (t + 1)) − Φ(xη (t))k = O(η 2 ), thus x eη (t) = 2∇2 L(Φ(x))(xη (t +
eη (t + 1) − x
1) − xη (t)) = η 2∇2 L(Φ(x)) ∇L(x
√ η (t)) . From Lemma 9.15.7, we have
p
2 L(xη (t))

∇L(xη (t)) ∇2 L(Φ(xη (t)))(xη (t) − Φ(xη (t)))


p = q + O(η),
L(xη (t)) 1 2
∇ L(Φ(x η (t)))[x η (t) − Φ(x η (t)), xη (t) − Φ(x η (t))]
2

498
where we have used the fact that ke
xη (t)k = O(η).
Hence, the update is similar to the update in a quadratic model, with ∇2 L(Φ(xη (t)))
guiding the updates with an additional O(η 2 ) perturbation. As a result we also get a
O(η 2 ) perturbation in Gt . Here we use the assumption Gt = Ω(η) so that GD updates
are O(1)-lipschitz.

9.15.4 Omitted Proof for Operating on Edge of Stability

This proof is similar to that of Theorem 9.5.7.

Proof of Theorem 9.5.8. If M = 1, that is, the dimension of manifold Γ is D − 1, we



know xη (t)xη (t + 1) will cross Γ, making the ∇2 L diverges at the intersection and
√ 2 >
the first claim becomes trivial. If M ≥ 2, we have ∇2 L = 2L∇ L−∇L∇L
√ 3 diverges at
4 L
1
the rate of .
It turns out that using basic geometry, one can show that the distance
k∇Lk

from Φ(xη (t)) to xη (t)xη (t + 1) is O(η(θt + θt+1 )), thus sup0≤s≤η λ1 (∇2 L(xη (t) −
√ 1
s∇ L(xη (t)))) = Ω( η(θt +θ t+1 )
). The proof of the first claim is completed by noting
that θt+1 = O(θt ).
p
For the second claim, it’s easy to check that L(xη (t)) = ke
xη (t)k+O(η). The proof
for the first claim is completed by noting that ke
xη (t)k+ke
xη (t + 1)k = ηλ1 (t)+O(η +θt )
as an analog of the quadratic case.


9.15.5 Geometric Lemmas for L
(2:M )
p PΦ(x),Γ x
e
e = 2∇2 L(Φ(x))(x − Φ(x)) and θ = arctan
First recall our notations, x |hv1 (x),e
xi|
.

Lemma 9.15.7. At any point x ∈ Y  , we have

∇L(x) ∇2 L(Φ(x))(x − Φ(x)) ζ 1/2 ν


p =q + O( kx − Φ(x)k).
L(x) 1 2
∂ L(Φ(x))[x − Φ(x), x − Φ(x)] µ
2

499
And therefore,

∇L(x) p ζ 1/2 ν
p ≤ 2λ1 (Φ(x)) + O( kx − Φ(x)k) = O(ζ 1/2 ).
L(x) µ

Proof. By Lemma 9.10.9

1
∇L(x) − ∇2 L(Φ(x))(x − Φ(x)) ≤ ν kx − Φ(x)k2 .
2

Since Φ(x) is a local minimizer of zero loss, we have ∇L(Φ(x)) = 0, we have that

1
L(x) = ∂ 2 L(Φ(x))[x − Φ(x), x − Φ(x)] + O(ν kx − Φ(x)k3 ).
2

2
By Lemma 9.10.8, we know ∂ 2 L(Φ(x))[x−Φ(x), x−Φ(x)] ≥ Ω( kx−Φ(x)k
µ
) and therefore

q
1 2
2
∂ L(Φ(x))[x − Φ(x), x − Φ(x)] ν
p = 1 + O( kx − Φ(x)k).
L(x) µ

Thus we conclude that


q
1 2
2
∇L(x) ∇ L(Φ(x))(x − Φ(x)) + O(ν kx − Φ(x)k ) 2
2
∂ L(Φ(x))[x − Φ(x), x − Φ(x)]
p = q · p
L(x) 1 2
∂ L(Φ(x))[x − Φ(x), x − Φ(x)] L(x)
2

∇2 L(Φ(x))(x − Φ(x)) ζ 1/2 ν


=q + O( kx − Φ(x)k).
1 2
∂ L(Φ(x))[x − Φ(x), x − Φ(x)] µ
2

p
For the second claim, with x
e= 2∇2 L(Φ(x))(x − Φ(x)), we have that

r
∇2 L(Φ(x))(x − Φ(x)) 1 2 x p
∇ L(Φ(x)) ≤ 2λ1 (x).
e
q =
1 2
∂ L(Φ(x))[x − Φ(x), x − Φ(x)] 2 ke
xk
2

2µ ζ 1/2 ν µ
By Lemma 9.10.7, we have kx − Φ(x)k ≤ ζν
, thus µ
kx − Φ(x)k = O( ζ 1/2 )=
O(ζ 1/2 ).

500
The following two lemmas are direct implications of Lemma 9.15.7.

Lemma 9.15.8. At any point x ∈ Y  , we have

(∇2 L(Φ(x)))−1/2 ∇L(x) x ζ 1/2 ν


+ O( 3/2 kx − Φ(x)k).
e
p =
2L(x) ke
xk µ

And therefore,

(∇2 L(Φ(x)))−1/2 ∇L(x)


p = O(1).
2L(x)

Lemma 9.15.9. Consider any point x ∈ Y  . Then,


* +!
∇L(x) (∇2 L(Φ(x)))−1/2 ∇L(x) ζ 1/2 ν
sign v1 (x), p v1 (x) − p ≤ θ + O( 3/2 kx − Φ(x)k),
L(x) 2L(x) µ

(2:M )
PΦ(x),Γ x
e p
where θ = arctan |hv1 (x),e
xi|
, with x
e= 2∇2 L(Φ(x))(x − Φ(x)).
p
Lemma 9.15.10. For any xy ∈ Y  where y = x − η∇ L(x) is the one step update

on L loss from x, we have

η2 ⊥
Φ(y) − Φ(x) = − Px,Γ ∇(λ1 (∇2 L(x)))
8
ζ 3/2 νξ
+O(η 2 ζξθ) + O( 3/2 kx − Φ(x)k η 2 ) + O(ζχ kx − Φ(x)k η 2 ).
µ

(2:M )
PΦ(x),Γ x
e p
Here θ = arctan |hv1 (x),e
xi|
, with x
e= 2∇2 L(Φ(x))(x − Φ(x)).

501
Proof of Lemma 9.15.10. We outline the major difference from the proof of
Lemma 9.10.12. Using taylor expansion for the function Φ, we have

Φ(y) − Φ(x)
1
= ∂Φ(x) (y − x) + ∂ 2 Φ(x)[y − x, y − x] + err
2
! " # 3
∇L(x) η2 2 ∇L(x) ∇L(x) ∇L(x)
= ∂Φ(x) −η p + ∂ Φ(x) p , p + O(χη 3 p )
2 L(x) 2 2 L(x) 2 L(x) L(x)
" #
η2 ∇L(x) ∇L(x)
= ∂ 2 Φ(x) p , p + O(χζ 3/2 η 3 ),
2 2 L(x) 2 L(x)

where in the final step, we used the property of Φ from Lemma 9.10.14 to kill the first
∇L(x)
term and use the bound on √ from Lemma 9.15.7 for the third term.
L(x)

Since the function Φ ∈ C , hence ∂ 2 Φ(x) = ∂ 2 Φ(Φ(x)) + O(χ kx − Φ(x)k).


3

Also, at Φ(x), since v1 (x) is the top eigenvector of the hessian ∇2 L, we have from
Corollary 9.10.21,

1
∂ 2 Φ(Φ(x)) v1 (x)v1 (x)> = − ∂Φ(Φ(x))∂ 2 (∇L)(Φ(x))[v1 (x), v1 (x)].
 
2λ1 (x)

From Lemma 9.15.9, we have

∇L(x)(∇L(x))> ζ 3/2 ν
λ1 (x)v1 (x)v1 (x)> − ≤ ζθ + O( 3/2 kx − Φ(x)k)
2L(x) µ

(2:M )
PΦ(x),Γ (x−Φ(x))
where recall our notation of θ = arctan |hv1 (x),x−Φ(x)i|
.
With further simplification, it turns out that

η2 2
∂ Φ(Φ(x)) λ1 (x)v1 (x)v1 (x)>
 
Φ(y) − Φ(x) = −
8
ζ 3/2 νξ
+O(η 2 ζξθ) + O( 3/2 kx − Φ(x)k η 2 ) + O(ζχ kx − Φ(x)k η 2 ).
µ

The proof is completed by using Corollary 9.10.23.


502
9.16 Additional Experimental Details

9.16.1 Experimental details



For Figure 9.1: For running GD on L, we start from (x, y) = (14.7, 3.), and
use a learning rate η = 0.5. For running Normalized GD on L, we start from
(x, y) = (14.7, −3), and use a learning rate η = 5.

For Figure 9.2: e(0)i = 10−4 , hv2 , x


We start Normalized GD from hv1 , x e(0)i = 0.45.
We use a learning rate of 1 for the optimization updates.

9.16.2 Implementation Details for Simulation for the Limit-

ing Flow of Normalized GD

We provide the code for running a single step of the riemannian flow (9.5) corresponding
to Normalized GD. The pseudocode is given in Algorithm 10.

Loss setting: The algorithm described in Algorithm 10 works for the following
scenario. The loss L is equal to the average of n loss functions `i : RD → R+ and
each `i is defined as follows. Suppose we use n functions fi : RD → R, that share a
common parameter x ∈ RD , to approximate n true labels {bi ∈ R}ni=1 Then, we define
each `i (x) = `(fi (x), bi ), where ` : R × R → R+ denotes a general loss function, that
takes in the prediction of a function and the true label and returns a score. ` should
have the following properties:

1. `(y, b) ≥ 0 for all y, b ∈ R. `(y, b) = 0 iff y = b.

2. ∇2y `(y, b) > 0 for all y ∈ R.

Example of such a loss function ` is the `2 loss function. The scenario described above
contains regression tasks. Moreover, it can also represent binary classification tasks,
since binary classification can be viewed as regression with 0, 1 label.
503
We can also represent the multi-class classification tasks, which we use for our
experiments. Consider the setting, where we are trying to train some function (e.g. a
neural network) f : RD × Rd → R|C| with the parameter space in RD , input examples
from Rd , and the set of classes being denoted by C. For each class c, we can think of
fc (x, a) as the likelihood score for label c to an input a returned by the function with
parameter x. If S = {(a, b) ∈ Rd × C} denotes the set of all input and label pairs in
the training set, we define our loss function as

1 X X
L(x) = |fc (x, a) − I(b = c)|2 .
|S||C| c∈C
(a,b)∈S

Thus, each `i in Algorithm 10 represents one of the terms {|fc (x, a) − I(b = c)|2 }(a,b)∈S,c∈C
in the multi-class classification setting.

Riemannian gradient update details: Each update comprises of three major


steps: a) computing ∇3 L(x)[v1 (x), v1 (x)], b) a projection onto the tangent space of
the manifold, and c) few steps of gradient descent with small learning rate to drop
back to manifold.
To get v1 (x), we follow a power iteration on ∇2 L(x). We use 100 iterations
to compute v1 (x) at each point x. For the plots in Figures 9.5 and 9.6, we run
Algorithm 10 starting from a point x(0) with L(x(0)) = 3.803 × 10−3 , with η = 10−2 ,
ηproj = 10−2 , and Tproj = 103 . We use a learning rate of 10−2 for the Normalized GD
trajectory.

504
Algorithm 10 Simulation for the limiting flow (9.5) of Normalized GD
Input: n loss functions `i : RD → R+ , initial point x(0) with L(x(0)) ≈ 0, maximum
number of iteration T , LR η, Projection LR ηproj , maximum number of projection
iterations Tproj . P
Define L(x) as n1 ni=1 `i (x) and Px,Γ as projection matrix onto the subspace spanned
by ∇f1 (x), · · · , ∇fn (x) for any x ∈ RD .
for t = 0 to T − 1 do
Compute v1 , the top eigenvector of ∇2 L(x(t)).
Compute ∇λ1 (x(t)) = ∇3 L(x(t))[v1 , v1 ]. //This is by Theorem 9.14.1.
Compute Px(t),Γ ∇λ1 (x(t)) by solving least square.
η
y(0) ← x(t) − λ1 (x(t)) (I − Px(t),Γ )∇λ1 (x(t)).
t = 0 to Tproj − 1 do
for e
y(e
t +1) = y(e t)−ηproj ∇L(y(et)). // Inner loop: project GD back to manifold.
x(t + 1) ← y(Tproj ).

505
Bibliography

[1] Diederik P. Kingma and Jimmy Ba. Adam: A method for stochastic optimization.
In 3rd International Conference on Learning Representations, ICLR 2015, San
Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings, 2015. Cited on
pages 2 & 90.

[2] John Duchi, Elad Hazan, and Yoram Singer. Adaptive subgradient methods
for online learning and stochastic optimization. Journal of Machine Learning
Research, 12(Jul):2121–2159, 2011. Cited on page 2.

[3] Yu Nesterov. Introductory lectures on convex programming, 1998. Cited on


page 3.

[4] Chiyuan Zhang, Samy Bengio, Moritz Hardt, Benjamin Recht, and Oriol Vinyals.
Understanding deep learning requires rethinking generalization. In International
Conference on Learning Representations, 2017. Cited on pages 4 & 158.

[5] Shengchao Liu, Dimitris Papailiopoulos, and Dimitris Achlioptas. Bad global
minima exist and sgd can reach them. Advances in Neural Information Processing
Systems, 33:8543–8552, 2020. Cited on page 4.

[6] Behnam Neyshabur. Implicit regularization in deep learning. arXiv preprint


arXiv:1709.01953, 2017. Cited on page 4.

[7] Yiding Jiang, Pierre Foret, Scott Yak, Daniel M Roy, Hossein Mobahi,
Gintare Karolina Dziugaite, Samy Bengio, Suriya Gunasekar, Isabelle Guyon,
and Behnam Neyshabur. Neurips 2020 competition: Predicting generalization
in deep learning. arXiv preprint arXiv:2012.07976, 2020. Cited on page 4.

[8] Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-
aware minimization for efficiently improving generalization. In International
Conference on Learning Representations, 2021. URL https://openreview.net/
forum?id=6Tm1mposlrM. Cited on pages 5 & 399.

[9] Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy,
and Ping Tak Peter Tang. On large-batch training for deep learning: General-
ization gap and sharp minima. arXiv preprint arXiv:1609.04836, 2016. Cited
on pages 5, 312 & 399.

506
[10] David McAllester. Simplified pac-bayesian margin bounds. In Learning theory
and Kernel machines, pages 203–215. Springer, 2003. Cited on page 5.
[11] Gintare Karolina Dziugaite and Daniel M Roy. Computing nonvacuous general-
ization bounds for deep (stochastic) neural networks with many more parameters
than training data. arXiv preprint arXiv:1703.11008, 2017. Cited on page 5.
[12] Yiding Jiang*, Behnam Neyshabur*, Hossein Mobahi, Dilip Krishnan, and
Samy Bengio. Fantastic generalization measures and where to find them. In
International Conference on Learning Representations, 2020. URL https://
openreview.net/forum?id=SJgIPJBFvH. Cited on page 5.
[13] Daniel Soudry, Elad Hoffer, and Nathan Srebro. The implicit bias of gradient
descent on separable data. In International Conference on Learning Representa-
tions, 2018. Cited on pages 5 & 160.
[14] Kaifeng Lyu and Jian Li. Gradient descent maximizes the margin of homogeneous
neural networks. arXiv preprint arXiv:1906.05890, 2019. Cited on pages 5, 252,
306 & 400.
[15] Arthur Jacot, Franck Gabriel, and Clément Hongler. Neural tangent kernel: Con-
vergence and generalization in neural networks. arXiv preprint arXiv:1806.07572,
2018. Cited on pages 5, 121, 252, 312, 333 & 400.
[16] Suriya Gunasekar, Blake E Woodworth, Srinadh Bhojanapalli, Behnam
Neyshabur, and Nati Srebro. Implicit regularization in matrix factorization. In
I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan,
and R. Garnett, editors, Advances in Neural Information Processing Systems
30, pages 6151–6159. Curran Associates, Inc., 2017. Cited on pages 6, 10, 157,
159, 160, 161, 162, 173, 193 & 306.
[17] Jeremy Cohen, Simran Kaur, Yuanzhi Li, J Zico Kolter, and Ameet Talwalkar.
Gradient descent on neural networks typically occurs at the edge of stabil-
ity. In International Conference on Learning Representations, 2020. Cited on
pages 7 & 65.
[18] Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep
network training by reducing internal covariate shift. In International conference
on machine learning, pages 448–456. PMLR, 2015. Cited on pages 8, 14, 19,
91 & 119.
[19] Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization.
arXiv preprint arXiv:1607.06450, 2016. Cited on pages 8, 15, 19 & 91.
[20] Blake Woodworth, Suriya Gunasekar, Jason D Lee, Edward Moroshko, Pedro
Savarese, Itay Golan, Daniel Soudry, and Nathan Srebro. Kernel and rich
regimes in overparametrized models. In Conference on Learning Theory, pages
3635–3673. PMLR, 2020. Cited on pages 11, 249, 250, 251, 252, 261, 262, 271,
306, 311, 312, 326, 328, 330, 332 & 400.
507
[21] Jeff Z HaoChen, Colin Wei, Jason D Lee, and Tengyu Ma. Shape mat-
ters: Understanding the implicit bias of the noise covariance. arXiv preprint
arXiv:2006.08680, 2020. Cited on pages 11, 252, 306, 311, 312, 326 & 329.

[22] Zhiyuan Li and Sanjeev Arora. An exponential learning rate schedule for deep
learning. In International Conference on Learning Representations, 2019. Cited
on pages 12, 91, 92 & 97.

[23] Zhiyuan Li, Srinadh Bhojanapalli, Manzil Zaheer, Sashank Reddi, and Sanjiv
Kumar. Robust training of neural networks using scale invariant architectures.
In International Conference on Machine Learning, pages 12656–12684. PMLR,
2022. Cited on pages 12 & 401.

[24] Zhiyuan Li, Yi Zhang, and Sanjeev Arora. Why are convolutional nets more
sample-efficient than fully-connected nets? In International Conference on
Learning Representations, 2020. Cited on pages 12 & 392.

[25] Zhiyuan Li, Yuping Luo, and Kaifeng Lyu. Towards resolving the implicit
bias of gradient descent for matrix factorization: Greedy low-rank learning. In
International Conference on Learning Representations, 2020. Cited on pages 12,
249, 252, 312 & 400.

[26] Zhiyuan Li, Tianhao Wang, and Sanjeev Arora. What happens after sgd reaches
zero loss?–a mathematical framework. In International Conference on Learning
Representations, 2021. Cited on page 12.

[27] Sanjeev Arora, Zhiyuan Li, and Abhishek Panigrahi. Understanding gradient
descent on the edge of stability in deep learning. In Kamalika Chaudhuri,
Stefanie Jegelka, Le Song, Csaba Szepesvari, Gang Niu, and Sivan Sabato,
editors, Proceedings of the 39th International Conference on Machine Learning,
volume 162 of Proceedings of Machine Learning Research, pages 948–1024. PMLR,
17–23 Jul 2022. URL https://proceedings.mlr.press/v162/arora22a.html.
Cited on page 12.

[28] Shibani Santurkar, Dimitris Tsipras, Andrew Ilyas, and Aleksander Madry.
How does batch normalization help optimization? In S. Bengio, H. Wallach,
H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, editors, Ad-
vances in Neural Information Processing Systems 31, pages 2488–2498. Curran
Associates, Inc., 2018. Cited on pages 14, 18 & 31.

[29] Sanjeev Arora, Nadav Cohen, and Elad Hazan. On the optimization of deep
networks: Implicit acceleration by overparameterization. In International Con-
ference on Machine Learning, pages 244–253. PMLR, 2018. Cited on pages 14,
51, 177, 312 & 400.

[30] Yuxin Wu and Kaiming He. Group normalization. In Proceedings of the European
conference on computer vision (ECCV), pages 3–19, 2018. Cited on pages 15,
19 & 91.
508
[31] Dmitry Ulyanov, Andrea Vedaldi, and Victor Lempitsky. Instance normalization:
The missing ingredient for fast stylization. arXiv preprint arXiv:1607.08022,
2016. Cited on pages 15 & 19.

[32] Elad Hoffer, Itay Hubara, and Daniel Soudry. Fix your classifier: the marginal
value of training the last weight layer. In International Conference on Learning
Representations, 2018. URL https://openreview.net/forum?id=S1Dh8Tg0-.
Cited on pages 15 & 54.

[33] Leslie N Smith. Cyclical learning rates for training neural networks. In 2017
IEEE Winter Conference on Applications of Computer Vision (WACV), pages
464–472. IEEE, 2017. Cited on page 16.

[34] Ilya Loshchilov and Frank Hutter. SGDR: Stochastic Gradient Descent with
Warm Restarts. arXiv e-prints, art. arXiv:1608.03983, Aug 2016. Cited on
pages 16 & 33.

[35] Minhyung Cho and Jaehyung Lee. Riemannian approach to batch normalization.
In I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan,
and R. Garnett, editors, Advances in Neural Information Processing Systems
30, pages 5225–5235. Curran Associates, Inc., 2017. Cited on page 17.

[36] Sanjeev Arora, Zhiyuan Li, and Kaifeng Lyu. Theoretical analysis of auto
rate-tuning by batch normalization. In International Conference on Learning
Representations, 2018. Cited on pages 18, 19, 29, 31, 52, 62, 91, 93, 94 & 401.

[37] Xiaoxia Wu, Rachel Ward, and Léon Bottou. WNGrad: Learn the Learning
Rate in Gradient Descent. arXiv preprint arXiv:1803.02865, 2018. Cited on
pages 18 & 51.

[38] Jonas Kohler, Hadi Daneshmand, Aurelien Lucchi, Ming Zhou, Klaus Neymeyr,
and Thomas Hofmann. Exponential convergence rates for batch normalization:
The power of length-direction decoupling in non-convex optimization. arXiv
preprint arXiv:1805.10694, 2018. Cited on page 18.

[39] Nils Bjorck, Carla P Gomes, Bart Selman, and Kilian Q Weinberger. Understand-
ing batch normalization. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman,
N. Cesa-Bianchi, and R. Garnett, editors, Advances in Neural Information
Processing Systems 31, pages 7705–7716. Curran Associates, Inc., 2018. Cited
on page 18.

[40] Guodong Zhang, Chaoqi Wang, Bowen Xu, and Roger Grosse. Three mecha-
nisms of weight decay regularization. In International Conference on Learning
Representations, 2018. Cited on pages 18 & 91.

[41] Elad Hoffer, Ron Banner, Itay Golan, and Daniel Soudry. Norm matters:
efficient and accurate normalization schemes in deep networks. arXiv preprint
arXiv:1803.01814, 2018. Cited on pages 18 & 91.
509
[42] Twan Van Laarhoven. L2 regularization versus batch and weight normalization.
arXiv preprint arXiv:1706.05350, 2017. Cited on pages 18, 91 & 95.
[43] Ilya Sutskever, James Martens, George Dahl, and Geoffrey Hinton. On the
importance of initialization and momentum in deep learning. In Proceedings
of the 30th International Conference on International Conference on Machine
Learning - Volume 28, ICML’13, pages III–1139–III–1147. JMLR.org, 2013. URL
http://dl.acm.org/citation.cfm?id=3042817.3043064. Cited on page 19.
[44] Adam Paszke, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang,
Zachary DeVito, Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer.
Automatic differentiation in pytorch. 2017. Cited on pages 19 & 58.
[45] Robert Mansel Gower, Nicolas Loizou, Xun Qian, Alibek Sailanbayev, Egor
Shulgin, and Peter Richtárik. Sgd: General analysis and improved rates. arXiv
preprint arXiv:1901.09401, 2019. Cited on page 46.
[46] Sanjoy Dasgupta and Anupam Gupta. An elementary proof of a theorem of
johnson and lindenstrauss. Random Structures & Algorithms, 22(1):60–65, 2003.
Cited on page 46.
[47] Yang You, Igor Gitman, and Boris Ginsburg. Large Batch Training of Convo-
lutional Networks. arXiv e-prints, art. arXiv:1708.03888, Aug 2017. Cited on
page 53.
[48] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning
for image recognition. In Proceedings of the IEEE conference on computer vision
and pattern recognition, pages 770–778, 2016. Cited on pages 53 & 117.
[49] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Identity mappings
in deep residual networks. In European conference on computer vision, pages
630–645. Springer, 2016. Cited on page 53.
[50] Zhiyuan Li, Kaifeng Lyu, and Sanjeev Arora. Reconciling modern deep learning
with traditional optimization analyses: The intrinsic learning rate. Advances
in Neural Information Processing Systems, 33, 2020. Cited on pages 65, 91, 93,
306, 312 & 395.
[51] Ekaterina Lobacheva, Maxim Kodryan, Nadezhda Chirkova, Andrey Malinin,
and Dmitry P Vetrov. On the periodic behavior of neural network training
with batch normalization and weight decay. Advances in Neural Information
Processing Systems, 34, 2021. Cited on page 65.
[52] Ramon van Handel. Probability in high dimension. 2016. Cited on page 74.
[53] Liyuan Liu, Xiaodong Liu, Jianfeng Gao, Weizhu Chen, and Jiawei Han. Under-
standing the difficulty of training transformers. In Bonnie Webber, Trevor Cohn,
Yulan He, and Yang Liu, editors, Proceedings of the 2020 Conference on Empir-
ical Methods in Natural Language Processing, EMNLP 2020, Online, November
510
16-20, 2020, pages 5747–5763. Association for Computational Linguistics, 2020.
Cited on pages 86, 88 & 90.

[54] Jingzhao Zhang, Sai Praneeth Karimireddy, Andreas Veit, Seungyeon Kim,
Sashank J. Reddi, Sanjiv Kumar, and Suvrit Sra. Why are adaptive methods
good for attention models? In Hugo Larochelle, Marc’Aurelio Ranzato, Raia
Hadsell, Maria-Florina Balcan, and Hsuan-Tien Lin, editors, Advances in Neural
Information Processing Systems 33: Annual Conference on Neural Information
Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual, 2020.
Cited on pages 86, 90 & 94.

[55] Aitor Lewkowycz, Yasaman Bahri, Ethan Dyer, Jascha Sohl-Dickstein, and Guy
Gur-Ari. The large learning rate phase of deep learning: the catapult mechanism.
arXiv preprint arXiv:2003.02218, 2020. Cited on pages 88 & 401.

[56] Boris T Polyak. Introduction to optimization. optimization software. Inc.,


Publications Division, New York, 1, 1987. Cited on page 90.

[57] John Duchi, Elad Hazan, and Yoram Singer. Adaptive subgradient methods
for online learning and stochastic optimization. Journal of machine learning
research, 12(Jul):2121–2159, 2011. Cited on page 90.

[58] Tijmen Tieleman and Geoffrey Hinton. Lecture 6.5-rmsprop: Divide the gradient
by a running average of its recent magnitude. COURSERA: Neural networks
for machine learning, 4(2):26–31, 2012. Cited on pages 90, 184 & 188.

[59] Sashank J Reddi, Satyen Kale, and Sanjiv Kumar. On the convergence of ADAM
and beyond. arXiv preprint arXiv:1904.09237, 2019. Cited on page 90.

[60] Yang You, Jing Li, Sashank J. Reddi, Jonathan Hseu, Sanjiv Kumar, Srinadh
Bhojanapalli, Xiaodan Song, James Demmel, Kurt Keutzer, and Cho-Jui Hsieh.
Large batch optimization for deep learning: Training BERT in 76 minutes. In
8th International Conference on Learning Representations, ICLR 2020, Addis
Ababa, Ethiopia, April 26-30, 2020. OpenReview.net, 2020. Cited on page 90.

[61] Noam Shazeer and Mitchell Stern. Adafactor: Adaptive learning rates with
sublinear memory cost. In International Conference on Machine Learning, pages
4596–4604. PMLR, 2018. Cited on page 90.

[62] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,
Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need.
In Advances in neural information processing systems, pages 5998–6008, 2017.
Cited on pages 90 & 92.

[63] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-
training of deep bidirectional transformers for language understanding. arXiv
preprint arXiv:1810.04805, 2018. Cited on pages 90, 92 & 101.

511
[64] Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang,
Michael Matena, Yanqi Zhou, Wei Li, and Peter J Liu. Exploring the limits
of transfer learning with a unified text-to-text transformer. arXiv preprint
arXiv:1910.10683, 2019. Cited on page 90.

[65] Rohan Anil, Vineet Gupta, Tomer Koren, and Yoram Singer. Memory efficient
adaptive optimization. Advances in Neural Information Processing Systems, 32,
2019. Cited on page 90.

[66] Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav
Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton,
Sebastian Gehrmann, et al. Palm: Scaling language modeling with pathways.
arXiv preprint arXiv:2204.02311, 2022. Cited on page 90.

[67] Elad Hazan, Kfir Levy, and Shai Shalev-Shwartz. Beyond convexity: Stochas-
tic quasi-convex optimization. In Advances in Neural Information Processing
Systems, pages 1594–1602, 2015. Cited on page 90.

[68] Kfir Y Levy. The power of normalization: Faster evasion of saddle points. arXiv
preprint arXiv:1611.04831, 2016. Cited on page 90.

[69] Lei Huang, Xianglong Liu, Bo Lang, and Bo Li. Projection based weight
normalization for deep neural networks. ArXiv, abs/1710.02338, 2017. Cited on
page 90.

[70] Razvan Pascanu, Tomas Mikolov, and Yoshua Bengio. On the difficulty of
training recurrent neural networks. In Sanjoy Dasgupta and David McAllester,
editors, Proceedings of the 30th International Conference on Machine Learning,
volume 28 of Proceedings of Machine Learning Research, pages 1310–1318,
Atlanta, Georgia, USA, 17–19 Jun 2013. PMLR. URL https://proceedings.
mlr.press/v28/pascanu13.html. Cited on page 90.

[71] Xiangyi Chen, Zhiwei Steven Wu, and Mingyi Hong. Understanding gradient
clipping in private SGD: A geometric perspective. CoRR, abs/2006.15429, 2020.
URL https://arxiv.org/abs/2006.15429. Cited on page 90.

[72] Jingzhao Zhang, Tianxing He, Suvrit Sra, and Ali Jadbabaie. Why gradient
clipping accelerates training: A theoretical justification for adaptivity. In 8th
International Conference on Learning Representations, ICLR 2020, Addis Ababa,
Ethiopia, April 26-30, 2020. OpenReview.net, 2020. Cited on page 90.

[73] Tim Salimans and Durk P Kingma. Weight normalization: A simple reparame-
terization to accelerate training of deep neural networks. Advances in neural
information processing systems, 29:901–909, 2016. Cited on page 91.

[74] Dmitry Ulyanov, Andrea Vedaldi, and Victor Lempitsky. Instance normalization:
The missing ingredient for fast stylization. arXiv preprint arXiv:1607.08022,
2016. Cited on page 91.
512
[75] Ruosi Wan, Zhanxing Zhu, Xiangyu Zhang, and Jian Sun. Spherical motion
dynamics: Learning dynamics of neural network with normalization, weight
decay, and sgd. arXiv preprint arXiv:2006.08419, 2020. Cited on pages 91 & 93.

[76] Dan Hendrycks and Kevin Gimpel. Gaussian error linear units (gelus). arXiv
preprint arXiv:1606.08415, 2016. Cited on pages 97, 320 & 417.

[77] Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, and Percy Liang. Squad:
100,000+ questions for machine comprehension of text. arXiv preprint
arXiv:1606.05250, 2016. Cited on page 103.

[78] Pranav Rajpurkar, Robin Jia, and Percy Liang. Know what you don’t know:
Unanswerable questions for squad. arXiv preprint arXiv:1806.03822, 2018. Cited
on page 103.

[79] Adina Williams, Nikita Nangia, and Samuel Bowman. A broad-coverage chal-
lenge corpus for sentence understanding through inference. In Proceedings of the
2018 Conference of the North American Chapter of the Association for Compu-
tational Linguistics: Human Language Technologies, Volume 1 (Long Papers),
pages 1112–1122. Association for Computational Linguistics, 2018. Cited on
page 103.

[80] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton. Imagenet classification
with deep convolutional neural networks. In Advances in neural information
processing systems, pages 1097–1105, 2012. Cited on page 117.

[81] Gao Huang, Zhuang Liu, Laurens Van Der Maaten, and Kilian Q Weinberger.
Densely connected convolutional networks. In Proceedings of the IEEE conference
on computer vision and pattern recognition, pages 4700–4708, 2017. Cited on
page 117.

[82] Andrew Y Ng. Feature selection, l 1 vs. l 2 regularization, and rotational


invariance. In Proceedings of the twenty-first international conference on Machine
learning, page 78, 2004. Cited on pages 118, 132, 133 & 392.

[83] Simon S Du, Yining Wang, Xiyu Zhai, Sivaraman Balakrishnan, Russ R Salakhut-
dinov, and Aarti Singh. How many samples are needed to estimate a convolutional
neural network? In Advances in Neural Information Processing Systems, pages
373–383, 2018. Cited on page 120.

[84] Yossi Arjevani and Ohad Shamir. On the iteration complexity of oblivious
first-order optimization algorithms. In International Conference on Machine
Learning, pages 908–916, 2016. Cited on page 120.

[85] Colin Wei, Jason D Lee, Qiang Liu, and Tengyu Ma. Regularization matters:
Generalization and optimization of neural nets vs their induced kernel. In
Advances in Neural Information Processing Systems, pages 9709–9721, 2019.
Cited on page 121.
513
[86] Zeyuan Allen-Zhu and Yuanzhi Li. What can resnet learn efficiently, going
beyond kernels? In Advances in Neural Information Processing Systems, pages
9015–9025, 2019. Cited on page 121.
[87] Anselm Blumer, A. Ehrenfeucht, David Haussler, and Manfred K. Warmuth.
Learnability and the vapnik-chervonenkis dimension. J. ACM, 36(4):929–965,
October 1989. ISSN 0004-5411. doi: 10.1145/76359.76371. URL https://doi.
org/10.1145/76359.76371. Cited on page 123.
[88] Gyora M Benedek and Alon Itai. Learnability with respect to fixed distributions.
Theoretical Computer Science, 86(2):377–389, 1991. Cited on pages 134 & 137.
[89] Philip M. Long. On the sample complexity of PAC learning half-spaces against
the uniform distribution. IEEE Transactions on Neural Networks, 6(6):1556–
1559, 1995. Cited on page 134.
[90] Michel Talagrand. Upper and lower bounds for stochastic processes: modern
methods and classical problems, volume 60. Springer Science & Business Media,
2014. Cited on page 141.
[91] Stanislaw J Szarek. Metric entropy of homogeneous spaces. arXiv preprint
math/9701213, 1997. Cited on page 145.
[92] Zongming Ma and Yihong Wu. Volume ratio, sparsity, and minimaxity under
unitarily invariant norms. IEEE Transactions on Information Theory, 61(12):
6939–6956, 2015. Cited on page 146.
[93] Roman Vershynin. High-Dimensional Probability: An Introduction with Ap-
plications in Data Science. Cambridge Series in Statistical and Probabilistic
Mathematics. Cambridge University Press, 2018. doi: 10.1017/9781108231596.
Cited on page 146.
[94] Yuejie Chi, Yue M Lu, and Yuxin Chen. Nonconvex optimization meets low-rank
matrix factorization: An overview. IEEE Transactions on Signal Processing, 67
(20):5239–5269, 2019. Cited on page 158.
[95] Sanjeev Arora, Nadav Cohen, Wei Hu, and Yuping Luo. Implicit regularization
in deep matrix factorization. arXiv preprint arXiv:1905.13655, 2019. Cited on
pages 159, 161, 249, 252, 312 & 400.
[96] Gauthier Gidel, Francis Bach, and Simon Lacoste-Julien. Implicit regulariza-
tion of discrete gradient dynamics in linear neural networks. In H. Wallach,
H. Larochelle, A. Beygelzimer, F. d’ Alché-Buc, E. Fox, and R. Garnett, editors,
Advances in Neural Information Processing Systems 32, pages 3196–3206. Curran
Associates, Inc., 2019. Cited on pages 159, 161 & 165.
[97] Daniel Gissin, Shai Shalev-Shwartz, and Amit Daniely. The implicit bias of depth:
How incremental learning drives generalization. In International Conference on
Learning Representations, 2020. Cited on pages 159, 160, 161, 164 & 165.
514
[98] Noam Razin and Nadav Cohen. Implicit regularization in deep learning may
not be explainable by norms. arXiv preprint arXiv:2005.06398, 2020. Cited on
pages 159, 161, 252 & 400.

[99] Ashia C Wilson, Rebecca Roelofs, Mitchell Stern, Nathan Srebro, and Benjamin
Recht. The marginal value of adaptive gradient methods in machine learning.
arXiv preprint arXiv:1705.08292, 2017. Cited on page 160.

[100] Daniel Soudry, Elad Hoffer, Mor Shpigel Nacson, Suriya Gunasekar, and Nathan
Srebro. The implicit bias of gradient descent on separable data. The Journal of
Machine Learning Research, 19(1):2822–2878, 2018. Cited on pages 160, 252,
306, 312 & 400.

[101] Mor Shpigel Nacson, Jason Lee, Suriya Gunasekar, Pedro Henrique Pamplona
Savarese, Nathan Srebro, and Daniel Soudry. Convergence of gradient descent
on separable data. In Kamalika Chaudhuri and Masashi Sugiyama, editors,
Proceedings of Machine Learning Research, volume 89 of Proceedings of Ma-
chine Learning Research, pages 3420–3428. PMLR, 16–18 Apr 2019. Cited on
pages 160 & 252.

[102] Mor Shpigel Nacson, Nathan Srebro, and Daniel Soudry. Stochastic gradient
descent on separable data: Exact convergence with a fixed learning rate. In
Kamalika Chaudhuri and Masashi Sugiyama, editors, Proceedings of Machine
Learning Research, volume 89 of Proceedings of Machine Learning Research,
pages 3051–3059. PMLR, 16–18 Apr 2019. Cited on page 160.

[103] Ziwei Ji and Matus Telgarsky. A refined primal-dual analysis of the implicit
bias. arXiv preprint arXiv:1906.04540, 2019. Cited on page 160.

[104] Ziwei Ji and Matus Telgarsky. Gradient descent aligns the layers of deep linear
networks. arXiv preprint arXiv:1810.02032, 2018. Cited on page 160.

[105] Suriya Gunasekar, Blake Woodworth, Srinadh Bhojanapalli, Behnam Neyshabur,


and Nathan Srebro. Implicit regularization in matrix factorization. In 2018
Information Theory and Applications Workshop (ITA), pages 1–10. IEEE, 2018.
Cited on pages 160, 250, 251, 252, 261, 263, 271, 312 & 400.

[106] Mor Shpigel Nacson, Suriya Gunasekar, Jason Lee, Nathan Srebro, and Daniel
Soudry. Lexicographic and depth-sensitive margins in homogeneous and non-
homogeneous deep models. In Kamalika Chaudhuri and Ruslan Salakhutdinov,
editors, Proceedings of the 36th International Conference on Machine Learning,
volume 97 of Proceedings of Machine Learning Research, pages 4683–4692, Long
Beach, California, USA, 09–15 Jun 2019. PMLR. Cited on page 160.

[107] Kaifeng Lyu and Jian Li. Gradient descent maximizes the margin of homogeneous
neural networks. In International Conference on Learning Representations, 2020.
Cited on page 160.

515
[108] Arthur Jacot, Franck Gabriel, and Clement Hongler. Neural tangent kernel:
Convergence and generalization in neural networks. In S. Bengio, H. Wallach,
H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, editors, Ad-
vances in Neural Information Processing Systems 31, pages 8571–8580. Curran
Associates, Inc., 2018. Cited on page 160.
[109] Sanjeev Arora, Simon S Du, Wei Hu, Zhiyuan Li, Ruslan Salakhutdinov, and
Ruosong Wang. On exact computation with an infinitely wide neural net. arXiv
preprint arXiv:1904.11955, 2019. Cited on pages 160 & 400.
[110] Lénaı̈c Chizat and Francis Bach. Implicit bias of gradient descent for wide two-
layer neural networks trained with the logistic loss. volume 125 of Proceedings
of Machine Learning Research, pages 1305–1338. PMLR, 09–12 Jul 2020. Cited
on page 160.
[111] Lénaı̈c Chizat, Edouard Oyallon, and Francis Bach. On lazy training in differen-
tiable programming. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d’ Alché-
Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing
Systems 32, pages 2937–2947. Curran Associates, Inc., 2019. Cited on page 161.
[112] Yuanzhi Li, Tengyu Ma, and Hongyang Zhang. Algorithmic regularization in over-
parameterized matrix sensing and neural networks with quadratic activations. In
Conference On Learning Theory, pages 2–47. PMLR, 2018. Cited on pages 161,
252, 312 & 400.
[113] Mohamed Ali Belabbas. On implicit regularization: Morse functions and appli-
cations to matrix factorization. arXiv preprint arXiv:2001.04264, 2020. Cited
on page 161.
[114] Zheng Wang, Ming-Jun Lai, Zhaosong Lu, Wei Fan, Hasan Davulcu, and Jieping
Ye. Rank-one matrix pursuit for matrix completion. In International Conference
on Machine Learning, pages 91–99, 2014. Cited on pages 168 & 189.
[115] Quanming Yao and James Tin Yau Kwok. Greedy learning of generalized low-
rank models. In IJCAI International Joint Conference on Artificial Intelligence,
2016. Cited on page 168.
[116] Shai Shalev-Shwartz and Yoram Singer. On the equivalence of weak learnabil-
ity and linear separability: New relaxations and efficient boosting algorithms.
Machine learning, 80(2-3):141–163, 2010. Cited on page 168.
[117] Rajiv Khanna, Ethan Elenberg, Alexandros G Dimakis, and Sahand Negahban.
On approximation guarantees for greedy low rank optimization. arXiv preprint
arXiv:1703.02721, 2017. Cited on page 168.
[118] Benjamin D. Haeffele and René Vidal. Structured Low-Rank Matrix Factor-
ization: Global Optimality, Algorithms, and Applications. IEEE Transactions
on Pattern Analysis and Machine Intelligence (PAMI), 42(6):1468–1482, 2019.
Cited on page 168.
516
[119] Jason D Lee, Ioannis Panageas, Georgios Piliouras, Max Simchowitz, Michael I
Jordan, and Benjamin Recht. First-order methods almost always avoid saddle
points. arXiv preprint arXiv:1710.07406, 2017. Cited on pages 172, 174, 222,
331 & 416.

[120] Ioannis Panageas, Georgios Piliouras, and Xiao Wang. First-order methods
almost always avoid saddle points: The case of vanishing step-sizes. In H. Wallach,
H. Larochelle, A. Beygelzimer, F. d’ Alché-Buc, E. Fox, and R. Garnett, editors,
Advances in Neural Information Processing Systems 32, pages 6474–6483. Curran
Associates, Inc., 2019. Cited on page 172.

[121] Jason D Lee, Max Simchowitz, Michael I Jordan, and Benjamin Recht. Gradient
descent only converges to minimizers. In Conference on learning theory, pages
1246–1257. PMLR, 2016. Cited on pages 174, 331 & 416.

[122] Jeff Bezanson, Stefan Karpinski, Viral B Shah, and Alan Edelman. Julia: A
fast dynamic language for technical computing. arXiv preprint arXiv:1209.5145,
2012. Cited on page 187.

[123] Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury,
Gregory Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga,
Alban Desmaison, Andreas Kopf, Edward Yang, Zachary DeVito, Martin Raison,
Alykhan Tejani, Sasank Chilamkurthy, Benoit Steiner, Lu Fang, Junjie Bai,
and Soumith Chintala. Pytorch: An imperative style, high-performance deep
learning library. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d’ Alché-Buc,
E. Fox, and R. Garnett, editors, Advances in Neural Information Processing
Systems 32, pages 8024–8035. Curran Associates, Inc., 2019. Cited on page 187.

[124] Steven Diamond and Stephen Boyd. CVXPY: A Python-embedded modeling


language for convex optimization. Journal of Machine Learning Research, 17
(83):1–5, 2016. Cited on page 189.

[125] Akshay Agrawal, Robin Verschueren, Steven Diamond, and Stephen Boyd. A
rewriting system for convex optimization problems. Journal of Control and
Decision, 5(1):42–60, 2018. Cited on page 189.

[126] Stanislaw Lojasiewicz. Ensembles semi-analytiques. IHES notes, 1965. Cited on


page 220.

[127] Simon S Du and Jason D Lee. On the power of over-parametrization in neural


networks with quadratic activation. In International Conference on Machine
Learning, pages 1328–1337, 2018. Cited on page 222.

[128] Lawrence Perko. Differential equations and dynamical systems, volume 7.


Springer Science & Business Media, 2013. Cited on page 223.

[129] Frank H. Clarke. Generalized gradients and applications. Transactions of the


American Mathematical Society, 205:247–262, 1975. Cited on page 227.
517
[130] Frank H Clarke. Optimization and Nonsmooth Analysis. Society for Industrial
and Applied Mathematics, 1990. doi: 10.1137/1.9781611971309. Cited on
page 227.

[131] Francis H Clarke, Yuri S Ledyaev, Ronald J Stern, and Peter R Wolenski.
Nonsmooth analysis and control theory, volume 178. Springer Science & Business
Media, 2008. Cited on page 227.

[132] Jean-Baptiste Hiriart-Urruty and A. S. Lewis. The clarke and michel-penot


subdifferentials of the eigenvalues of a symmetric matrix. Comput. Optim. Appl.,
13(1-3):13–23, 1999. doi: 10.1023/A:1008644520093. Cited on page 227.

[133] Zhiyuan Li, Tianhao Wang, and Sanjeev Arora. What happens after SGD
reaches zero loss? –a mathematical framework. In International Conference on
Learning Representations, 2022. Cited on pages 249, 252, 271, 399 & 400.

[134] Tomas Vaskevicius, Varun Kanade, and Patrick Rebeschini. Implicit regulariza-
tion for optimal sparse recovery. Advances in Neural Information Processing
Systems, 32:2972–2983, 2019. Cited on pages 250, 251, 252, 261 & 312.

[135] Chulhee Yun, Shankar Krishnan, and Hossein Mobahi. A unifying view on
implicit bias in training linear neural networks. arXiv preprint arXiv:2010.02501,
2020. Cited on pages 250, 251 & 252.

[136] Ehsan Amid and Manfred K Warmuth. Winnowing with gradient descent. In
Conference on Learning Theory, pages 163–182. PMLR, 2020. Cited on pages 250,
251, 252, 261 & 274.

[137] Ehsan Amid and Manfred KK Warmuth. Reparameterizing mirror descent


as gradient descent. Advances in Neural Information Processing Systems, 33:
8430–8439, 2020. Cited on pages 250, 251, 252, 253, 261, 274 & 275.

[138] Shahar Azulay, Edward Moroshko, Mor Shpigel Nacson, Blake Woodworth,
Nathan Srebro, Amir Globerson, and Daniel Soudry. On the implicit bias
of initialization shape: Beyond infinitesimal mirror descent. arXiv preprint
arXiv:2102.09769, 2021. Cited on pages 250, 251, 252, 271, 274, 330 & 400.

[139] Arkadij Semenovič Nemirovskij and David Borisovich Yudin. Problem complexity
and method efficiency in optimization. 1983. Cited on pages 250 & 258.

[140] Amir Beck and Marc Teboulle. Mirror descent and nonlinear projected sub-
gradient methods for convex optimization. Operations Research Letters, 31(3):
167–175, 2003. Cited on pages 250 & 258.

[141] Udaya Ghai, Zhou Lu, and Elad Hazan. Non-convex online learning via algo-
rithmic equivalence. arXiv preprint arXiv:2205.15235, 2022. Cited on pages 251,
253 & 274.

518
[142] Suriya Gunasekar, Jason Lee, Daniel Soudry, and Nathan Srebro. Characterizing
implicit bias in terms of optimization geometry. In International Conference
on Machine Learning, pages 1832–1841. PMLR, 2018. Cited on pages 252, 259,
274, 312 & 400.

[143] Ziwei Ji and Matus Telgarsky. Characterizing the implicit bias via a primal-dual
analysis. In Algorithmic Learning Theory, pages 772–804. PMLR, 2021. Cited
on page 252.

[144] Edward Moroshko, Suriya Gunasekar, Blake Woodworth, Jason D Lee, Nathan
Srebro, and Daniel Soudry. Implicit bias in deep linear classification: Initializa-
tion scale vs training accuracy. arXiv preprint arXiv:2007.06738, 2020. Cited
on page 252.

[145] Ziwei Ji and Matus Telgarsky. Directional convergence and alignment in deep
learning. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin,
editors, Advances in Neural Information Processing Systems, volume 33, pages
17176–17186. Curran Associates, Inc., 2020. Cited on page 252.

[146] Ziwei Ji and Matus Telgarsky. Risk and parameter convergence of logistic
regression. arXiv preprint arXiv:1803.07300, 2018. Cited on page 252.

[147] Ziwei Ji and Matus Telgarsky. The implicit bias of gradient descent on nonsepa-
rable data. In Conference on Learning Theory, pages 1772–1798. PMLR, 2019.
Cited on page 252.

[148] Lenaic Chizat, Edouard Oyallon, and Francis Bach. On lazy training in differen-
tiable programming. arXiv preprint arXiv:1812.07956, 2018. Cited on pages 252,
312 & 400.

[149] Simon S Du, Xiyu Zhai, Barnabas Poczos, and Aarti Singh. Gradient de-
scent provably optimizes over-parameterized neural networks. arXiv preprint
arXiv:1810.02054, 2018. Cited on pages 252 & 312.

[150] Simon Du, Jason Lee, Haochuan Li, Liwei Wang, and Xiyu Zhai. Gradient
descent finds global minima of deep neural networks. In International Conference
on Machine Learning, pages 1675–1685. PMLR, 2019. Cited on pages 252,
312 & 400.

[151] Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A convergence theory for deep
learning via over-parameterization. In International Conference on Machine
Learning, pages 242–252. PMLR, 2019. Cited on pages 252, 312 & 400.

[152] Zeyuan Allen-Zhu, Yuanzhi Li, and Yingyu Liang. Learning and generalization
in overparameterized neural networks, going beyond two layers. Advances in
neural information processing systems, 2019. Cited on pages 252, 312 & 400.

519
[153] Difan Zou, Yuan Cao, Dongruo Zhou, and Quanquan Gu. Gradient descent
optimizes over-parameterized deep relu networks. Machine Learning, 109(3):
467–492, 2020. Cited on pages 252, 312 & 400.
[154] Sanjeev Arora, Simon Du, Wei Hu, Zhiyuan Li, and Ruosong Wang. Fine-
grained analysis of optimization and generalization for overparameterized two-
layer neural networks. In International Conference on Machine Learning, pages
322–332. PMLR, 2019. Cited on pages 252, 312 & 400.
[155] Greg Yang. Scaling limits of wide neural networks with weight sharing: Gaussian
process behavior, gradient independence, and neural tangent kernel derivation.
arXiv preprint arXiv:1902.04760, 2019. Cited on pages 252, 312 & 400.
[156] Arthur Jacot, François Ged, Franck Gabriel, Berfin Şimşek, and Clément Hongler.
Deep linear networks dynamics: Low-rank biases induced by initialization scale
and l2 regularization. arXiv preprint arXiv:2106.15933, 2021. Cited on page 252.
[157] Suriya Gunasekar, Jason D Lee, Daniel Soudry, and Nati Srebro. Implicit
bias of gradient descent on linear convolutional networks. Advances in Neural
Information Processing Systems, 31, 2018. Cited on page 252.
[158] Lenaic Chizat and Francis Bach. Implicit bias of gradient descent for wide two-
layer neural networks trained with the logistic loss. In Conference on Learning
Theory, pages 1305–1338. PMLR, 2020. Cited on page 252.
[159] Kaifeng Lyu, Zhiyuan Li, Runzhe Wang, and Sanjeev Arora. Gradient descent
on two-layer nets: Margin maximization and simplicity bias. Advances in Neural
Information Processing Systems, 34, 2021. Cited on pages 252 & 400.
[160] Noam Razin, Asaf Maman, and Nadav Cohen. Implicit regularization in hi-
erarchical tensor factorization and deep convolutional neural networks. arXiv
preprint arXiv:2201.11729, 2022. Cited on page 252.
[161] Dominik Stöger and Mahdi Soltanolkotabi. Small random initialization is akin
to spectral learning: Optimization and generalization guarantees for overpa-
rameterized low-rank matrix reconstruction. Advances in Neural Information
Processing Systems, 34, 2021. Cited on page 252.
[162] Rong Ge, Yunwei Ren, Xiang Wang, and Mo Zhou. Understanding deflation pro-
cess in over-parametrized tensor decomposition. Advances in Neural Information
Processing Systems, 34, 2021. Cited on page 252.
[163] Greg Yang and Edward J Hu. Tensor programs iv: Feature learning in infinite-
width neural networks. In International Conference on Machine Learning, pages
11727–11737. PMLR, 2021. Cited on page 252.
[164] Yuanzhi Li, Colin Wei, and Tengyu Ma. Towards explaining the regularization
effect of initial large learning rate in training neural networks. arXiv preprint
arXiv:1907.04595, 2019. Cited on pages 252, 306 & 312.
520
[165] Guy Blanc, Neha Gupta, Gregory Valiant, and Paul Valiant. Implicit regulariza-
tion for deep neural networks driven by an ornstein-uhlenbeck like process. In
Conference on learning theory, pages 483–513. PMLR, 2020. Cited on pages 252,
305, 307, 308, 311, 325, 326 & 400.

[166] Alex Damian, Tengyu Ma, and Jason Lee. Label noise sgd provably prefers flat
global minimizers. arXiv preprint arXiv:2106.06530, 2021. Cited on pages 252,
309, 311, 325, 326, 327 & 400.

[167] Difan Zou, Jingfeng Wu, Vladimir Braverman, Quanquan Gu, Dean P Foster,
and Sham Kakade. The benefits of implicit regularization from sgd in least
squares problems. Advances in Neural Information Processing Systems, 34:
5456–5468, 2021. Cited on page 252.

[168] Qian Qian and Xiaoyuan Qian. The implicit bias of adagrad on separable
data. Advances in Neural Information Processing Systems, 32, 2019. Cited on
page 252.

[169] Bohan Wang, Qi Meng, Huishuai Zhang, Ruoyu Sun, Wei Chen, and Zhi-
Ming Ma. Momentum doesn’t change the implicit bias. arXiv preprint
arXiv:2110.03891, 2021. Cited on page 252.

[170] Bohan Wang, Qi Meng, Wei Chen, and Tie-Yan Liu. The implicit bias for adap-
tive optimization algorithms on homogeneous neural networks. In International
Conference on Machine Learning, pages 10849–10858. PMLR, 2021. Cited on
page 252.

[171] Ziwei Ji, Nathan Srebro, and Matus Telgarsky. Fast margin maximization via
dual acceleration. In International Conference on Machine Learning, pages
4860–4869. PMLR, 2021. Cited on page 252.

[172] Suriya Gunasekar, Blake Woodworth, and Nathan Srebro. Mirrorless mirror
descent: A natural derivation of mirror descent. In International Conference on
Artificial Intelligence and Statistics, pages 2305–2313. PMLR, 2021. Cited on
pages 252 & 400.

[173] John M Lee. Introduction to Smooth Manifolds. Springer, 2013. Cited on


pages 254, 261, 272 & 284.

[174] Héctor J Sussmann. Orbits of families of vector fields and integrability of


distributions. Transactions of the American Mathematical Society, 180:171–188,
1973. Cited on page 258.

[175] Sébastien Bubeck et al. Convex optimization: Algorithms and complexity.


Foundations and Trends R in Machine Learning, 8(3-4):231–357, 2015. Cited on
page 258.

521
[176] Scott Pesme, Loucas Pillaud-Vivien, and Nicolas Flammarion. Implicit bias
of sgd for diagonal linear networks: a provable benefit of stochasticity. arXiv
preprint arXiv:2106.09524, 2021. Cited on pages 271, 313 & 330.

[177] John Nash. The imbedding problem for riemannian manifolds. Annals of
mathematics, pages 20–63, 1956. Cited on page 274.

[178] Matthias Gunther. Isometric embeddings of riemannian manifolds, kyoto, 1990.


In Proc. Intern. Congr. Math., pages 1137–1143. Math. Soc. Japan, 1991. Cited
on page 274.

[179] Ralph Tyrell Rockafellar. Convex analysis. In Convex analysis. Princeton


university press, 2015. Cited on pages 277 & 278.

[180] Heinz H Bauschke, Jonathan M Borwein, et al. Legendre functions and the
method of random bregman projections. Journal of convex analysis, 4(1):27–67,
1997. Cited on pages 277 & 278.

[181] Lev M Bregman. The relaxation method of finding the common point of convex
sets and its application to the solution of problems in convex programming.
USSR computational mathematics and mathematical physics, 7(3):200–217, 1967.
Cited on page 278.

[182] Yair Censor and Arnold Lent. An iterative row-action method for interval convex
programming. Journal of Optimization theory and Applications, 34(3):321–353,
1981. Cited on page 278.

[183] Felipe Alvarez, Jérôme Bolte, and Olivier Brahic. Hessian riemannian gradient
flows in convex programming. SIAM journal on control and optimization, 43(2):
477–501, 2004. Cited on pages 278 & 279.

[184] Serge Lang. Introduction to differentiable manifolds. Springer Science & Business
Media, 2006. Cited on page 279.

[185] Jean-Pierre Crouzeix. A relationship between the second derivatives of a convex


function and of its conjugate. Mathematical Programming, 13(1):364–365, 1977.
Cited on page 286.

[186] Robert L Foote. Regularity of the distance function. Proceedings of the American
Mathematical Society, 92(1):153–155, 1984. Cited on page 301.

[187] Gary Shon Katzenberger. Solutions of a stochastic differential equation forced


onto a manifold by a large drift. The Annals of Probability, pages 1587–1628,
1991. Cited on pages 305, 307, 310, 315, 319, 321, 322, 333, 336, 337, 340 & 400.

[188] Stanislaw Jastrzebski, Zachary Kenton, Devansh Arpit, Nicolas Ballas, Asja
Fischer, Yoshua Bengio, and Amos Storkey. Three factors influencing minima in
sgd. arXiv preprint arXiv:1711.04623, 2017. Cited on pages 306 & 312.

522
[189] Bin Shi, Weijie J Su, and Michael I Jordan. On learning rates and schr\” odinger
operators. arXiv preprint arXiv:2004.06977, 2020. Cited on pages 306 & 325.

[190] Qianxiao Li, Cheng Tai, and E Weinan. Stochastic modified equations and
adaptive stochastic gradient algorithms. In International Conference on Machine
Learning, pages 2101–2110. PMLR, 2017. Cited on pages 307, 312, 335 & 401.

[191] Xiang Cheng, Dong Yin, Peter Bartlett, and Michael Jordan. Stochastic gradient
and langevin processes. In International Conference on Machine Learning, pages
1810–1819. PMLR, 2020. Cited on pages 307 & 312.

[192] Garvesh Raskutti, Martin J Wainwright, and Bin Yu. Minimax-optimal rates
for sparse additive models over kernel classes via convex programming. Journal
of machine learning research, 13(2), 2012. Cited on page 311.

[193] C Daniel Freeman and Joan Bruna. Topology and geometry of half-rectified
network optimization. arXiv preprint arXiv:1611.01540, 2016. Cited on page 311.

[194] Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin, Dmitry Vetrov, and
Andrew Gordon Wilson. Loss surfaces, mode connectivity, and fast ensembling
of dnns. arXiv preprint arXiv:1802.10026, 2018. Cited on page 311.

[195] Felix Draxler, Kambis Veschgini, Manfred Salmhofer, and Fred Hamprecht.
Essentially no barriers in neural network energy landscape. In International
conference on machine learning, pages 1309–1318. PMLR, 2018. Cited on
page 311.

[196] Luca Venturi, Afonso S Bandeira, and Joan Bruna. Spurious valleys in two-layer
neural network optimization landscapes. arXiv preprint arXiv:1802.06384, 2018.
Cited on page 311.

[197] Shiyu Liang, Ruoyu Sun, Yixuan Li, and Rayadurgam Srikant. Understanding
the loss surface of neural networks for binary classification. In International
Conference on Machine Learning, pages 2835–2843. PMLR, 2018. Cited on
page 311.

[198] Quynh Nguyen, Mahesh Chandra Mukkamala, and Matthias Hein. On the loss
landscape of a class of deep neural networks with no bad local valleys. arXiv
preprint arXiv:1809.10749, 2018. Cited on page 311.

[199] Quynh Nguyen. On connected sublevel sets in deep learning. In International


Conference on Machine Learning, pages 4790–4799. PMLR, 2019. Cited on
page 311.

[200] Rohith Kuditipudi, Xiang Wang, Holden Lee, Yi Zhang, Zhiyuan Li, Wei Hu,
Sanjeev Arora, and Rong Ge. Explaining landscape connectivity of low-cost
solutions for multilayer nets. arXiv preprint arXiv:1906.06247, 2019. Cited on
page 311.

523
[201] Yaim Cooper. The loss landscape of overparameterized neural networks. arXiv
preprint arXiv:1804.10200, 2018. Cited on page 312.

[202] Y Cooper. The critical locus of overparameterized neural networks. arXiv


preprint arXiv:2005.04210, 2020. Cited on page 312.

[203] Benjamin Fehrman, Benjamin Gess, and Arnulf Jentzen. Convergence rates
for the stochastic gradient descent method for non-convex objective functions.
Journal of Machine Learning Research, 21, 2020. Cited on pages 312, 320 & 439.

[204] Yann A LeCun, Léon Bottou, Genevieve B Orr, and Klaus-Robert Müller.
Efficient backprop. In Neural networks: Tricks of the trade, pages 9–48. Springer,
2012. Cited on page 312.

[205] Elad Hoffer, Itay Hubara, and Daniel Soudry. Train longer, generalize better:
closing the generalization gap in large batch training of neural networks. arXiv
preprint arXiv:1705.08741, 2017. Cited on page 312.

[206] Zhanxing Zhu, Jingfeng Wu, Bing Yu, Lei Wu, and Jinwen Ma. The anisotropic
noise in stochastic gradient descent: Its behavior of escaping from sharp minima
and regularization effects. arXiv preprint arXiv:1803.00195, 2018. Cited on
page 312.

[207] Jian Li, Xuanyuan Luo, and Mingda Qiao. On generalization error bounds of
noisy gradient methods for non-convex learning. arXiv preprint arXiv:1902.00621,
2019. Cited on page 312.

[208] Yeming Wen, Kevin Luk, Maxime Gazeau, Guodong Zhang, Harris Chan, and
Jimmy Ba. Interplay between optimization and generalization of stochastic
gradient descent with covariance noise. arXiv preprint arXiv:1902.08234, 2019.
Cited on page 312.

[209] Jingfeng Wu, Wenqing Hu, Haoyi Xiong, Jun Huan, Vladimir Braverman, and
Zhanxing Zhu. On the noisy gradient descent that generalizes as sgd. In
International Conference on Machine Learning, pages 10367–10376. PMLR,
2020. Cited on page 312.

[210] Jianqing Fan, Zhuoran Yang, and Mengxin Yu. Understanding implicit reg-
ularization in over-parameterized nonlinear statistical model. arXiv preprint
arXiv:2007.08322, 2020. Cited on page 312.

[211] Peng Zhao, Yun Yang, and Qiao-Chu He. Implicit regularization via hadamard
product over-parametrization in high-dimensional linear regression. arXiv
preprint arXiv:1903.09367, 2019. Cited on page 312.

[212] Amit Daniely. Sgd learns the conjugate kernel class of the network. arXiv
preprint arXiv:1702.08503, 2017. Cited on page 312.

524
[213] Yuanzhi Li and Yingyu Liang. Learning overparameterized neural networks via
stochastic gradient descent on structured data. arXiv preprint arXiv:1808.01204,
2018. Cited on pages 312 & 400.

[214] Qianxiao Li, Cheng Tai, and E Weinan. Stochastic modified equations and dy-
namics of stochastic gradient algorithms i: Mathematical foundations. The
Journal of Machine Learning Research, 20(1):1474–1520, 2019. Cited on
pages 312 & 401.

[215] Alex Krizhevsky. One weird trick for parallelizing convolutional neural networks.
arXiv preprint arXiv:1404.5997, 2014. Cited on page 312.

[216] Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski,
Aapo Kyrola, Andrew Tulloch, Yangqing Jia, and Kaiming He. Accurate, large
minibatch sgd: Training imagenet in 1 hour. arXiv preprint arXiv:1706.02677,
2017. Cited on page 312.

[217] Zhiyuan Li, Sadhika Malladi, and Sanjeev Arora. On the validity of modeling sgd
with stochastic differential equations (sdes). arXiv preprint arXiv:2102.12470,
2021. Cited on page 312.

[218] Zeke Xie, Issei Sato, and Masashi Sugiyama. A diffusion theory for deep learning
dynamics: Stochastic gradient descent escapes from sharp minima exponentially
fast. arXiv preprint arXiv:2002.03495, 2020. Cited on page 312.

[219] Stephan Wojtowytsch. Stochastic gradient descent with noise of machine learning
type. part ii: Continuous time analysis. arXiv preprint arXiv:2106.02588, 2021.
Cited on page 312.

[220] Ioannis Karatzas and Steven Shreve. Brownian motion and stochastic calculus,
volume 113. springer, 2014. Cited on page 313.

[221] Patrick Billingsley. Convergence of probability measures. John Wiley & Sons,
2013. Cited on page 313.

[222] David Pollard. Convergence of stochastic processes. Springer Science & Business
Media, 2012. Cited on pages 313 & 319.

[223] Augustin Banyaga and David Hurtubise. Lectures on Morse homology, volume 29.
Springer Science & Business Media, 2013. Cited on pages 320 & 360.

[224] Stanislaw Lojasiewicz. A topological property of real analytic subsets. Coll.


du CNRS, Les équations aux dérivées partielles, 117(87-89):2, 1963. Cited on
page 331.

[225] Boris T Polyak. Gradient methods for solving equations and inequalities. USSR
Computational Mathematics and Mathematical Physics, 4(6):17–32, 1964. Cited
on page 331.

525
[226] Joel A Tropp. Convex recovery of a structured signal from independent random
linear measurements. In Sampling Theory, a Renaissance, pages 67–101. Springer,
2015. Cited on pages 332, 363, 365, 367 & 368.

[227] K. J. Falconer. Differentiation of the limit mapping in a dynamical system.


Journal of the London Mathematical Society, s2-27(2):356–372, 1983. ISSN
0024-6107. doi: 10.1112/jlms/s2-27.2.356. Cited on page 334.

[228] Ward Whitt. Stochastic-process limits: an introduction to stochastic-process


limits and their application to queues. Springer Science & Business Media, 2002.
Cited on page 339.

[229] Elton P Hsu. Stochastic analysis on manifolds. Number 38. American Mathe-
matical Soc., 2002. Cited on page 343.

[230] Manfredo P Do Carmo. Riemannian geometry. Springer Science & Business


Media, 2013. Cited on page 355.

[231] Andrew Holbrook. Differentiating the pseudo determinant. Linear Algebra and
its Applications, 548:293–304, 2018. Cited on page 355.

[232] Jeff Kahn, János Komlós, and Endre Szemerédi. On the probability that a
random±1-matrix is singular. Journal of the American Mathematical Society, 8
(1):223–240, 1995. Cited on page 361.

[233] Lawrence M. Perko. Differential equations and dynamical systems. 2001. Cited
on page 378.

[234] Jeremy Cohen, Simran Kaur, Yuanzhi Li, J Zico Kolter, and Ameet Talwalkar.
Gradient descent on neural networks typically occurs at the edge of stability.
In International Conference on Learning Representations, 2021. URL https://
openreview.net/forum?id=jh-rTtvkGeM. Cited on pages 394, 395, 396 & 420.

[235] Kwangjun Ahn, Jingzhao Zhang, and Suvrit Sra. Understanding the unstable
convergence of gradient descent. arXiv preprint arXiv:2204.01050, 2022. Cited
on page 395.

[236] Prajit Ramachandran, Barret Zoph, and Quoc V Le. Searching for activation
functions. arXiv preprint arXiv:1710.05941, 2017. Cited on page 398.

[237] S Hochreiter and J Schmidhuber. Flat minima. Neural Computation, 1997.


Cited on page 399.

[238] Yiding Jiang, Behnam Neyshabur, Hossein Mobahi, Dilip Krishnan, and
Samy Bengio. Fantastic generalization measures and where to find them.
In International Conference on Learning Representations, 2020. URL https:
//openreview.net/forum?id=SJgIPJBFvH. Cited on page 399.

526
[239] Laurent Dinh, Razvan Pascanu, Samy Bengio, and Yoshua Bengio. Sharp
minima can generalize for deep nets. In International Conference on Machine
Learning, pages 1019–1028. PMLR, 2017. Cited on page 399.

[240] Mingyang Yi, Qi Meng, Wei Chen, Zhi-ming Ma, and Tie-Yan Liu. Positively
scale-invariant flatness of relu neural networks. arXiv preprint arXiv:1903.02237,
2019. Cited on page 400.

[241] Mingyang Yi, Huishuai Zhang, Wei Chen, Zhi-Ming Ma, and Tie-Yan Liu. Bn-
invariant sharpness regularizes the training model to better generalization. In
Proceedings of the Twenty-Eighth International Joint Conference on Artificial
Intelligence, IJCAI-19, pages 4164–4170. International Joint Conferences on
Artificial Intelligence Organization, 7 2019. Cited on page 400.

[242] Yusuke Tsuzuku, Issei Sato, and Masashi Sugiyama. Normalized flat minima:
Exploring scale invariant definition of flat minima for neural networks using PAC-
Bayesian analysis. In Hal Daumé III and Aarti Singh, editors, Proceedings of the
37th International Conference on Machine Learning, volume 119 of Proceedings
of Machine Learning Research, pages 9636–9647. PMLR, 13–18 Jul 2020. Cited
on page 400.

[243] Akshay Rangamani, Nam H. Nguyen, Abhishek Kumar, Dzung Phan, Sang Peter
Chin, and Trac D. Tran. A scale invariant measure of flatness for deep network
minima. In ICASSP 2021 - 2021 IEEE International Conference on Acoustics,
Speech and Signal Processing (ICASSP), pages 1680–1684, 2021. Cited on
page 400.

[244] Mingyang Yi, Qi Meng, Wei Chen, and Zhi-Ming Ma. Towards accelerat-
ing training of batch normalization: A manifold perspective. arXiv preprint
arXiv:2101.02916, 2021. Cited on page 400.

[245] Jungmin Kwon, Jeongseop Kim, Hyunseo Park, and In Kwon Choi. Asam:
Adaptive sharpness-aware minimization for scale-invariant learning of deep
neural networks. In Marina Meila and Tong Zhang, editors, Proceedings of the
38th International Conference on Machine Learning, volume 139 of Proceedings
of Machine Learning Research, pages 5905–5914. PMLR, 18–24 Jul 2021. Cited
on page 400.

[246] Haowei He, Gao Huang, and Yang Yuan. Asymmetric valleys: Beyond sharp and
flat local minima. arXiv preprint arXiv:1902.00744, 2019. Cited on page 400.

[247] Suriya Gunasekar, Jason Lee, Daniel Soudry, and Nathan Srebro. Implicit bias
of gradient descent on linear convolutional networks. In Advances in Neural
Information Processing Systems, 2018. Cited on page 400.

[248] Lei Wu, Zhanxing Zhu, et al. Towards understanding generalization of deep
learning: Perspective of loss landscapes. arXiv preprint arXiv:1706.10239, 2017.
Cited on page 401.
527
[249] Chao Ma and Lexing Ying. On linear stability of SGD and input-smoothness
of neural networks. In A. Beygelzimer, Y. Dauphin, P. Liang, and J. Wortman
Vaughan, editors, Advances in Neural Information Processing Systems, 2021.
Cited on page 401.

[250] David Barrett and Benoit Dherin. Implicit gradient regularization. In Interna-
tional Conference on Learning Representations, 2021. Cited on page 401.

[251] Yuqing Wang, Minshuo Chen, Tuo Zhao, and Molei Tao. Large learning
rate tames homogeneity: Convergence and balancing effect. arXiv preprint
arXiv:2110.03677, 2021. Cited on page 401.

[252] Chi Jin, Rong Ge, Praneeth Netrapalli, Sham M. Kakade, and Michael I. Jordan.
How to escape saddle points efficiently. In Proceedings of the 34th International
Conference on Machine Learning, pages 1724–1732, 2017. Cited on page 415.

[253] Karen Simonyan and Andrew Zisserman. Very deep convolutional networks for
large-scale image recognition. arXiv preprint arXiv:1409.1556, 2014. Cited on
page 417.

[254] Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. Cifar-10 (canadian insti-
tute for advanced research). URL http://www.cs.toronto.edu/~kriz/cifar.
html. Cited on page 417.

[255] Y-Lan Boureau, Jean Ponce, and Yann LeCun. A theoretical analysis of feature
pooling in visual recognition. In Proceedings of the 27th international conference
on machine learning (ICML-10), pages 111–118, 2010. Cited on page 417.

[256] Yann LeCun and Corinna Cortes. MNIST handwritten digit database. 2010.
URL http://yann.lecun.com/exdb/mnist/. Cited on page 419.

[257] E. Hairer, S. P. Nørsett, and G. Wanner. Solving Ordinary Differential Equations


I (2nd Revised. Ed.): Nonstiff Problems. Springer-Verlag, Berlin, Heidelberg,
1993. ISBN 0387566708. Cited on page 458.

[258] Jan R Magnus. On differentiating eigenvalues and eigenvectors. Econometric


theory, 1(2):179–191, 1985. Cited on page 493.

[259] Roger A Horn and Charles R Johnson. Matrix analysis. Cambridge university
press, 2012. Cited on page 493.

[260] Chandler Davis and William Morton Kahan. The rotation of eigenvectors by a
perturbation. iii. SIAM Journal on Numerical Analysis, 7(1):1–46, 1970. Cited
on page 494.

528

You might also like