Introduction To RNNS!: Arun Mallya!
Introduction To RNNS!: Arun Mallya!
Arun Mallya!
y1!
h1!
x1!
t = 1!
5
Sample RNN!
y3!
y2!
h3!
y1!
h2!
x3!
h1! t = 3!
x2!
x1! t = 2!
t = 1!
6
Sample RNN!
y3!
y2!
h3!
y1!
h2!
x3!
h1! t = 3!
x2!
h0! t = 2!
x1!
t = 1!
7
The Vanilla RNN Cell!
xt! W!
!
ht!
!
ht-1!
!
⎛ xt ⎞
ht = tanhW ⎜ ⎟
⎝ ht−1 ⎠
8
The Vanilla RNN Forward!
C1! C2! C3!
⎛ xt ⎞
ht = tanhW ⎜ ⎟
⎝ ht−1 ⎠
h1! h2! h3!
yt = F(ht )
Ct = Loss(yt ,GTt )
9
The Vanilla RNN Forward!
C1! C2! C3!
⎛ xt ⎞
ht = tanhW ⎜ ⎟
⎝ ht−1 ⎠
h1! h2! h3!
yt = F(ht )
Ct = Loss(yt ,GTt )
10
Recurrent Neural Networks (RNNs)!
• Note that the weights are shared over time!
RNN!
h1!
The!
Sentiment Classification!
RNN! RNN!
h1! h2!
The! food!
Sentiment Classification!
hn!
Linear
Classifier!
hn!
Linear
Ignore! Ignore!
Classifier!
h = Sum(…)!
h1! hn!
h2!
http://deeplearning.net/tutorial/lstm.html !
Sentiment Classification!
Linear
Classifier!
h = Sum(…)!
h1! hn!
h2!
http://deeplearning.net/tutorial/lstm.html !
Image Captioning!
• Given an image, produce a sentence describing its contents!
!
• Inputs: Image feature (from a CNN)!
• Outputs: Multiple words (let’s consider one sentence)!
RNN!
CNN!
Image Captioning!
The!
Linear
Classifier!
h2!
RNN! RNN!
h1! h2!
CNN!
Image Captioning!
The! dog!
Linear Linear
Classifier! Classifier!
h2! h3!
RNN! RNN! RNN!
h1! h2! h3!
CNN!
RNN Outputs: Image Captions!
http://karpathy.github.io/2015/05/21/rnn-effectiveness/ !
Input – Output Scenarios!
Image Captioning!
Input – Output Scenarios!
Note: We might deliberately choose to frame our problem as a!
particular input-output scenario for ease of training or!
better performance. !
For example, at each time step, provide previous word as!
input for image captioning!
(Single-Multiple to Multiple-Multiple).!
The Vanilla RNN Forward!
C1! C2! C3!
28
BackPropagation Refresher!
y = f (x;W )
C! C = Loss(y, yGT )
y!
SGD Update
f(x; W)! ∂C
W ←W −η
∂W
x!
∂C ⎛ ∂C ⎞ ⎛ ∂y ⎞
=⎜ ⎜ ⎟
∂W ⎝ ∂y ⎟⎠ ⎝ ∂W ⎠
Multiple Layers!
y1 = f1 (x;W1 )
y2 = f2 (y1;W2 )
C! C = Loss(y2 , yGT )
y2!
SGD Update
∂C
f2(y1; W2)! W2 ← W2 − η
∂W2
y1! ∂C
W1 ← W1 − η
∂W1
f1(x; W1)!
x!
Chain Rule for Gradient Computation!
y1 = f1 (x;W1 )
y2 = f2 (y1;W2 )
C! C = Loss(y2 , yGT )
∂C ∂C
y2! Find ,
∂W1 ∂W2
f2(y1; W2)! ∂C ⎛ ∂C ⎞ ⎛ ∂y2 ⎞
=⎜
∂W2 ⎝ ∂y2 ⎟⎠ ⎜⎝ ∂W2 ⎟⎠
y1!
∂C ⎛ ∂C ⎞ ⎛ ∂y1 ⎞
=⎜
f1(x; W1)! ∂W1 ⎝ ∂y1 ⎟⎠ ⎜⎝ ∂W1 ⎟⎠
⎛ ∂C ⎞ ⎛ ∂y2 ⎞ ⎛ ∂y1 ⎞
x! =⎜
⎝ ∂y2 ⎟⎠ ⎜⎝ ∂y1 ⎟⎠ ⎜⎝ ∂W1 ⎟⎠
Application of the Chain Rule!
Chain Rule for Gradient Computation!
⎛ ∂C ⎞
Given:! ⎜
⎝ ∂y ⎟⎠
⎛ ∂C ⎞ ⎛ ∂C ⎞
We are interested in computing:! ⎜
⎝ ∂W ⎠ ⎝ ∂x ⎟⎠
⎟
y!
,⎜
⎛ ∂C ⎞ ⎛ ∂C ⎞ ⎛ ∂y ⎞ ⎛ ∂C ⎞ ⎛ ∂C ⎞ ⎛ ∂y ⎞
⎜⎝ ⎟⎠ = ⎜ ⎜ ⎟ ⎜⎝ ⎟⎠ = ⎜ ⎜ ⎟
∂W ⎝ ∂y ⎟⎠ ⎝ ∂W ⎠ ∂x ⎝ ∂y ⎟⎠ ⎝ ∂x ⎠
Chain Rule for Gradient Computation!
⎛ ∂C ⎞
Given:! ⎜
⎝ ∂y ⎟⎠
⎛ ∂C ⎞ ⎛ ∂C ⎞ ⎛ ∂C ⎞
⎜⎝ ∂y ⎟⎠ We are interested in computing:! ⎜
⎝ ∂W ⎠ ⎝ ∂x ⎟⎠
⎟ ,⎜
⎛ ∂C ⎞ ⎛ ∂C ⎞ ⎛ ∂y ⎞ ⎛ ∂C ⎞ ⎛ ∂C ⎞ ⎛ ∂y ⎞
⎜⎝ ⎟⎠ = ⎜ ⎜ ⎟ ⎜⎝ ⎟⎠ = ⎜ ⎜ ⎟
∂W ⎝ ∂y ⎟⎠ ⎝ ∂W ⎠ ∂x ⎝ ∂y ⎟⎠ ⎝ ∂x ⎠
Equations for common layers: http://arunmallya.github.io/writeups/nn/backprop.html!
Extension to Computational Graphs!
y1! y2!
y!
f1(y; W1)! f2(y; W2)!
f(x; W)!
y! y!
x!
f(x; W)!
x!
Extension to Computational Graphs!
⎛ ∂C1 ⎞ ⎛ ∂C2 ⎞
⎛ ∂C ⎞ ⎜⎝ ∂y ⎟⎠ ⎜⎝ ∂y ⎟⎠
1 2
⎜⎝ ∂y ⎟⎠
f1(y; W1)! f2(y; W2)!
f(x; W)!
⎛ ∂C1 ⎞ ⎛ ∂C2 ⎞
⎜⎝ ∂y ⎟⎠ ⎜⎝ ∂y ⎟⎠
⎛ ∂C ⎞
⎜⎝ ⎟
∂x ⎠ Σ
f(x; W)!
⎛ ∂C ⎞
⎜⎝ ⎟
∂x ⎠
Extension to Computational Graphs!
⎛ ∂C1 ⎞ ⎛ ∂C2 ⎞
⎛ ∂C ⎞ ⎜⎝ ∂y ⎟⎠ ⎜⎝ ∂y ⎟⎠
1 2
⎜⎝ ∂y ⎟⎠
f1(y; W1)! f2(y; W2)!
f(x; W)!
⎛ ∂C1 ⎞ ⎛ ∂C2 ⎞
⎜⎝ ∂y ⎟⎠ ⎜⎝ ∂y ⎟⎠
⎛ ∂C ⎞
⎜⎝ ⎟
∂x ⎠ Σ Gradient Accumulation!
f(x; W)!
⎛ ∂C ⎞
⎜⎝ ⎟
∂x ⎠
BackPropagation Through Time
(BPTT)!
• One of the methods used to train RNNs!
• The unfolded network (used during forward pass) is
treated as one big feed-forward network!
• This unfolded network accepts the whole time series as
input!
39
The Unfolded Vanilla RNN Backward!
C1! C2! C3!
40
The Vanilla RNN Backward!
⎛ xt ⎞
C1! C2! C3! ht = tanhW ⎜ ⎟
⎝ ht−1 ⎠
y1! y2! y3!
yt = F(ht )
Ct = Loss(yt ,GTt )
41
Issues with the Vanilla RNNs!
• In the same way a product of k real numbers can shrink to
zero or explode to infinity, so can a product of matrices!
xt! W! Cell!
!
!
ct! ht!
ht-1!
!
⎛ xt ⎞
ct = ct−1 + tanhW ⎜ ⎟ ht = tanh ct
⎝ ht−1 ⎠ 47
* Dashed line indicates time-lag!
The Original LSTM Cell!
xt ht-1! xt ht-1!
! !
Wi! Wo!
Input Gate! it! Output Gate! ot!
xt! W! Cell!
!
!
ct! ht!
ht-1!
!
⎛ xt ⎞ ⎛ ⎛ xt ⎞ ⎞
ct = ct−1 + it ⊗ tanhW ⎜ ⎟ ht = ot ⊗ tanh ct it = σ ⎜ Wi ⎜ ⎟ + bi ⎟ Similarly for ot!
⎝ ht−1 ⎠ ⎝ ⎝ ht−1 ⎠ ⎠
48
The Popular LSTM Cell!
xt ht-1! xt ht-1!
! !
Wi! Wo!
Input Gate! it! Output Gate! ot!
xt! W! Cell!
!
!
ct! ht!
ht-1!
!
⎛ xt ⎞ ⎛ ⎛ xt ⎞ ⎞
ct = ft ⊗ ct−1 + it ⊗ tanhW ⎜ ⎟ xt ht-1! ft = σ ⎜ W f ⎜ ⎟ + b f ⎟
⎝ ht−1 ⎠ ! ⎝ ⎝ ht−1 ⎠ ⎠
49
LSTM – Forward/Backward!
50
Summary!
• RNNs allow for processing of variable length inputs and
outputs by maintaining state information across time steps!
• Various Input-Output scenarios are possible !
(Single/Multiple)!
52