-
Notifications
You must be signed in to change notification settings - Fork 559
Description
Context:
I am running inference benchmark using dynamo bridge + torch_bench on TPU v4 single device. This thread is more to update the current info and some todos. We have done the similar benchmark in https://docs.google.com/document/d/1xXwCDdQl1n2aCaJ8Lu3qn060Hp18pwj4MELVTZ3mP4g/edit. @shunting314 has done an optimization to trace the model on XLA device instead of the cpu device which result in some better performance.
PyTorch branch:
pytorch/pytorch#88449 + some profiler code(cavet: use avg_pool
instead of maxpool
, this is fixed now)
XLA branch:
nightly + a patch (check #4306 (comment))
diff --git a/third_party/xla_client/pjrt_computation_client.cc b/third_party/xla_client/pjrt_computation_client.cc
index 207c8874..fa847c0d 100755
--- a/third_party/xla_client/pjrt_computation_client.cc
+++ b/third_party/xla_client/pjrt_computation_client.cc
@@ -308,6 +308,7 @@ PjRtComputationClient::ExecuteComputation(
std::vector<DataPtr> datas;
datas.reserve(results.size());
for (auto& result : results) {
+ auto status = result->GetReadyFuture().Await();
std::unique_ptr<xla::PjRtBuffer> buffer = std::move(result);
std::shared_ptr<PjRtData> data = std::make_shared<PjRtData>(
TorchBench + TorchAudio + TorchText branch
nightly
Runtime
PJRT, check https://github.com/pytorch/xla/blob/master/docs/pjrt.md
Command
XLA_HLO_DEBUG=0 XLA_IR_DEBUG=0 USE_FAKE_TENSOR=0 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --backend=torchxla_trace_once --only $MODEL -n 30
Sample profiles
gs://tpu-pytorch/tmp/dynamo_profile/dynamo_tracing/try13/
I believe this one is a resnet50
First part of the trace(before wait_device_ops
is the lazy and the remaining is the dynamo) that lazy took some times to trace the graph before execution while dynamo's walltime is most just device execution.
Result
cpu eval resnet18 1.768x p=0.00
cpu eval resnet50 1.610x p=0.00
cpu eval resnext50_32x4d 1.328x p=0.00
cpu eval alexnet 1.261x p=0.00
cpu eval mobilenet_v2 2.017x p=0.00
cpu eval mnasnet1_0 1.686x p=0.00
cpu eval vgg16 1.155x p=0.00
cpu eval BERT_pytorch 3.502x SAME
squeezenet1_1 --> RuntimeError: Fail to extact the compiled graph because of fallback: aten::avg_pool2d=3
timm_vision_transformer --> Segmentation fault (core dumped)
geomean --> model can't find
(seems like it is removed from torch bench)
TODO
- investigate why
timm_vision_transformer
crashes - Enable more models on torch bench.
FYI @shunting314 @wconstab @ezyang @miladm @alanwaketan @wonjoolee95