Skip to content

[inductor] Accuracy issue - BertForQuestionAnswering, ElectraForQuestionAnswering #1450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
anijain2305 opened this issue Oct 3, 2022 · 3 comments
Assignees

Comments

@anijain2305
Copy link
Contributor

Repro


import torch
import torchdynamo
from torch import tensor, device
import torch.fx as fx
from torchdynamo.testing import rand_strided
from math import inf
from torchdynamo.debug_utils import run_fwd_maybe_bwd
from torchdynamo.debug_utils import same_two_models

args = [((1,), (1,), torch.int64, 'cuda', False), ((1, 128, 768), (98304, 768, 1), torch.float32, 'cuda', True)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]


from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = Linear(in_features=768, out_features=2, bias=True)



    def forward(self, start_positions : torch.Tensor, x : torch.Tensor):
        linear = self.linear(x)
        split = linear.split(1, dim = -1)
        getitem = split[0]
        squeeze = getitem.squeeze(-1)
        clamp = start_positions.clamp(0, 128)
        cross_entropy = torch.nn.functional.cross_entropy(squeeze, clamp, None, None, 128, None, 'mean', 0.0)
        return (cross_entropy,)



mod = Repro().cuda()
opt_mod = torchdynamo.optimize("inductor")(mod)


mod.eval()
opt_mod.eval()
with torch.cuda.amp.autocast(enabled=False):
    assert same_two_models(mod, mod, args), "Eager itself failed"
    assert same_two_models(mod, opt_mod, args), "Dynamo failed"

@desertfire desertfire self-assigned this Oct 3, 2022
@anijain2305
Copy link
Contributor Author

ElectraForQuestionAnswering is failing with similar looking pattern

Repro


import torch
import torchdynamo
from torch import tensor, device
import torch.fx as fx
from torchdynamo.testing import rand_strided
from math import inf
from torchdynamo.debug_utils import run_fwd_maybe_bwd
from torchdynamo.debug_utils import same_two_models

args = [((1,), (1,), torch.int64, 'cuda', False), ((1,), (1,), torch.int64, 'cuda', False), ((1, 512, 256), (131072, 256, 1), torch.float32, 'cuda', True)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]


from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.self_self_qa_outputs = Linear(in_features=256, out_features=2, bias=True)



    def forward(self, start_positions : torch.Tensor, end_positions : torch.Tensor, self_self_electra_encoder_layer_11__output_layer_norm):
        self_self_qa_outputs = self.self_self_qa_outputs(self_self_electra_encoder_layer_11__output_layer_norm);  self_self_electra_encoder_layer_11__output_layer_norm = None
        split = self_self_qa_outputs.split(1, dim = -1);  self_self_qa_outputs = None
        getitem_3 = split[0]
        getitem_4 = split[1];  split = None
        squeeze = getitem_3.squeeze(-1);  getitem_3 = None
        contiguous_12 = squeeze.contiguous();  squeeze = None
        squeeze_1 = getitem_4.squeeze(-1);  getitem_4 = None
        contiguous_13 = squeeze_1.contiguous();  squeeze_1 = None
        clamp = start_positions.clamp(0, 512);  start_positions = None
        clamp_1 = end_positions.clamp(0, 512);  end_positions = None
        cross_entropy = torch.nn.functional.cross_entropy(contiguous_12, clamp, None, None, 512, None, 'mean', 0.0);  contiguous_12 = clamp = None
        cross_entropy_1 = torch.nn.functional.cross_entropy(contiguous_13, clamp_1, None, None, 512, None, 'mean', 0.0);  contiguous_13 = clamp_1 = None
        add_37 = cross_entropy + cross_entropy_1;  cross_entropy = cross_entropy_1 = None
        return (add_37,)



mod = Repro().cuda()
opt_mod = torchdynamo.optimize("inductor")(mod)


mod.eval()
opt_mod.eval()
with torch.cuda.amp.autocast(enabled=False):
    assert same_two_models(mod, mod, args), "Eager itself failed"
    assert same_two_models(mod, opt_mod, args), "Dynamo failed"

@anijain2305 anijain2305 changed the title [inductor] Accuracy issue - BertForQuestionAnswering [inductor] Accuracy issue - BertForQuestionAnswering, ElectraForQuestionAnswering Oct 3, 2022
@anijain2305
Copy link
Contributor Author

Aaaaand similar for MegatronBertForQuestionAnswering. Going to skip looking at HF models, seems many of them have same pattern


import torch
import torchdynamo
from torch import tensor, device
import torch.fx as fx
from torchdynamo.testing import rand_strided
from math import inf
from torchdynamo.debug_utils import run_fwd_maybe_bwd
from torchdynamo.debug_utils import same_two_models

