Skip to content

A mistake in PyTorch Docs for nn.RNN #129446

@shehper

Description

@shehper

📚 The doc issue

The code snippet in PyTorch docs for nn.RNN seems to have a mistake.

Inside the forward function, h_t[layer] is defined as

h_t[layer] = torch.tanh(
                x[t] @ weight_ih[layer].T
                + bias_ih[layer]
                + h_t_minus_1[layer] @ weight_hh[layer].T
                + bias_hh[layer]
            )

For layer != 0, we should write h_t[layer-1] instead of x[t]. A multi-layer RNN acts on the hidden state of the previous layer when layer > 0.

Suggest a potential alternative/fix

We could change the code snippet to the following.

# Efficient implementation equivalent to the following with bidirectional=False
def forward(x, h_0=None):
    if batch_first:
        x = x.transpose(0, 1)
    seq_len, batch_size, _ = x.size()
    if h_0 is None:
        h_0 = torch.zeros(num_layers, batch_size, hidden_size)
    h_t_minus_1 = h_0
    h_t = h_0
    output = []
    for t in range(seq_len):
        for layer in range(num_layers):
            ih_input = x[t] if layer == 0 else h_t[layer-1]
            h_t[layer] = torch.tanh(
                ih_input @ weight_ih[layer].T
                + bias_ih[layer]
                + h_t_minus_1[layer] @ weight_hh[layer].T
                + bias_hh[layer]
            )
        output.append(h_t[-1].clone())
        h_t_minus_1 = h_t
    output = torch.stack(output)
    if batch_first:
        output = output.transpose(0, 1)
    return output, h_t

cc @svekars @sekyondaMeta @AlannaBurke @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @brycebortree

Metadata

Metadata

Assignees

Labels

docathon-h1-2025mediumLabel for medium docathon tasksmodule: docsRelated to our documentation, both in docs/ and docblocksmodule: nnRelated to torch.nnmodule: rnnIssues related to RNN support (LSTM, GRU, etc)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions