Skip to content

TPU dynamo speed up on inference analysis #4328

@JackCaoG

Description

@JackCaoG

Context:
I am running inference benchmark using dynamo bridge + torch_bench on TPU v4 single device. This thread is more to update the current info and some todos. We have done the similar benchmark in https://docs.google.com/document/d/1xXwCDdQl1n2aCaJ8Lu3qn060Hp18pwj4MELVTZ3mP4g/edit. @shunting314 has done an optimization to trace the model on XLA device instead of the cpu device which result in some better performance.

PyTorch branch:

pytorch/pytorch#88449 + some profiler code(cavet: use avg_pool instead of maxpool, this is fixed now)

XLA branch:

nightly + a patch (check #4306 (comment))

diff --git a/third_party/xla_client/pjrt_computation_client.cc b/third_party/xla_client/pjrt_computation_client.cc
index 207c8874..fa847c0d 100755
--- a/third_party/xla_client/pjrt_computation_client.cc
+++ b/third_party/xla_client/pjrt_computation_client.cc
@@ -308,6 +308,7 @@ PjRtComputationClient::ExecuteComputation(
   std::vector<DataPtr> datas;
   datas.reserve(results.size());
   for (auto& result : results) {
+    auto status = result->GetReadyFuture().Await();
     std::unique_ptr<xla::PjRtBuffer> buffer = std::move(result);
 
     std::shared_ptr<PjRtData> data = std::make_shared<PjRtData>(

TorchBench + TorchAudio + TorchText branch

nightly

Runtime

PJRT, check https://github.com/pytorch/xla/blob/master/docs/pjrt.md

Command

XLA_HLO_DEBUG=0 XLA_IR_DEBUG=0 USE_FAKE_TENSOR=0 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --backend=torchxla_trace_once --only $MODEL -n 30

Sample profiles

gs://tpu-pytorch/tmp/dynamo_profile/dynamo_tracing/try13/
I believe this one is a resnet50
image

First part of the trace(before wait_device_ops is the lazy and the remaining is the dynamo) that lazy took some times to trace the graph before execution while dynamo's walltime is most just device execution.

Result

cpu  eval  resnet18                           1.768x p=0.00
cpu  eval  resnet50                           1.610x p=0.00
cpu  eval  resnext50_32x4d                    1.328x p=0.00
cpu  eval  alexnet                            1.261x p=0.00
cpu  eval  mobilenet_v2                       2.017x p=0.00
cpu  eval  mnasnet1_0                         1.686x p=0.00
cpu  eval  vgg16                              1.155x p=0.00
cpu  eval  BERT_pytorch                       3.502x SAME

squeezenet1_1 --> RuntimeError: Fail to extact the compiled graph because of fallback: aten::avg_pool2d=3
timm_vision_transformer --> Segmentation fault (core dumped)
geomean --> model can't find (seems like it is removed from torch bench)

TODO

  1. investigate why timm_vision_transformer crashes
  2. Enable more models on torch bench.

FYI @shunting314 @wconstab @ezyang @miladm @alanwaketan @wonjoolee95

Metadata

Metadata

Assignees

Labels

dynamoperformancetriagedThis issue has been reviewed by the triage team and the appropriate priority assigned.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions