Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

[inductor] TypeError: function takes exactly 16 arguments (13 given) #1577

@typedfemale

Description

@typedfemale

I have a pretty odd one. Minimal repo below:

import torch
from torch import tensor, device
import torch.fx as fx
from torchdynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx

# torch version: 1.14.0.dev20221009
# torch cuda version: 11.7
# torch git version: 0dbefb2414417e80371ef3d8224404d4a522f86e


# CUDA Info:
# nvcc: NVIDIA (R) Cuda compiler driver
# Copyright (c) 2005-2022 NVIDIA Corporation
# Built on Wed_Jun__8_16:49:14_PDT_2022
# Cuda compilation tools, release 11.7, V11.7.99
# Build cuda_11.7.r11.7/compiler.31442593_0

# GPU Hardware Info:
# NVIDIA A100-SXM4-40GB : 1


from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()



    def forward(self, arg0_1, new_zeros_1):
        slice_scatter = torch.ops.aten.slice_scatter.default(new_zeros_1, arg0_1, 2, 0, 2048);  new_zeros_1 = arg0_1 = None
        return (slice_scatter,)

args = [((16, 128, 2048), (262144, 2048, 1), torch.float32, 'cuda'), ((16, 128, 2112), (270336, 2112, 1), torch.float32, 'cuda')]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
mod = make_fx(Repro())(*args)

from torchinductor.compile_fx import compile_fx_inner
from torchdynamo.debug_utils import same_two_models

compiled = compile_fx_inner(mod, args)
compiled(*args)

Needed for fast state space models 😄

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions