0% found this document useful (0 votes)
11 views22 pages

RNNs

The document discusses Recurrent Neural Networks (RNNs) and their relevance in sequence modeling, particularly in biomedical research. It highlights the advantages of Long Short-Term Memory (LSTM) networks over traditional RNNs, especially in preserving long-range dependencies and mitigating gradient loss. The presentation concludes that while LSTMs are still significant despite the rise of large language models (LLMs), they can be effectively combined with other architectures for improved predictive performance.

Uploaded by

Juan Placer
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
11 views22 pages

RNNs

The document discusses Recurrent Neural Networks (RNNs) and their relevance in sequence modeling, particularly in biomedical research. It highlights the advantages of Long Short-Term Memory (LSTM) networks over traditional RNNs, especially in preserving long-range dependencies and mitigating gradient loss. The presentation concludes that while LSTMs are still significant despite the rise of large language models (LLMs), they can be effectively combined with other architectures for improved predictive performance.

Uploaded by

Juan Placer
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 22

Recurrent Neural Networks

and Sequence Modeling


David Sahner, M.D.
Senior Advisor, National Center for the Advancement of Translational Sciences, and
Executive Advisor, Data Science and Translational Medicine, Axle Research and Technologies
Presentation Outline
• RNNs (other than LSTMs)

• Preserving long-range dependencies in sequences

• LSTMs

• Examples of applications in biomedical research

• Conclusions
Are traditional RNNs still relevant?
• Until the last few years, Long Short-Term Memory
(LSTM) RNNs represented state-of-the art in
sequence modeling
• In Natural Language Processing, LSTMs are trained to
predict the next word in a sequence or to classify text.

• LSTMs have become somewhat eclipsed by LLMs but can LLM


still be useful. They require fewer parameters and are
readily leveraged using tf.keras (see:
https://www.tensorflow.org/api_docs/python/tf/keras/layers/
LSTM)

• LSTMs can also be used to model other sequences, LSTM – smaller and less
contemporary, but still a
such as patient trajectories, to make predictions hefty player in AI
informed by the nature & order of prior
events/diagnoses, etc.
Prefatory comments
• In diagrams, RNNs are usually depicted as operating on a series of
individually time-stamped input vectors
• Vectors may represent a sequence in time or space (e.g., word order in text)

• Typically, an RNN operates on minibatches of sequences

• By nature, the hidden vectors in RNNs are “lossy,” as a sequence is


mapped to a single hidden vector at a given time point

• RNN parameter sharing across time steps is computationally economical

• Subsequent slides present RNN architectures with a single hidden


representation (h) at each time step, but depth can be increased at each
point in the sequence by adding another layer to the RNN.
• This may improve performance based on evidence in the literature
Basic RNN architecture and update equations
x = vector input
h = hidden state vector
o = unnormalized log probabilities used to
compute softmax outputs
L = loss function such as negative log likelihood
(computed for each step, with total loss
for series equaling sum of stepwise losses)
y = true label
U, W and V = learned weight matrices

Forward propagation update equations:

h(t) = tanh[Wh(t-1) + Ux(t) +b]

o(t) = Vh(t) +c

Backpropagation to compute gradient


Image from https://www.deeplearningbook.org/contents/rnn.html informing weight updates
Can also directly leverage last token
label in updating the hidden state

Image from https://www.deeplearningbook.org/contents/rnn.html


A single output at the end can be used
to classify text

Image from https://www.deeplearningbook.org/contents/rnn.html


Bidirectional RNN architecture: Exploiting
the past and the future

Intuition: Predicted word token at time t depends upon


both the prior words and the subsequently uttered words.
In essence, “both slices of bread are used to predict the
content of the sandwich.”
This bidirectional RNN is most attuned to “past and future
words” in proximity to the token to be predicted at time t.
g = sequence of hidden states emanating “from the future”
h = sequence of hidden states “emanating from the past”
Linear transformations of both g and h influence output (o)
used to compute loss at each step

Image from https://www.deeplearningbook.org/contents/rnn.html


Memory loss and the difficulty in modeling
long-range dependencies using RNNs
• As mentioned earlier, RNNs are inherently lossy, as a sequence is typically compressed
into a single hidden representation

• But there is another problem, born of parsimony


• Repeated application of the same linear transformation renders RNNs particularly susceptible
to gradient decay (and/or explosion*) over time.
• This can be readily appreciated if the weight vector used to compute a hidden state based on a
prior state is subjected to eigendecomposition (W= SΛS-1)
• h(t) prior to application of activation function = (SΛtS-1)h(0)

• Note that the eigenvalues of W at time step t are raised to the power of t, causing exponentiated values of <1 to
evaporate with time, and values >1 to explode, hindering learning in general and obscuring/eroding signals related to
long-range dependencies

*Gradient clipping can help address gradient explosion, but gradient loss remains problematic in vanilla RNNs
How can we improve memory in an RNN?
• Leaky units “allow more leakage of the past into the present”
• Compute hidden state at time t.
• Then combine weighted “self-connection” of the hidden state with a weighted value of the hidden
representation at h(t-1) (to “access more of the past”) → “revised” h(t)
• Updated/revised” h(t) can be denoted as ρ(t) = α(ρ(t-1)) + (1-α)h(t) where alpha is an adjustable
hyperparameter - a real number within the range (0,1)
• Repeat the above at each time step.
• The closer the value of α is to 1, the more of the past that seeps into the present

• Addition of skip connections from the past, or replacement of a length- one


connection with a (longer) skip connection

• Modify the basic RNN architecture to create a new class of RNN


• Cue LSTMs!
LSTMs: General Points
• Compared with traditional vanilla RNNs, LSTMs have more
complex architecture and update equations
• The core LSTM building block is the “memory cell,” which also
controls what is forgotten.
• Aggregate “connectivity weight” between hidden vector values
changes over time because of the architecture*
• Gated recursive loops enable gradient to endure over time
• LSTMs can also accommodate bidirectionality

*Parameters remain shared, however, among weight matrices. See next slide.
Core LSTM building block: The Memory Cell
h(t) hidden layer vector at time t
output
s(t) memory cell state at time t
h(t)
s(t-1) memory cell state at time t-1
s(t-1)
g(t) output of input gate
f(t) output of forget gate
self-loop
q(t) output of the output gate
s(t) x(t) input at time t
h(t-1) hidden layer vector at time t
state
U and W represent learned weight matrices with
superscripts indicating distinct weight matrices for the
input, forget, and output gates.
g(t) f(t) q(t)
input input gate forget gate output gate Input gate controls extent to which current input informs
cell state. Forget gate modulates degree to which prior
state is reflected in new state. Output gate governs how
U W Ug Wg U f Wf Uo Wo much influence current state has on output to next layer.
Gates each apply scalar outputs between 0 and 1.

Image adapted from https://www.deeplearningbook.org/contents/rnn.html


Forward propagation update equations for a cell
g(t) = σ(Ugx(t) + Wgh(t-1) + bg)
output
h(t)
s(t-1)
f(t) = σ(Ufx(t) + Wfh(t-1) + bf)
s(t)
self-loop
q(t) = σ(Uox(t) + Woh(t-1) + bo)
state
s(t) = f(t)s(t-1) + g(t)σ(Ux(t) + Wh(t-1) + b)

h(t) = tanh(s(t))q(t)
g(t) f(t) q(t)
input input gate forget gate output gate b is a specific bias term indicated by superscript. Simple
stochastic gradient descent often used in backpropagation for
LSTMs.
U W Ug Wg U f Wf Uo Wo

Image adapted from https://www.deeplearningbook.org/contents/rnn.html


Examples of applications of
RNN/LSTMs in biomedical research
• LSTM trained on features from the EHR and GNN embeddings to predict opioid
overdose risk (Dong et al., 2023) https://pubmed.ncbi.nlm.nih.gov/36628797/

• Prediction of post-treatment late-stage symptom ratings leveraging patient-


reported outcome data from head and neck cancer patients treated at MD
Anderson Cancer (Wang et al., 2023) https://pubmed.ncbi.nlm.nih.gov/38343586/

• Combination of LSTM and LightGBM to predict delirium in hospitalized patients


using EHR data (Schlesinger et al., 2022)
https://pubmed.ncbi.nlm.nih.gov/36303456/

• Hybrid CNN-LSTM models have also been explored in biomedical research


Clever use of an LSTM fed by features from a
GNN leveraging EHR data (Dong et al., 2023)
• Objective: Prediction of opioid overdose after prescription of an
opioid
• Rationale for (complementary) combination of techniques
• In general, GNNs, which capture relationships between examples, tend to
further homogenize the features of “like nodes” (examples) in a graph,
potentially improving discrimination between classes.*
• LSTMs exploit longitudinal information to make predictions

• Over 5.2 million patients in the Cerner database 16 to 66 years of


age without a cancer diagnosis

*See GNN teaching module for in-depth description of GNNs


High-level overview of model structure
used by Dong and colleagues
• Features with co-occurrence of <1%
in overdose patients excluded
• Data from last 5* encounters ≥14
days and ≤ 12 months before target
encounter
• Team used a heterogenous GNN
(see back-up slide for added detail).
• Two-layer LSTM with 64 hidden units
• 80%-20% train-test split
• Cross-entropy loss and Adam
optimizer
*If fewer than 5 encounters, data from last available encounter
cloned. Image from Dong et al., 2023
Table from Dong et al., 2023
LSTM + GNN comes out on top

Image from Dong et al., 2023


Conclusions
• RNNs are used in sequence modeling but older architectures are prone to
gradient loss, challenging learning.
• Steps can be taken to mitigate this liability (i.e., leaky units and skip connections)

• LSTMs remain state-of-the art in RNN architectures in large part because


they are essentially immune to gradient loss, better preserving
“knowledge” of long-term dependencies in a series
• Rumors of the death of LSTMs with the advent of powerful LLMs are vastly
overstated.
• LLMs continue to play a role in biomedical research, and may be creatively
combined with other architectures to further enhance predictive
performance
Supplementary Slides
Image from Dong et al., 2023

You might also like