-
Notifications
You must be signed in to change notification settings - Fork 558
Description
We are training BERT with PT-XLA. The script for training is present here - https://github.com/codeislife99/xla/blob/master/test/test_train_mp_bert_mlm.py
End to End results
Iteration time (sec) | GPU time (sec) | |
---|---|---|
PT-native | 0.17 | 0.1654 |
PT-XLA | 0.188 | 0.1105 |
Speedup PT-native vs PT-XLA | 0.90426 | 1.49683 |
We observe these performance numbers. Iteration time is forward pass + backward pass + update.
In addition to iteration time, we use dlprof to measure the CUDA time. We observe that CUDA time is quite faster for PT-XLA (49% faster), but still the overall iteration time is 10% slower.
XRT Profile numbers
Name | Num_calls | Total time (ms) | Percentage | Acc Percentage |
---|---|---|---|---|
XrtCompile | 6 | 266252.50842 | 78.55986 | 78.55986 |
XrtExecute | 1280 | 72450.99722 | 21.37723 | 99.93709 |
XrtReleaseAllocation | 17679 | 212.04111 | 0.06256 | 99.99965 |
XrtAllocateFromTensor | 3416 | 0.69785 | 0.00021 | 99.99986 |
XrtReadLiteral | 650 | 0.4893 | 0.00014 | 100 |
Auto-metric analysis
pt-xla-profiler: ================================================================================
pt-xla-profiler: Unlowered Op usage summary (more of these ops, lower performance)
pt-xla-profiler: Note: _local_scalar_dense typically indicates CPU context access
pt-xla-profiler: --------------------------------------------------------------------------------
pt-xla-profiler: FRAME (count=640):
pt-xla-profiler: Unlowered Op: "_local_scalar_dense"
pt-xla-profiler: Python Frames:
pt-xla-profiler: <genexpr> (/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch_xla/amp/grad_scaler.py:11)
pt-xla-profiler: _maybe_opt_step (/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch_xla/amp/grad_scaler.py:11)
pt-xla-profiler: step (/root/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/cuda/amp/grad_scaler.py:339)
pt-xla-profiler: loop_with_amp (/pytorch/xla/test/test_train_mp_bert_mlm.py:53)
pt-xla-profiler: train (/pytorch/xla/test/test_train_mp_bert_mlm.py:165)
pt-xla-profiler: main (/pytorch/xla/test/test_train_mp_bert_mlm.py:208)
pt-xla-profiler: <module> (/pytorch/xla/test/test_train_mp_bert_mlm.py:237)
pt-xla-profiler:
pt-xla-profiler:
================================================================================
pt-xla-profiler: TransferFromServerTime too frequent: 648 counts during 1279 steps
- The CPU fallback to local_scalar_dense happens because of Step of GradScaler here - https://github.com/pytorch/xla/blob/master/torch_xla/amp/grad_scaler.py#L11. So, IIUC, this is unavoidable
- TransferFromServerTime is frequent because of local_scalar_dense, so it is a byproduct and not a new cause.
Manual Timing Analysis
We put timers in the scripts and code to measure different portions and found that forward and backward calls (which supposedly build the XLA graph) are taking substantial time
Time (ms) | |
---|---|
Forward | 24 |
Backward | 17 |
So, around 40-45 ms (out of 180 ms) in every iteration goes in the forward and backward calls.
@JackCaoG Does this analysis make sense? Can you provide any pointers?
@@codeislife99