args = [((1,), (1,), torch.int64, 'cuda', False), ((1,), (1,), torch.int64, 'cuda', False), ((1, 128, 1024), (131072, 1024, 1), torch.float32, 'cuda', True), ((1, 128, 1024), (131072, 1024, 1), torch.float32, 'cuda', True)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]


from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.self_self_bert_encoder_layer_23__output_dropout = Dropout(p=0.1, inplace=False)
        self.self_self_bert_encoder_ln = LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        self.self_self_qa_outputs = Linear(in_features=1024, out_features=2, bias=True)



    def forward(self, start_positions : torch.Tensor, end_positions : torch.Tensor, add_71, self_self_bert_encoder_layer_23__output_dense):
        self_self_bert_encoder_layer_23__output_dropout = self.self_self_bert_encoder_layer_23__output_dropout(self_self_bert_encoder_layer_23__output_dense);  self_self_bert_encoder_layer_23__output_dense = None
        add_72 = add_71 + self_self_bert_encoder_layer_23__output_dropout;  add_71 = self_self_bert_encoder_layer_23__output_dropout = None
        self_self_bert_encoder_ln = self.self_self_bert_encoder_ln(add_72);  add_72 = None
        self_self_qa_outputs = self.self_self_qa_outputs(self_self_bert_encoder_ln);  self_self_bert_encoder_ln = None
        split = self_self_qa_outputs.split(1, dim = -1);  self_self_qa_outputs = None
        getitem_2 = split[0]
        getitem_3 = split[1];  split = None
        squeeze = getitem_2.squeeze(-1);  getitem_2 = None
        contiguous_24 = squeeze.contiguous();  squeeze = None
        squeeze_1 = getitem_3.squeeze(-1);  getitem_3 = None
        contiguous_25 = squeeze_1.contiguous();  squeeze_1 = None
        clamp = start_positions.clamp(0, 128);  start_positions = None
        clamp_1 = end_positions.clamp(0, 128);  end_positions = None
        cross_entropy = torch.nn.functional.cross_entropy(contiguous_24, clamp, None, None, 128, None, 'mean', 0.0);  clamp = None
        cross_entropy_1 = torch.nn.functional.cross_entropy(contiguous_25, clamp_1, None, None, 128, None, 'mean', 0.0);  clamp_1 = None
        add_73 = cross_entropy + cross_entropy_1;  cross_entropy = cross_entropy_1 = None
        truediv_24 = add_73 / 2;  add_73 = None
        return (truediv_24, contiguous_24, contiguous_25)



mod = Repro().cuda()
opt_mod = torchdynamo.optimize("inductor")(mod)


mod.eval()
opt_mod.eval()
with torch.cuda.amp.autocast(enabled=False):
    assert same_two_models(mod, mod, args), "Eager itself failed"
    assert same_two_models(mod, opt_mod, args), "Dynamo failed"

@desertfire
Copy link
Contributor

I am confident this is a triton codegen related issue, but I haven't found out a good fix yet. Here are my findings so far:

  1. If I use eager forward + inductor backward, the test will pass;
  2. If I use inductor forward + eager backward, the test will pass;
  3. If I force a fallback for aten.sum, the test will pass;
  4. The result difference for the weight tensor is exactly a sign difference, something like
tensor([[-1.3431, -0.4431, -0.2566,  ..., -0.5244,  0.7893,  1.4079],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')
tensor([[ 1.3431,  0.4431,  0.2566,  ...,  0.5244, -0.7893, -1.4079],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')

Here is the generated forward code,
https://gist.github.com/desertfire/a6a64543dda6e18af11fcdf050d084fa
, and the generated backward code, https://gist.github.com/desertfire/f895fd42f8349b7439160743b07013b6.

Based on my findings, I was able to reduce the error to a small example in #1490, test_accuracy_issue2, where if I load a bool tensor from the saved forward graph output, the test would fail, but if I create the bool tensor as fresh, the test will pass. From the surface the two tensors are equal, but their underlining storage is different.

desertfire added a commit that referenced this issue Oct 5, 2022
desertfire added a commit that referenced this issue Oct 5, 2022
desertfire added a commit that referenced this issue Oct 6, 2022
desertfire added a commit that referenced this issue Oct 11, 2022
Summary: The pinned version has a proper fix for #1450
desertfire added a commit that referenced this issue Oct 11, 2022
Summary: The pinned version has a proper fix for #1450
desertfire added a commit that referenced this issue Oct 11, 2022
Summary: The pinned version has a proper fix for #1450
desertfire added a commit that referenced this issue Oct 12, 2022
WARNING: you need to upgrade triton with this PR!

Summary:
* The pinned version has a proper fix for #1450
* Also fixes #1564
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants