Skip to content

[Inductor] Constant folding support #93420

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
jgong5 opened this issue Nov 13, 2022 · 4 comments
Closed

[Inductor] Constant folding support #93420

jgong5 opened this issue Nov 13, 2022 · 4 comments
Assignees
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jgong5
Copy link
Collaborator

jgong5 commented Nov 13, 2022

Motivating Example

Below is a case in MobileBertForMaskedLM which has a concatenation on two model parameters (hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))). This concat takes >20% of the single-threaded inference time on CPU but this cost can be saved with constant folding.

class MobileBertLMPredictionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.transform = MobileBertPredictionHeadTransform(config)
        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.dense = nn.Linear(config.vocab_size, config.hidden_size - config.embedding_size, bias=False)
        self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=False)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.transform(hidden_states)
        hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
        hidden_states += self.decoder.bias
        return hidden_states

cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire @ezyang @msaroufim @wconstab @ngimel @bdhirsh

@eellison eellison added the enhancement Not as big of a feature, but technically not a bug. Should be easy to fix label Nov 14, 2022
@malfet malfet transferred this issue from pytorch/torchdynamo Feb 1, 2023
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 7, 2023
@XiaobingSuper XiaobingSuper self-assigned this Mar 13, 2023
@ydwu4
Copy link
Contributor

ydwu4 commented Nov 29, 2023

Hi, is there any update on this? Do we still want to keep it open?

@jgong5
Copy link
Collaborator Author

jgong5 commented Nov 30, 2023

@Le-Zheng can you check if this issue has been addressed with freezing on?

@leslie-fang-intel
Copy link
Collaborator

I think constant folding should already be supported in fw_compiler_freezing. @Le-Zheng Could you help to double confirm this issue fixing by adding TORCHINDUCTOR_FREEZING=1?

@github-project-automation github-project-automation bot moved this from TODO to Done in PyTorch Intel Nov 30, 2023
@eellison
Copy link
Contributor

eellison commented Nov 30, 2023

Should be fixed with freezing, please re-open if not.

@eellison eellison reopened this Nov 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

No branches or pull requests

10 participants