-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Closed
Labels
docathon-h1-2025mediumLabel for medium docathon tasksLabel for medium docathon tasksmodule: docsRelated to our documentation, both in docs/ and docblocksRelated to our documentation, both in docs/ and docblocksmodule: nnRelated to torch.nnRelated to torch.nnmodule: rnnIssues related to RNN support (LSTM, GRU, etc)Issues related to RNN support (LSTM, GRU, etc)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
📚 The doc issue
The code snippet in PyTorch docs for nn.RNN
seems to have a mistake.
Inside the forward
function, h_t[layer]
is defined as
h_t[layer] = torch.tanh(
x[t] @ weight_ih[layer].T
+ bias_ih[layer]
+ h_t_minus_1[layer] @ weight_hh[layer].T
+ bias_hh[layer]
)
For layer != 0
, we should write h_t[layer-1]
instead of x[t]
. A multi-layer RNN acts on the hidden state of the previous layer when layer > 0.
Suggest a potential alternative/fix
We could change the code snippet to the following.
# Efficient implementation equivalent to the following with bidirectional=False
def forward(x, h_0=None):
if batch_first:
x = x.transpose(0, 1)
seq_len, batch_size, _ = x.size()
if h_0 is None:
h_0 = torch.zeros(num_layers, batch_size, hidden_size)
h_t_minus_1 = h_0
h_t = h_0
output = []
for t in range(seq_len):
for layer in range(num_layers):
ih_input = x[t] if layer == 0 else h_t[layer-1]
h_t[layer] = torch.tanh(
ih_input @ weight_ih[layer].T
+ bias_ih[layer]
+ h_t_minus_1[layer] @ weight_hh[layer].T
+ bias_hh[layer]
)
output.append(h_t[-1].clone())
h_t_minus_1 = h_t
output = torch.stack(output)
if batch_first:
output = output.transpose(0, 1)
return output, h_t
cc @svekars @sekyondaMeta @AlannaBurke @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @brycebortree
Metadata
Metadata
Assignees
Labels
docathon-h1-2025mediumLabel for medium docathon tasksLabel for medium docathon tasksmodule: docsRelated to our documentation, both in docs/ and docblocksRelated to our documentation, both in docs/ and docblocksmodule: nnRelated to torch.nnRelated to torch.nnmodule: rnnIssues related to RNN support (LSTM, GRU, etc)Issues related to RNN support (LSTM, GRU, etc)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module