L12_optim__slides
L12_optim__slides
Sebastian Raschka
http://stat.wisc.edu/~sraschka/teaching
Lecture 12
Improving Gradient Descent-based
with Applications in Python
Optimization
Sebastian Raschka STAT 453: Intro to Deep Learning 1
Overview: Additional Tricks for
Neural Network Training (Part 2/2)
In the worst-case scenario, the test set may not contain any instance of a minority class at all. Thus,
a recommended practice is to divide the dataset in a stratified fashion. Here, stratification simply
means that we randomly split a dataset such that each class
Sebastian is correctly represented
Raschka STAT in 453:
the resulting
Intro to Deep Learning
subsets (the training and the test set) – in other words, stratification is an approach to maintain the 8
batchsize-1024.ipynb batchsize-64.ipynb
1) Exponential Decay:
k·t
⌘t := ⌘0 · e
<latexit sha1_base64="y1QN6zUteveq7skV4o5wfP7j9R0=">AAACDnicbZC7SgNBFIZn4y3G26qlzWAI2Bh2o6AIQtDGMoK5QBLD7OQkGTJ7YeasEJY8gY2vYmOhiK21nW/jJNlCE38Y+PjPOZw5vxdJodFxvq3M0vLK6lp2PbexubW9Y+/u1XQYKw5VHspQNTymQYoAqihQQiNSwHxPQt0bXk/q9QdQWoTBHY4iaPusH4ie4AyN1bELLUDWQXpxSafk0BbvhkjhPjkepozjjp13is5UdBHcFPIkVaVjf7W6IY99CJBLpnXTdSJsJ0yh4BLGuVasIWJ8yPrQNBgwH3Q7mZ4zpgXjdGkvVOYFSKfu74mE+VqPfM90+gwHer42Mf+rNWPsnbcTEUQxQsBni3qxpBjSSTa0KxRwlCMDjCth/kr5gCnG0SSYMyG48ycvQq1UdE+KpdvTfPkqjSNLDsghOSIuOSNlckMqpEo4eSTP5JW8WU/Wi/VufcxaM1Y6s0/+yPr8AdcImrk=</latexit>
3) Inverse decay:
⌘0
⌘t :=
1+k·t
<latexit sha1_base64="CMVLJdUk7lm/xR6VZzLL2r/6azA=">AAACD3icbVDLSsNAFJ34rPUVdelmsCiCUJIqKIJQdOOygn1AU8JkMmmHTiZh5kYooX/gxl9x40IRt27d+TdOHwttPXDhzDn3MveeIBVcg+N8WwuLS8srq4W14vrG5ta2vbPb0EmmKKvTRCSqFRDNBJesDhwEa6WKkTgQrBn0b0Z+84EpzRN5D4OUdWLSlTzilICRfPvIY0B8wJdX2IsUofn47QxzF5/gPvZomACGoW+XnLIzBp4n7pSU0BQ13/7ywoRmMZNABdG67TopdHKigFPBhkUv0ywltE+6rG2oJDHTnXx8zxAfGiXEUaJMScBj9fdETmKtB3FgOmMCPT3rjcT/vHYG0UUn5zLNgEk6+SjKBIYEj8LBIVeMghgYQqjiZldMe8TEAibCognBnT15njQqZfe0XLk7K1Wvp3EU0D46QMfIReeoim5RDdURRY/oGb2iN+vJerHerY9J64I1ndlDf2B9/gCRB5sZ</latexit>
Size
D ON ’ T D ECAY THE L EARNING R ATE ,
I NCREASE THE BATCH S IZE
Samuel L. Smith⇤, Pieter-Jan Kindermans⇤, Chris Ying & Quoc V. Le
Google Brain
{slsmith, pikinder, chrisying, qvl}@google.com
A BSTRACT
It is common practice to decay the learning rate. Here we show one can usually
obtain the same learning curve on both training and test sets by instead increasing
the batch size during training. This procedure is successful for stochastic gradi-
ent descent (SGD), SGD with momentum, Nesterov momentum, and Adam. It
reaches equivalent test accuracies after the same number of training epochs, but
with fewer parameter updates, leading to greater parallelism and shorter training
times. We can further reduce the number of parameter updates by increasing the
learning rate ✏ and scaling the batch size B / ✏. Finally, one can increase the mo-
mentum coefficient m and scale B / 1/(1 m), although this tends to slightly
reduce the test accuracy. Crucially, our techniques allow us to repurpose existing
training schedules for large batch training with no hyper-parameter tuning. We
train ResNet-50 on ImageNet to 76.1% validation accuracy in under 30 minutes.
1 I
Smith, S. L., Kindermans, P. J., Ying, C., & Le, Q. V. (2017). Don't decay the learning rate, increase the batch size. arXiv preprint arXiv:1711.00489.
NTRODUCTION
Sebastian Raschka STAT 453: Intro to Deep Learning 15
Relationship between Learning Rate and Batch
Published as a conference paper at ICLR 2018
Size
(a) (b)
Figure 6: Inception-ResNet-V2 on ImageNet. Increasing the batch size during training achieves
similar results to decaying the learning rate, but it reduces the number of parameter updates from
just over 14000 to below 6000. We run each experiment twice to illustrate the variance.
Smith, S. L., Kindermans, P. J., Ying, C., & Le, Q. V. (2017). Don't decay the learning rate, increase the batch size. arXiv preprint arXiv:1711.00489.
Option 1. Just call your own function at the end of each epoch:
Source: https://pytorch.org/docs/stable/optim.html
torch.manual_seed(RANDOM_SEED
model = MLP(num_features=28*28
num_hidden=100
num_classes=10
model = model.to(DEVICE
#################################
### LEARNING RATE SCHEDULER
#################################
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
gamma=0.1
...
cost.backward(
minibatch_cost.append(cost
### UPDATE MODEL PARAMETERS
optimizer.step(
### LOGGING
if not batch_idx % 50
print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f'
%(epoch+1, NUM_EPOCHS, batch_idx,
len(train_loader), cost)
##########################
### Update Learning Rate
scheduler.step() # don't have to do it every epoch!
##########################
model.eval()
3x3 conv, 64
sizes, they are performed with a stride of 2.
3x3 conv, 64
3x3 conv, 64
3.4. Implementation
3x3 conv, 128, /2 Our implementation for ImageNet follows the practice
3x3 conv, 128 in [21, 40]. The image is resized with its shorter side ran-
3x3 conv, 128 domly sampled in [256, 480] for scale augmentation [40].
3x3 conv, 128 A 224×224 crop is randomly sampled from an image or its
3x3 conv, 128 horizontal flip, with the per-pixel mean subtracted [21]. The
3x3 conv, 128 standard color augmentation in [21] is used. We adopt batch
3x3 conv, 128 normalization (BN) [16] right after each convolution and
3x3 conv, 128
before activation, following [16]. We initialize the weights
3x3 conv, 256, /2
as in [12] and train all plain/residual nets from scratch. We
3x3 conv, 256
use SGD with a mini-batch size of 256. The learning rate
3x3 conv, 256
starts from 0.1 and is divided by 10 when the error plateaus,
and the models are trained for up to 60 × 104 iterations. We
3x3 conv, 256
use a weight decay of 0.0001 and a momentum of 0.9. We
3x3 conv, 256
do not use dropout [13], following the practice in [16].
3x3 conv, 256
In testing, for comparison studies we adopt the standard
3x3 conv, 256
10-crop testing [21]. For best results, we adopt the fully-
3x3 conv, 256
convolutional form as in [40, 12], and average the scores
3x3 conv, 256
at multiple scales (images are resized such that the shorter
3x3 conv, 256 side is in {224, 256, 384, 480, 640}).
http://openaccess.thecvf.com/content_cvpr_2016/html/He_Deep_Residual_Learning_CVPR_2016_paper.html
3x3 conv, 256
...
Source: https://en.wikipedia.org/wiki/Momentum
Qian, N. (1999). On the momentum term in gradient descent learning algorithms. Neural Networks : The O cial Journal of
the International Neural Network Society, 12(1), 145–151. http://doi.org/10.1016/S0893-6080(98)00116-6
ffi
Training with "Momentum"
Key take-away:
Not only move in the (opposite) direction of the gradient, but also
move in the "averaged" direction of the last few updates
Usually, we choose a
momentum rate between
0.9 and 0.999; you can Regular partial derivative/
think of it as a "friction" or gradient multiplied by
"dampening" parameter learning rate at current
time step t
Weight update using the velocity vector:
wi,j (t + 1) := wi,j (t)
<latexit sha1_base64="YXpNve4YJpwcqXgxHZOK0YTWgHI=">AAACIHicbZDJSgNBEIZ74hbjFvXopTEICWqYiUJEEIJ68BjBLJAMQ0+nY9r0LHTXKGHIo3jxVbx4UERv+jR2lkMWf2j4+aqK6vrdUHAFpvljJBYWl5ZXkquptfWNza309k5VBZGkrEIDEci6SxQT3GcV4CBYPZSMeK5gNbd7NajXHplUPPDvoBcy2yP3Pm9zSkAjJ12Mn/pOzI/wQz8Lh1YOn1/gSZTDx7h5zQSQaeykM2beHArPG2tsMmisspP+brYCGnnMByqIUg3LDMGOiQROBeunmpFiIaFdcs8a2vrEY8qOhwf28YEmLdwOpH4+4CGdnIiJp1TPc3WnR6CjZmsD+F+tEUH7zI65H0bAfDpa1I4EhgAP0sItLhkF0dOGUMn1XzHtEEko6ExTOgRr9uR5Uy3krZN84fY0U7ocx5FEe2gfZZGFiqiEblAZVRBFz+gVvaMP48V4Mz6Nr1FrwhjP7KIpGb9/h5GguQ==</latexit>
wi,j (t)
Qian, N. (1999). On the momentum term in gradient descent learning algorithms. Neural Networks : The O cial Journal of
the International Neural Network Society, 12(1), 145–151. http://doi.org/10.1016/S0893-6080(98)00116-6
ffi
Source: https://distill.pub/2017/momentum/
Sebastian Raschka STAT 453: Intro to Deep Learning 32
Combining adaptive learning
rates with momentum
Key take-aways:
Key take-aways:
Step 1: De ne a local gain (g) for each weight (initialized with g=1)
@L
wi,j := ⌘ · gi,j ·
<latexit sha1_base64="NoawAABpSt+8PKqkCagS8PvKu/A=">AAACRHicbZDLahsxFIY1Sdq67s1JltmImkIXxcy4hZZCwKRddJGFA/UFPMackc/YijUXpDMNZpiH66YP0F2fIJsuEkq2IRpfSGv3gODXdy46+oNUSUOu+8vZ2d178PBR5XH1ydNnz1/U9g+6Jsm0wI5IVKL7ARhUMsYOSVLYTzVCFCjsBbNPZb73DbWRSfyV5ikOI5jEMpQCyKJRbeB/RkXAL0a5fMPPC/7xmPtogS/GCfHJGi+vfqhB5H4KmiQo7kdAUwEqPy2Ke7oeVYxqdbfhLoJvC28l6mwV7VHtpz9ORBZhTEKBMQPPTWmYl3OFwqLqZwZTEDOY4MDKGCI0w3xhQsFfWTLmYaLtiYkv6N8dOUTGzKPAVpZrm81cCf+XG2QUfhjmMk4zwlgsHwozxSnhpaN8LDUKUnMrQGhpd+ViCtYosr5XrQne5pe3RbfZ8N42mmfv6q2TlR0VdsRestfMY+9Zi31hbdZhgn1nl+yKXTs/nN/OH+dmWbrjrHoO2T/h3N4BrbGxZw==</latexit>
@wi,j
Step 2:
Note that
If gradient is consistent multiplying by a factor has a larger
impact if gains are large, compared
<latexit sha1_base64="zSTJaRlBybcJChwhNIJbeK2n2UQ=">AAACDHicbVDLSgMxFM3UV62vqks3wSK0qGWmCoogFN24rGAf0A4lk6ZtbOZBckcoQz/Ajb/ixoUibv0Ad/6NmekstPVA4Nxz7uXmHicQXIFpfhuZhcWl5ZXsam5tfWNzK7+901B+KCmrU1/4suUQxQT3WB04CNYKJCOuI1jTGV3HfvOBScV97w7GAbNdMvB4n1MCWurmC4NuxI/w/aQIJXxxiZMyro6tEj7EHYcB0V1m2UyA54mVkgJKUevmvzo9n4Yu84AKolTbMgOwIyKBU8EmuU6oWEDoiAxYW1OPuEzZUXLMBB9opYf7vtTPA5yovyci4io1dh3d6RIYqlkvFv/z2iH0z+2Ie0EIzKPTRf1QYPBxnAzucckoiLEmhEqu/4rpkEhCQeeX0yFYsyfPk0albJ2UK7enhepVGkcW7aF9VEQWOkNVdINqqI4oekTP6BW9GU/Gi/FufExbM0Y6s4v+wPj8AWQUmKk=</latexit>
gi,j (t) := gi,j (t 1) + to adding a term
(dampening e ect if updates oscillate
else in the wrong direction)
<latexit sha1_base64="VgTJ5W8ysLtb2+R/2t3af7rlx90=">AAACFHicbVDLSgMxFM3UV62vUZdugkVo0ZaZKiiCUHTjsoJ9QDsMmTRtYzMPkjtCGfoRbvwVNy4UcevCnX9jOu1CqwcC555zLzf3eJHgCizry8gsLC4tr2RXc2vrG5tb5vZOQ4WxpKxOQxHKlkcUEzxgdeAgWCuSjPieYE1veDXxm/dMKh4GtzCKmOOTfsB7nBLQkmse9t2EH+G7cQGK+PwCp+WkKtlF3KHdEHDBLnU8BqTomnmrbKXAf4k9I3k0Q801PzvdkMY+C4AKolTbtiJwEiKBU8HGuU6sWETokPRZW9OA+Ew5SXrUGB9opYt7odQvAJyqPycS4is18j3d6RMYqHlvIv7ntWPonTkJD6IYWECni3qxwBDiSUK4yyWjIEaaECq5/iumAyIJBZ1jTodgz5/8lzQqZfu4XLk5yVcvZ3Fk0R7aRwVko1NURdeohuqIogf0hF7Qq/FoPBtvxvu0NWPMZnbRLxgf38fmm4M=</latexit>
gi,j (t) := gi,j (t 1) · (1 )
• Unpublished algorithm by Geo Hinton (but very popular) based on Rprop [1]
• Very similar to another concept called AdaDelta
• Concept: divide learning rate by an exponentially decreasing moving average of
the squared gradients
• This takes into account that gradients can vary widely in magnitude
• Here, RMS stands for "Root Mean Squared"
• Also, damps oscillations like momentum (but in practice, works a bit better)
[1] Igel, Christian, and Michael Hüsken. "Improving the Rprop learning algorithm." Proceedings of the Second
International ICSC Symposium on Neural Computation (NC 2000). Vol. 2000. ICSC Academic Press, 2000.
✓ ◆2
@L
M eanSquare(wi,j , t) := · M eanSquare(wi,j , t 1) + (1 )
<latexit sha1_base64="Z+HVulkxbXVLEb22LThAtsPJsEs=">AAACbXicbVHbitRAEO3E2zreouKDF6RwEBPcHZJRUARh0RcfFFZ0dhcm41Dp6cy22+nE7ooyhLz5hb75C774C3ayI+juFjQcTp1DVZ3OKiUtxfFPzz9z9tz5CxsXB5cuX7l6Lbh+Y9eWteFiwktVmv0MrVBSiwlJUmK/MgKLTIm97PB119/7KoyVpf5Iq0rMClxqmUuO5Kh58P2dQP3hS41GhN/mjdyEz+0mUAQvXkKaCUJI+aIkOF22lUTwGMJkq5dGziGXyzDNDfImrdCQRAVpgXTAUTVv27b5a4aQoraXR5/G82AYj+K+4CRI1mDI1rUzD36ki5LXhdDEFVo7TeKKZk03kCvRDtLaigr5IS7F1EGNhbCzpk+rhYeOWUBeGvc0Qc/+62iwsHZVZE7ZbW6P9zrytN60pvz5rJG6qklofjQorxVQCV30sJBGcFIrB5Ab6XYFfoAuK3IfNHAhJMdPPgl2x6PkyWj8/ulw+9U6jg12lz1gIUvYM7bN3rAdNmGc/fIC77Z3x/vt3/Lv+fePpL639txk/5X/6A8nOrer</latexit>
wi,j (t)
✓q ◆
@L
wi,j (t) := wi,j (t) ⌘· / M eanSquare (wi,j , t) + ✏
<latexit sha1_base64="A5NHqNA9il5pfKg4MG0lXiIwqog=">AAACiXicbVFdb9MwFHUCjNHxUeCRlysqpCJKl4yvaRLSxF54AGkIuk2qq8pxb1ozx8nsG1AV5b/wm3jj3+C0QZSNK1k6Puce+frcpNDKURT9CsJr129s3dy+1dm5fefuve79BycuL63Ekcx1bs8S4VArgyNSpPGssCiyRONpcn7U6Kff0DqVmy+0LHCSiblRqZKCPDXt/vg+rdQAvtZ9egoHbzduzzmSAC5nOQFPrZAVL4QlJTTwTNBCCl19qOu/7Ia3hl2uMaU+dxeWqo+AIMDAZ7iA0iMLuJb/WAZA3Kr5wjufcSyc0rlpiWm3Fw2jVcFVELegx9o6nnZ/8lkuywwNSS2cG8dRQZOqGVJqrDu8dFgIeS7mOPbQiAzdpFolWcMTz8wgza0/hmDFbjoqkTm3zBLf2WTgLmsN+T9tXFK6P6mUKUpCI9cPpaUGyqFZC8yURUl66YGQVvlZQS6ET5388jo+hPjyl6+Ck71h/GK49+ll7/BdG8c2e8Qesz6L2Rt2yN6zYzZiMtgKBsGr4HW4E8bhfniwbg2D1vOQ/VPh0W9mB8Iz</latexit>
@wi,j (t)
where beta is typically between 0.9 and 0.999 small epsilon term to
avoid division by zero
Momentum-like term:
mt <latexit sha1_base64="BaUV/ky/esFoJzWPohpB2BYsCMs=">AAAB7nicdVDLSsNAFJ3UV62vqks3g0VwY0jS0NZd0Y3LCvYBbSiT6aQdOjMJMxOhhH6EGxeKuPV73Pk3TtoKKnrgwuGce7n3njBhVGnH+bAKa+sbm1vF7dLO7t7+QfnwqKPiVGLSxjGLZS9EijAqSFtTzUgvkQTxkJFuOL3O/e49kYrG4k7PEhJwNBY0ohhpI3X5MNMX7nxYrjj2ZaPm+TXo2I5Tdz03J17dr/rQNUqOClihNSy/D0YxTjkRGjOkVN91Eh1kSGqKGZmXBqkiCcJTNCZ9QwXiRAXZ4tw5PDPKCEaxNCU0XKjfJzLElZrx0HRypCfqt5eLf3n9VEeNIKMiSTUReLkoShnUMcx/hyMqCdZsZgjCkppbIZ4gibA2CZVMCF+fwv9Jx7Pdqu3d+pXm1SqOIjgBp+AcuKAOmuAGtEAbYDAFD+AJPFuJ9Wi9WK/L1oK1mjkGP2C9fQJjC4+c</latexit>
1
@L original momentum term
wi,j (t) := ↵ · wi,j (t 1) + ⌘ · (t)
@wi,j
mt
<latexit sha1_base64="z07XiIX0nYQ7L9u4xfm0qJW6RzA=">AAACZXicbVFLaxRBEO4ZX+tqdBPFiwcLFyHBuMwkAUUQgnrw4CGCmwR2lqWmtybbpudBd41haedPevPqxb9hz2bwkaSg4eOrr15fp5VWlqPoRxBeu37j5q3e7f6du2v37g/WNw5tWRtJY1nq0hynaEmrgsasWNNxZQjzVNNRevquzR99JWNVWXzmZUXTHE8KlSmJ7KnZ4FvynjQjuLNm5tQ2fGk2eQtev4EEdbVASOS8ZLhC9CLegueQEP/RZAalSyo0rFBDkiMvJGr3sWn+su6sa9C0c2aDYTSKVgGXQdyBoejiYDb4nsxLWedUsNRo7SSOKp66trnU1PST2lKF8hRPaOJhgTnZqVu51MAzz8whK41/BcOK/bfCYW7tMk+9st3dXsy15FW5Sc3Zq6lTRVUzFfJ8UFZr4BJay2GuDEnWSw9QGuV3BblA7xb7j+l7E+KLJ18GhzujeHe082lvuP+2s6MnHounYlPE4qXYFx/EgRgLKX4GvWA92Ah+hWvhw/DRuTQMupoH4r8In/wGv6C2kA==</latexit>
@L
<latexit sha1_base64="vwwizZRkV2c/UM2DKTGBtXofbIw=">AAAB7HicdVBNS8NAEN34WetX1aOXxSJ4Ckka2norevFYwbSFNpTNdtMu3d2E3Y1QQn+DFw+KePUHefPfuGkrqOiDgcd7M8zMi1JGlXacD2ttfWNza7u0U97d2z84rBwdd1SSSUwCnLBE9iKkCKOCBJpqRnqpJIhHjHSj6XXhd++JVDQRd3qWkpCjsaAxxUgbKeDDXM+HlapjXzbrnl+Hju04DddzC+I1/JoPXaMUqIIV2sPK+2CU4IwToTFDSvVdJ9VhjqSmmJF5eZApkiI8RWPSN1QgTlSYL46dw3OjjGCcSFNCw4X6fSJHXKkZj0wnR3qifnuF+JfXz3TcDHMq0kwTgZeL4oxBncDicziikmDNZoYgLKm5FeIJkghrk0/ZhPD1KfyfdDzbrdnerV9tXa3iKIFTcAYugAsaoAVuQBsEAAMKHsATeLaE9Wi9WK/L1jVrNXMCfsB6+wSGqY8q</latexit>
mt := ↵ · mt 1 + (1 ↵) · (t)
@wi,j <latexit sha1_base64="Nqz8y0Sl/lUTBdPy6m+vgbbZJ80=">AAACTHicbZBPaxRBEMV7NhqT9d9qjl6KLMIGzTITBYMQCHrx4CEBNwnsLENNb0+2TffM0F1jWJr5gLnkkJufwouHiAj27A6oiQUNj/equqt/aamkpTD8GnRW7txdvbe23r3/4OGjx70nT49sURkuRrxQhTlJ0QolczEiSUqclEagTpU4Ts/eN/nxF2GsLPJPNC/FRONpLjPJkbyV9LhOCN7uQYyqnCHEfFoQ6MTRdlTDC4BBtL2MttoszgxyF5doSKKCWCPNOCr3sa7/uO48cfIlfK7rekBbSa8fDsNFwW0RtaLP2jpIelfxtOCVFjlxhdaOo7CkiWsu50rU3biyokR+hqdi7GWOWtiJW8Co4bl3ppAVxp+cYOH+PeFQWzvXqe9sdrc3s8b8XzauKNudOJmXFYmcLx/KKgVUQEMWptIITmruBXIj/a7AZ+hpkeff9RCim1++LY52htGr4c7h6/7+uxbHGnvGNtmARewN22cf2AEbMc4u2Dd2zX4El8H34Gfwa9naCdqZDfZPdVZ/A10DsoM=</latexit>
Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980.
@L
mt := ↵ · mt 1 + (1 ↵) · (t)
<latexit sha1_base64="Nqz8y0Sl/lUTBdPy6m+vgbbZJ80=">AAACTHicbZBPaxRBEMV7NhqT9d9qjl6KLMIGzTITBYMQCHrx4CEBNwnsLENNb0+2TffM0F1jWJr5gLnkkJufwouHiAj27A6oiQUNj/equqt/aamkpTD8GnRW7txdvbe23r3/4OGjx70nT49sURkuRrxQhTlJ0QolczEiSUqclEagTpU4Ts/eN/nxF2GsLPJPNC/FRONpLjPJkbyV9LhOCN7uQYyqnCHEfFoQ6MTRdlTDC4BBtL2MttoszgxyF5doSKKCWCPNOCr3sa7/uO48cfIlfK7rekBbSa8fDsNFwW0RtaLP2jpIelfxtOCVFjlxhdaOo7CkiWsu50rU3biyokR+hqdi7GWOWtiJW8Co4bl3ppAVxp+cYOH+PeFQWzvXqe9sdrc3s8b8XzauKNudOJmXFYmcLx/KKgVUQEMWptIITmruBXIj/a7AZ+hpkeff9RCim1++LY52htGr4c7h6/7+uxbHGnvGNtmARewN22cf2AEbMc4u2Dd2zX4El8H34Gfwa9naCdqZDfZPdVZ/A10DsoM=</latexit>
@wi,j
RMSProp term: ✓ ◆2
@L
r := · M eanSquare(wi,j , t 1) + (1 )
<latexit sha1_base64="35mB82DfqUC86OS0JSfG/tTONkc=">AAACYHicbVFNb9NAEF27fITQ0hRucBkRITmijeyAVFQJqSoXDiAVQdpKcYjGm3W6dL12d8etIst/sjcOXPglrN0goGWklZ7em7ez8zYplLQUht89f+3O3Xv3Ow+6D9c3Hm32th4f2bw0XIx5rnJzkqAVSmoxJklKnBRGYJYocZycvWv04wthrMz1F1oWYprhQstUciRHzXqXBvbeQpwIQoj5PCf4KFB/Pi/RiOByVslt+FZvA+1EA3gJQbTTtg6cQy4WQZwa5FVcoCGJCuIM6ZSjqj7U9R/29y0Q0KBufYOvo1mvHw7DtuA2iFagz1Z1OOtdxfOcl5nQxBVaO4nCgqZVM4MrUXfj0ooC+RkuxMRBjZmw06oNqIYXjplDmht3NEHL/u2oMLN2mSWus1nB3tQa8n/apKT0zbSSuihJaH49KC0VUA5N2jCXRnBSSweQG+neCvwUXWjk/qTrQohurnwbHI2G0avh6NPr/v7BKo4Oe8aes4BFbJfts/fskI0ZZz+8NW/d2/B++h1/09+6bvW9lecJ+6f8p78A4ayzFA==</latexit>
@wi,j (t)
ADAM update:
mt
wi,j := wi,j ⌘p
<latexit sha1_base64="ybbslpsrNYZDlLvaddWDralrgAc=">AAACH3icbZBNS8NAEIY3flu/qh69LBZBUEuioiIIohePFawKTQmb7URXN5u4O1FKyD/x4l/x4kER8dZ/47ZW8OuFhYd3ZpidN0ylMOi6HWdgcGh4ZHRsvDQxOTU9U56dOzVJpjnUeSITfR4yA1IoqKNACeepBhaHEs7C68Nu/ewWtBGJOsF2Cs2YXSgRCc7QWkF56y7IxSq9Kuju3heu+YCM+pFmPI8DLHLf3GjMdbHiQ2qETFQRlCtu1e2J/gWvDxXSVy0ov/uthGcxKOSSGdPw3BSbOdMouISi5GcGUsav2QU0LCoWg2nmvfsKumSdFo0SbZ9C2nO/T+QsNqYdh7YzZnhpfte65n+1RobRTjMXKs0QFP9cFGWSYkK7YdGW0MBRti0wroX9K+WXzOaCNtKSDcH7ffJfOF2vehvV9ePNyv5BP44xskAWyTLxyDbZJ0ekRuqEk3vySJ7Ji/PgPDmvzttn64DTn5knP+R0PgCGDqNY</latexit>
r+✏
Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980.
Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980.
In section 2 we describe the algorithm and the properties of its update rule. Section 3 explains
our initialization bias correction technique, and section 4 provides a theoretical analysis of Adam’s
convergence in online convex programming. Empirically,
Sebastian Raschka
our method consistently outperforms other
STAT 453: Intro to Deep Learning
methods for a variety of models and datasets, as shown in section 6. Overall, we show that Adam is 41
Experimenting with di erent
optimization algorithms
Remember to save the optimizer state if you are using, e.g., Momentum or
ADAM, and want to continue training later
(see earlier slides on saving states of the learning rate schedulers).
Source: https://pytorch.org/docs/stable/optim.html
Sebastian Raschka STAT 453: Intro to Deep Learning 45
fi
sgd-scheduler-momentum.ipynb adam.ipynb
!
, p · ↵sgd , q · ↵sgd mk 1.
+✏
"it is known that Adam will not always give you the best performance, yet most of
the time people know that they can use it with its default parameters and get, if not
the best performance, at least the second best performance on their particular deep
learning problem. "
"Usually people try new architectures keeping the optimization algorithm xed, and most
of the time the algorithm of choice is Adam. This happens because, as explained above,
Adam is the default optimizer."
fi
meters ✓, and Dt is the training dataset of size |Dt |. This loss is near zero when a model with
meters ✓ accurately classifies the training data.
-parameterized neural networks (i.e., those with more parameters than training data) can represent
rary, even random, labeling functions on large datasets [Zhang et al., 2016]. As a result, an
mizer can reliably fit an over-parameterized network to training data and achieve near zero
[Laurent and Brecht, 2018, Kawaguchi, 2016]. However, this comes with no guarantee of
ralization to unseen test data.
llustrate the difference between model fit-
https://arxiv.org/abs/1906.03291
and generalization with an experiment.
CIFAR-10 training dataset contains 50,000
l images. We train two over-parameterized
els on this dataset. The first is a neural
ork (ResNet-18) with 269,722 parameters
ly 6⇥ the number of training images). The
nd is a linear model with a feature set that
des pixel intensities as well as pair-wise
ucts of pixels intensities.1 This linear model
298, 369 parameters, which is comparable
e neural network, and both are trained using
. On the left of Figure 2, we see that over-
meterization causes both models to achieve Figure 2: (left) CIFAR10 trained with ResNet-18 and a
ct accuracy on training data. But the linear
linear model having comparable number of parameters.
el achieves only 49% test accuracy, while Both can fit the training data well, but neural nets are
Net-18 achieves 92%. able to generalize to unseen data, while linear models
cannot. (right) CIFAR10 trained with various optimiz-
excellent performance of the neural network ers using VGG13, generalizing well irrespective of the
el raises the question of whether bad min- optimizer used.
xist at all. Maybe deep networks generalize
use bad minima are rare and lie far away from the region of parameter space where initialization
place? We can confirm the existence of Raschka
Sebastian bad minimaSTAT
by453:
“poisoning” the loss function with a
Intro to Deep Learning 54
2 Methods
2.1 Details of AdaBelief Optimizer
Comparison with Adam Adam and AdaBelief are summarized in Algo. 1 and Algo. 2, where
all operations are element-wise, with differences marked in blue. Note that no p extra parameters
"trains fast as Adam, are introduced in AdaBelief. Specifically, in Adam, the update direction is mt / vt , where vt is
p
generalizes well as SGD, and is the EMA of gt2 ; in AdaBelief, the update direction is mt / st , where st is the EMA of (gt mt )2 .
Intuitively, viewing mt as the prediction of gt , AdaBelief takes a large step when observation gt is
stable to train GANs" close to prediction mt , and a small step when the observation greatly deviates from the prediction. b.
represents bias-corrected value. Note that an extra ✏ is added to st during bias-correction, in order to