Skip to content

Debugging BERT Performance bottleneck #3043

@anijain2305

Description

@anijain2305

We are training BERT with PT-XLA. The script for training is present here - https://github.com/codeislife99/xla/blob/master/test/test_train_mp_bert_mlm.py

End to End results

  Iteration time (sec) GPU time (sec)
PT-native 0.17 0.1654
PT-XLA 0.188 0.1105
Speedup PT-native vs PT-XLA 0.90426 1.49683

We observe these performance numbers. Iteration time is forward pass + backward pass + update.
In addition to iteration time, we use dlprof to measure the CUDA time. We observe that CUDA time is quite faster for PT-XLA (49% faster), but still the overall iteration time is 10% slower.

XRT Profile numbers

Name Num_calls Total time (ms) Percentage Acc Percentage
XrtCompile 6 266252.50842 78.55986 78.55986
XrtExecute 1280 72450.99722 21.37723 99.93709
XrtReleaseAllocation 17679 212.04111 0.06256 99.99965
XrtAllocateFromTensor 3416 0.69785 0.00021 99.99986
XrtReadLiteral 650 0.4893 0.00014 100

Auto-metric analysis

pt-xla-profiler: ================================================================================
pt-xla-profiler: Unlowered Op usage summary (more of these ops, lower performance)
pt-xla-profiler: Note: _local_scalar_dense typically indicates CPU context access
pt-xla-profiler: --------------------------------------------------------------------------------
pt-xla-profiler: FRAME (count=640):
pt-xla-profiler: Unlowered Op: "_local_scalar_dense"
pt-xla-profiler: Python Frames:
pt-xla-profiler:   <genexpr> (/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch_xla/amp/grad_scaler.py:11)
pt-xla-profiler:   _maybe_opt_step (/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch_xla/amp/grad_scaler.py:11)
pt-xla-profiler:   step (/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/cuda/amp/grad_scaler.py:339)
pt-xla-profiler:   loop_with_amp (/pytorch/xla/test/test_train_mp_bert_mlm.py:53)
pt-xla-profiler:   train (/pytorch/xla/test/test_train_mp_bert_mlm.py:165)
pt-xla-profiler:   main (/pytorch/xla/test/test_train_mp_bert_mlm.py:208)
pt-xla-profiler:   <module> (/pytorch/xla/test/test_train_mp_bert_mlm.py:237)
pt-xla-profiler:
pt-xla-profiler:
================================================================================

pt-xla-profiler: TransferFromServerTime too frequent: 648 counts during 1279 steps

Manual Timing Analysis

We put timers in the scripts and code to measure different portions and found that forward and backward calls (which supposedly build the XLA graph) are taking substantial time

  Time (ms)
Forward 24
Backward 17

So, around 40-45 ms (out of 180 ms) in every iteration goes in the forward and backward calls.

@JackCaoG Does this analysis make sense? Can you provide any pointers?
@@codeislife99

Metadata

Metadata

Assignees

Labels

staleHas not had recent activityxla:gpu

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions