Chapter Recurrent Neural Networks
Chapter Recurrent Neural Networks
So far, we have limited our attention to domains in which each output y is assumed to
have been generated as a function of an associated input x, and our hypotheses have been
“pure” functions, in which the output depends only on the input (and the parameters we
have learned that govern the function’s behavior). In the next few chapters, we are going
to consider cases in which our models need to go beyond functions. In particular, behavior
as a function of time will be an important concept:
• In recurrent neural networks, the hypothesis that we learn is not a function of a single
input, but of the whole sequence of inputs that the predictor has received.
In this chapter, we introduce state machines. We start with deterministic state machines,
and then consider recurrent neural network (RNN) architectures to model their behavior.
Later, in Chapter 10, we will study Markov decision processes (MDPs) that extend to consider
probabilistic (rather than deterministic) transitions in our state machines. RNNs and MDPs
will enable description and modeling of temporally sequential patterns of behavior that are
important in many domains.
80
MIT 6.390 Fall 2022 81
The basic operation of the state machine is to start with state s0 , then iteratively compute In some cases, we will
for t > 1: pick a starting state
from a set or distribu-
tion.
st = fs (st−1 , xt ) (9.1)
yt = fo (st ) (9.2)
The diagram below illustrates this process. Note that the “feedback” connection of
st back into fs has to be buffered or delayed by one time step—-otherwise what it
is computing would not generally be well defined.
xt st yt
fs fo
−
st−1
We sometimes say that the machine transduces sequence x into sequence y. The output at
time t can have dependence on inputs from steps 1 to t. There are a huge num-
One common form is finite state machines, in which S, X, and Y are all finite sets. They are ber of major and minor
variations on the idea of
often described using state transition diagrams such as the one below, in which nodes stand
a state machine. We’ll
for states and arcs indicate transitions. Nodes are labeled by which output they generate just work with one spe-
and arcs are labeled by which input causes the transition. cific one in this section
and another one in the
One can verify that the state machine below reads binary strings and determines the next, but don’t worry if
you see other variations
parity of the number of zeros in the given string. Check for yourself that all input out in the world!
binary strings end in state S1 if and only if they contain an even number of zeros.
All computers can be
described, at the digital
level, as finite state ma-
chines. Big, but finite!
Another common structure that is simple but powerful and used in signal processing
and control is linear time-invariant (LTI) systems. In this case, all the quantities are real-
valued vectors: S = Rm , X = Rl and Y = Rn . The functions fs and fo are linear functions of
their inputs. The transition function is described by the state matrix A and the input matrix
B; the output function is defined by the output matrix C, each with compatible dimensions.
In discrete time, they can be defined by a linear difference equation, like
and can be implemented using state to store relevant previous input and output informa-
tion. We will study recurrent neural networks which are a lot like a non-linear version of an
LTI system.
A recurrent neural network is a state machine with neural networks constituting functions
fs and fo :
The inputs, states, and outputs are all vector-valued: We are very sorry! This
course material has
xt : ` × 1 (9.7) evolved from different
sources, which used
st : m × 1 (9.8) W T x in the forward
yt : v × 1 . (9.9) pass for regular feed-
forward NNs and Wx
The weights in the network, then, are for the forward pass in
RNN s. This inconsis-
tency doesn’t make any
W sx : m × ` (9.10) technical difference, but
W ss
:m×m (9.11) is a potential source of
confusion.
W0ss :m×1 (9.12)
o
W :v×m (9.13)
W0o : v × 1 (9.14)
The per-element loss function Lelt will depend on the type of yt and what information it is So it could be NLL,
encoding, in the same way as for a supervised network. squared loss, etc.
Then, letting W = (W sx , W ss , W o , W0ss , W0o ), our overall goal is to minimize the objec-
tive
1X
q
J(W) = Lseq RNN(x(i) ; W), y(i) , (9.16)
q
i=1
x = (hstarti, c1 , c2 , . . . , ck ) (9.18)
y = (c1 , c2 , . . . , hendi) (9.19)
What we want you to take away from this section is that, by “unrolling” a recurrent
network out to model a particular sequence, we can treat the whole thing as a feed-
forward network with a lot of parameter sharing. Thus, we can tune the parameters
using stochastic gradient descent, and learn to model sequential mappings. The
concepts here are very important. While the details are important to get right if you
need to implement something, we present the mathematical details below primarily
to convey or explain the larger concepts.
Calculus reminder: total derivative Most of us are not very careful about the differ-
ence between the partial derivative and the total derivative. We are going to use a nice
example from the Wikipedia article on partial derivatives to illustrate the difference.
The volume of a circular cone depends on its height and radius:
πr2 h
V(r, h) = . (9.20)
3
The partial derivatives of volume with respect to height and radius are
∂V 2πrh ∂V πr2
= and = . (9.21)
∂r 3 ∂h 3
They measure the change in V assuming everything is held constant except the
single variable we are changing. Now assume that we want to preserve the cone’s
proportions in the sense that the ratio of radius to height stays constant. Then we
can’t really change one without changing the other. In this case, we really have to
think about the total derivative. If we’re interested in the total derivative with respect
to r, we sum the “paths” along which r might influence V:
dV ∂V ∂V dh
= + (9.22)
dr ∂r ∂h dr
2πrh πr2 dh
= + (9.23)
3 3 dr
Or if we’re interested in the total derivative with respect to h, we consider how h
might influence V, either directly or via r:
dV ∂V ∂V dr
= + (9.24)
dh ∂h ∂r dh
πr2 2πrh dr
= + (9.25)
3 3 dh
Just to be completely concrete, let’s think of a right circular cone with a fixed angle
α = tan r/h, so that if we change r or h then α remains constant. So we have
r = h tan−1 α; let constant c = tan−1 α, so now r = ch. Thus, we finally have
dV 2πrh πr2 1
= + (9.26)
dr 3 3 c
dV πr2 2πrh
= + c. (9.27)
dh 3 3
(1) Sample a training pair of sequences (x, y); let their length be n.
(2) “Unroll" the RNN to be length n (picture for n = 3 below), and initialize s0 :
Now, we can see our problem as one of performing what is almost an ordinary back-
propagation training procedure in a feed-forward neural network, but with the dif-
ference that the weight matrices are shared among the layers. In many ways, this is
similar to what ends up happening in a convolutional network, except in the conv-
net, the weights are re-used spatially, and here, they are re-used temporally.
(4) Do backward pass to compute the gradients. For both W ss and W sx we need to find
n
dLseq (g, y) X dLelt (gu , yu )
=
dW dW
u=1
Letting Lu = Lelt (gu , yu ) and using the total derivative, which is a sum over all the
ways in which W affects Lu , we have
n X
X n
∂st ∂Lu
=
∂W ∂st
u=1 t=1
Re-organizing, we have
n
X n
∂st X ∂Lu
=
∂W ∂st
t=1 u=1
Xn n
∂st X ∂Lu
=
∂W u=t ∂st
t=1
n n
X ∂st ∂Lt X ∂Lu
= + . (9.32)
∂W ∂st ∂st
t=1 u=t+1
| {z }
δst
where δst is the dependence of the future loss (incurred after step t) on the state St . That is, δst is how
much we can blame
We can compute this backwards, with t going from n down to 1. The trickiest part is state st for all the future
figuring out how early states contribute to later losses. We define the future loss after element losses.
step t to be
n
X
Ft = Lelt (gu , yu ) , (9.33)
u=t+1
so
∂Ft
δst = . (9.34)
∂st
At the last stage, Fn = 0 so δsn = 0.
Now, working backwards,
n
∂ X
δst−1 = Lelt (gu , yu ) (9.35)
∂st−1 u=t
n
∂st ∂ X
= Lelt (gu , yu ) (9.36)
∂st−1 ∂st u=t
" n
#
∂st ∂ X
= Lelt (gt , yt ) + Lelt (gu , yu ) (9.37)
∂st−1 ∂st
u=t+1
∂st ∂Lelt (gt , yt )
= +δ st
(9.38)
∂st−1 ∂st
Now, we can use the chain rule again to find the dependence of the element loss at
time t on the state at that same time,
and the dependence of the state at time t on the state at the previous time,
Note that ∂st /∂z1t is formally an m × m diagonal matrix, with the values along
the diagonal being fs0 (z1t,i ), 1 6 i 6 m. But since this is a diagonal matrix,
one could represent it as an m × 1 vector fs0 (z1t ). In that case the product
of the matrix W ss T by the vector fs0 (z1t ), denoted W ss T ∗ fs0 (z1t ), should be
interpreted as follows: take the first column of the matrix W ss T and multiply
each of its elements by the first element of the vector ∂st /∂z1t , then take the
second column of the matrix W ss T and multiply each of its elements by the
second element of the vector ∂st /∂z1t , and so on and so forth ...
We’re almost there! Now, we can describe the actual weight updates. Using Eq. 9.32
and recalling the definition of δst = ∂Ft /∂st , as we iterate backwards, we can accu-
mulate the terms in Eq. 9.32 to get the gradient for the whole loss.
n n
X dLelt (gt , yt ) X ∂z1 ∂st ∂Ft−1
dLseq
= = t
(9.42)
dW ss dW ss ∂W ss ∂z1t ∂st
t=1 t=1
Xn Xn
dLseq dLelt (gt , yt ) ∂z1t ∂st ∂Ft−1
= = (9.43)
dW sx dW sx ∂W sx ∂z1t ∂st
t=1 t=1
We can handle W o separately; it’s easier because it does not affect future losses in the
way that the other weight matrices do:
n n
dLseq X dLt X ∂Lt ∂z2
= = t
(9.44)
dW o dW o
t=1
∂z2t ∂W o t=1
Assuming we have ∂L t
∂z2t
= (gt − yt ), (which ends up being true for squared loss,
softmax-NLL, etc.), then
n
dLseq X
= (gt − yt ) sTt . (9.45)
dW o
| {z } | {z } |{z}
t=1 v×1 1×m
v×m
Whew!
Study Question: Derive the updates for the offsets W0ss and W0o .
Consider a case where only the output at the end of the sequence is incorrect, but it depends
critically, via the weights, on the input at time 1. In this case, we will multiply the loss at
step n by
∂s2 ∂s3 ∂sn
··· . (9.47)
∂s1 ∂s2 ∂sn−1
In general, this quantity will either grow or shrink exponentially with the length of the
sequence, and make it very difficult to train.
Study Question: The last time we talked about exploding and vanishing gradients, it
was to justify per-weight adaptive step sizes. Why is that not a solution to the prob-
lem this time?
An important insight that really made recurrent networks work well on long sequences
is the idea of gating.
where ∗ is component-wise multiplication. We can see, here, that the output of the gating
network is deciding, for each dimension of the state, how much it should be updated now.
This mechanism makes it much easier for the network to learn to, for example, “store”
some information in some dimension of the state, and then not change it during future
state updates, or change it only under certain conditions on the input or other aspects of
the state.
Study Question: Why is it important that the activation function for g be a sigmoid?