Skip to content

Extending minifier for detecting accuracy issues #1242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 6, 2022

Conversation

anijain2305
Copy link
Contributor

@anijain2305 anijain2305 commented Sep 15, 2022

Accuracy minifier for TORCHDYANMO_REPRO_AFTER = "dynamo". This likely requires more work but unblocks currents efforts to move forward with Inductor accuracy debugging.

Example usage -

TORCHDYNAMO_REPRO_AFTER="dynamo" TORCHDYNAMO_REPRO_LEVEL=4 python benchmarks/timm_models.py --accuracy --ci -d cuda --inductor --float32 --training --only=crossvit_9_240

Remaining work tracked here - pytorch/pytorch#93673

@ezyang
Copy link
Contributor

ezyang commented Sep 19, 2022

When I patch this into the symbolic shapes branches, and then attempt to run the minifier with

TORCHDYNAMO_REPRO_AFTER="dynamo" TORCHDYNAMO_REPRO_LEVEL=3 TORCH_SHOW_CPP_STACKTRACES=1 AOT_FX_GRAPHS_JOINT=1 TORCHDYNAMO_DYNAMIC_SHAPES=1 AOT_DYNAMIC_SHAPES=1 time python benchmarks/torchbench.py --only BERT_pytorch --accuracy-aot-nop --training

I get

  File "/data/users/ezyang/torchdynamo/torchdynamo/eval_frame.py", line 251, in catch_errors
    return callback(frame, cache_size)
  File "/data/users/ezyang/torchdynamo/torchdynamo/convert_frame.py", line 374, in _convert_frame
    result = inner_convert(frame, cache_size)
  File "/data/users/ezyang/torchdynamo/torchdynamo/convert_frame.py", line 110, in _fn
    return fn(*args, **kwargs)
  File "/data/users/ezyang/torchdynamo/torchdynamo/utils.py", line 76, in time_wrapper
    r = func(*args, **kwargs)
  File "/data/users/ezyang/torchdynamo/torchdynamo/convert_frame.py", line 313, in _convert_frame_assert
    code = transform_code_object(frame.f_code, transform)
  File "/data/users/ezyang/torchdynamo/torchdynamo/bytecode_transformation.py", line 338, in transform_code_object
    transformations(instructions, code_options)
  File "/data/users/ezyang/torchdynamo/torchdynamo/convert_frame.py", line 301, in transform
    tracer.run()
  File "/data/users/ezyang/torchdynamo/torchdynamo/symbolic_convert.py", line 338, in run
    and self.step()
  File "/data/users/ezyang/torchdynamo/torchdynamo/symbolic_convert.py", line 311, in step
    getattr(self, inst.opname)(inst)
  File "/data/users/ezyang/torchdynamo/torchdynamo/symbolic_convert.py", line 180, in wrapper
    self.output.compile_subgraph(self, reason=reason)
  File "/data/users/ezyang/torchdynamo/torchdynamo/output_graph.py", line 333, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/data/users/ezyang/torchdynamo/torchdynamo/output_graph.py", line 375, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/data/users/ezyang/torchdynamo/torchdynamo/output_graph.py", line 408, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torchdynamo.exc.BackendCompilerFailed: compile_fn raised AttributeError: 'function' object has no attribute 'zero_grad'

You can suppress this exception and fall back to eager by setting:
    torchdynamo.config.raise_on_backend_error = False
ERROR
compile_fn raised AttributeError: 'function' object has no attribute 'zero_grad'

@ezyang
Copy link
Contributor

ezyang commented Sep 19, 2022

more informative stack trace

Traceback (most recent call last): 
  File "/data/users/ezyang/torchdynamo/torchdynamo/output_graph.py", line 399, in call_user_compiler    
    compiled_fn = self.compiler_fn(gm, self.example_inputs())
  File "/data/users/ezyang/torchdynamo/torchdynamo/debug_utils.py", line 626, in debug_wrapper                    
    if backend_accuracy_fails(gm, example_inputs, compiler_fn, None):
  File "/data/users/ezyang/torchdynamo/torchdynamo/debug_utils.py", line 572, in backend_accuracy_fails
    return not same_two_models(gm, compiled_gm, example_inputs)
  File "/data/users/ezyang/torchdynamo/torchdynamo/debug_utils.py", line 380, in same_two_models
    res = run_fwd_maybe_bwd(opt_gm, example_inputs)
  File "/data/users/ezyang/torchdynamo/torchdynamo/debug_utils.py", line 359, in run_fwd_maybe_bwd
    gm.zero_grad(True)                                              
AttributeError: 'function' object has no attribute 'zero_grad'      

@anijain2305
Copy link
Contributor Author

@ezyang Fixed in the latest commit.

@ezyang
Copy link
Contributor

ezyang commented Sep 19, 2022

I ran it, it chundered on for a bit, and then it seems to have failed in a way where I don't have a minified copy. It is failing with:

You can suppress this exception and fall back to eager by setting:                                                                      
    torchdynamo.config.raise_on_backend_error = False                                                                                   
ERROR                                                                                                                                   
compile_fn raised TypeError: must be real number, not torch._C.SymIntNode                                                               
                                                                                                                                        
While executing %sqrt_330 : [#users=1] = call_function[target=math.sqrt](args = (%size_329,), kwargs = {})                              
Original traceback:                                                                                                                     
Module stack: {'self_transformer_blocks_6__input_sublayer': 'SublayerConnection', 'self_transformer_blocks_6__lambda_module_attention_attention': 'Attention'}                                                                                                                  
  File "/data/users/ezyang/benchmark/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/attention/single.py", line 16, in forward    
    / math.sqrt(query.size(-1))                                                                                                         
 |   File "/data/users/ezyang/benchmark/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/attention/multi_head.py", line 41, in forward                                                                                                                                     
    x, attn = self.attention(query, key, value, self.dropout, mask=mask)                                                                
 |   File "/data/users/ezyang/benchmark/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/transformer.py", line 20, in forward      
    return self.attention.forward(x, x, x, mask=self.mask)                                                                              
 |   File "/data/users/ezyang/benchmark/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/utils/sublayer.py", line 19, in forward   
    return x + self.dropout(sublayer.forward(self.norm(x)))                                                                             
 |   File "/data/users/ezyang/benchmark/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/transformer.py", line 47, in forward      
    x = self.input_sublayer(x, self.lambda_module)                                                                                      
 |   File "/data/users/ezyang/benchmark/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/bert.py", line 47, in <graph break in forward>                                                                                                                                    
    x = transformer.forward(x, mask)                                                                                                    
                                                                                                                                        
                                                                                                                                        
You can suppress this exception and fall back to eager by setting:                                                                      
    torchdynamo.config.raise_on_backend_error = False  

Is the minifier not catching enough exceptions?

@desertfire
Copy link
Contributor

I try to use your branch for #1039, but the generated repro.py doesn't really give an accuracy error.

@ezyang
Copy link
Contributor

ezyang commented Sep 27, 2022

I pushed a merge to master as this has the zero_grad fix that I still need for pytorch_BERT based minimizations

@ezyang
Copy link
Contributor

ezyang commented Sep 27, 2022

The produced minified scripts don't actually test for equality on the outputs

@anijain2305 anijain2305 force-pushed the accuracy-minifier branch 6 times, most recently from 19f8b9d to 56f5587 Compare October 4, 2022 01:15
tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})"
else:
tensor_str = (
f"torch.randint(2, size={list(buffer.shape)}, dtype={buffer.dtype})"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably just do 1 here.

from torchinductor import config
from torchinductor.compile_fx import compile_fx_inner

config.triton.autotune = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for?

inductor or nvfuser. Intercepting after Aot Autograd presents neat
abstration, where all the params are lifted as graph inputs, making it easy
to save the graph as a string.
"""

@functools.wraps(compiler)
@functools.wraps(compiler_fn)
def debug_wrapper(gm, example_inputs, **kwargs):
orig_graph = copy.deepcopy(gm.graph)
assert config.repro_after in ("dynamo", "aot", None)

def deferred_for_real_inputs(*real_inputs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to fix, but this seems kinda ... strange to me. We delay compilation from the ... first time we see inputs, to slightly later after the first time we see inputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do compilation with fake tensors first and then run the compiled model with real tensors later. But the code gets real ugly.

@@ -646,24 +762,42 @@ def debug_wrapper(gm, example_inputs, **kwargs):
config.raise_on_backend_error = True
if config.repro_level == 3:
dump_to_minify_after_dynamo(gm, example_inputs, compiler_name)
try:

# Check for either accuracy (level 4) or other type of failures.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice if we could share more code between the two, but not a big deal if it's awkward.

zero_grad

Requires grad

Accuracy minifier

Extending to Inductor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants