Skip to content

Commit bf51948

Browse files
committed
Added torch compile row to pytorch install table
1 parent 78250ba commit bf51948

File tree

2 files changed

+143
-0
lines changed

2 files changed

+143
-0
lines changed

_includes/quick-start-module.js

+115
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ var opts = {
2121
pm: 'pip',
2222
language: 'python',
2323
ptbuild: 'stable',
24+
'torch-compile': null
2425
};
2526

2627
var supportedCloudPlatforms = [
@@ -34,6 +35,7 @@ var package = $(".package > .option");
3435
var language = $(".language > .option");
3536
var cuda = $(".cuda > .option");
3637
var ptbuild = $(".ptbuild > .option");
38+
var torchCompile = $(".torch-compile > .option")
3739

3840
os.on("click", function() {
3941
selectedOption(os, this, "os");
@@ -50,6 +52,9 @@ cuda.on("click", function() {
5052
ptbuild.on("click", function() {
5153
selectedOption(ptbuild, this, "ptbuild")
5254
});
55+
torchCompile.on("click", function() {
56+
selectedOption(torchCompile, this, "torch-compile")
57+
});
5358

5459
// Pre-select user's operating system
5560
$(function() {
@@ -168,6 +173,110 @@ function changeAccNoneName(osname) {
168173
}
169174
}
170175

176+
function getIDFromBackend(backend) {
177+
const idTobackendMap = {
178+
inductor: 'inductor',
179+
cgraphs : 'cudagraphs',
180+
onnxrt: 'onnxrt',
181+
openvino: 'openvino',
182+
tensorrt: 'tensorrt',
183+
tvm: 'tvm',
184+
};
185+
return idTobackendMap[backend];
186+
}
187+
188+
function getPmCmd(backend) {
189+
const pmCmd = {
190+
onnxrt: 'onnxruntime',
191+
tvm: 'apache-tvm',
192+
openvino: 'openvino',
193+
tensorrt: 'torch-tensorrt',
194+
};
195+
return pmCmd[backend];
196+
}
197+
198+
function getImportCmd(backend) {
199+
const importCmd = {
200+
onnxrt: 'import onnxruntime',
201+
tvm: 'import tvm',
202+
openvino: 'import openvino.torch',
203+
tensorrt: 'import torch_tensorrt'
204+
}
205+
return importCmd[backend];
206+
}
207+
208+
function getInstallCommand(optionID) {
209+
backend = getIDFromBackend(optionID);
210+
pmCmd = getPmCmd(optionID);
211+
finalCmd = "";
212+
if (opts.pm == "pip") {
213+
finalCmd = `pip3 install ${pmCmd}`;
214+
}
215+
else if (opts.pm == "conda") {
216+
finalCmd = `conda install ${pmCmd}`;
217+
}
218+
return finalCmd;
219+
}
220+
221+
function getTorchCompileUsage(optionId) {
222+
backend = getIDFromBackend(optionId);
223+
importCmd = "<br>" + getImportCmd(optionId) + "<br>";
224+
finalCmd = "";
225+
tcUsage = "# Torch Compile usage: ";
226+
backendCmd = `torch.compile(model, backend="${backend}")`;
227+
libtorchCmd = `# Torch compile ${backend} not supported with Libtorch`;
228+
229+
if (opts.pm == "libtorch") {
230+
return libtorchCmd;
231+
}
232+
if (backend == "openvino") {
233+
if (opts.pm == "source") {
234+
finalCmd += "# Follow instructions at this URL to build openvino from source: https://github.com/openvinotoolkit/openvino/blob/master/docs/dev/build.md" + "<br>" ;
235+
tcUsage += importCmd;
236+
}
237+
else if (opts.pm == "conda") {
238+
tcUsage += importCmd;
239+
}
240+
if (opts.os == "windows" && !tcUsage.includes(importCmd)) {
241+
tcUsage += importCmd;
242+
}
243+
}
244+
else{
245+
tcUsage += importCmd;
246+
}
247+
if (backend == "onnxrt") {
248+
if (opts.pm == "source") {
249+
finalCmd += "# Follow instructions at this URL to build onnxruntime from source: https://onnxruntime.ai/docs/build" + "<br>" ;
250+
}
251+
}
252+
if (backend == "tvm") {
253+
if (opts.pm == "source") {
254+
finalCmd += "# Follow instructions at this URL to build tvm from source: https://tvm.apache.org/docs/install/from_source.html" + "<br>" ;
255+
}
256+
}
257+
if (backend == "tensorrt") {
258+
if (opts.pm == "source") {
259+
finalCmd += "# Follow instructions at this URL to build tensorrt from source: https://pytorch.org/TensorRT/getting_started/installation.html#compiling-from-source" + "<br>" ;
260+
}
261+
}
262+
finalCmd += tcUsage + backendCmd;
263+
return finalCmd
264+
}
265+
266+
function addTorchCompileCommandNote(selectedOptionId) {
267+
268+
if (!selectedOptionId) {
269+
return;
270+
}
271+
272+
$("#command").append(
273+
`<pre> ${getInstallCommand(selectedOptionId)} </pre>`
274+
);
275+
$("#command").append(
276+
`<pre> ${getTorchCompileUsage(selectedOptionId)} </pre>`
277+
);
278+
}
279+
171280
function selectedOption(option, selection, category) {
172281
$(option).removeClass("selected");
173282
$(selection).addClass("selected");
@@ -208,13 +317,19 @@ function selectedOption(option, selection, category) {
208317
changeVersion(opts.ptbuild);
209318
//make sure unsupported platforms are disabled
210319
disableUnsupportedPlatforms(opts.os);
320+
} else if (category === "torch-compile") {
321+
if (selection.id === previousSelection) {
322+
$(selection).removeClass("selected");
323+
opts[category] = null;
324+
}
211325
}
212326
commandMessage(buildMatcher());
213327
if (category === "os") {
214328
disableUnsupportedPlatforms(opts.os);
215329
display(opts.os, 'installation', 'os');
216330
}
217331
changeAccNoneName(opts.os);
332+
addTorchCompileCommandNote(opts['torch-compile'])
218333
}
219334

220335
function display(selection, id, category) {

_includes/quick_start_local.html

+28
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
<div class="col-md-12 title-block">
2525
<div class="option-text">Compute Platform</div>
2626
</div>
27+
<div class="col-md-12 title-block">
28+
<div class="option-text">Torch Compile</div>
29+
</div>
2730
<div class="col-md-12 title-block command-block">
2831
<div class="option-text command-text">Run this Command:</div>
2932
</div>
@@ -103,6 +106,31 @@
103106
<div class="option-text">CPU</div>
104107
</div>
105108
</div>
109+
<div class="row torch-compile">
110+
<!-- Section Label -->
111+
<div class="col-md-12 title-block mobile-heading">
112+
<div class="option-text">Torch Compile</div>
113+
</div>
114+
<!-- Section Label -->
115+
<div class="col-md-2 option block version" id="inductor">
116+
<div class="option-text">Inductor</div>
117+
</div>
118+
<div class="col-md-2 option block version" id="cgraphs">
119+
<div class="option-text">CUDA Graphs</div>
120+
</div>
121+
<div class="col-md-2 option block version" id="openvino">
122+
<div class="option-text">OpenVINO</div>
123+
</div>
124+
<div class="col-md-2 option block version" id="onnxrt">
125+
<div class="option-text">ONNX Runtime</div>
126+
</div>
127+
<div class="col-md-2 option block version" id="tensorrt">
128+
<div class="option-text">TensorRT</div>
129+
</div>
130+
<div class="col-md-2 option block version" id="tvm">
131+
<div class="option-text">TVM</div>
132+
</div>
133+
</div>
106134
<div class="row">
107135
<div class="col-md-12 title-block command-mobile-heading">
108136
<div class="option-text">Run this Command:</div>

0 commit comments

Comments
 (0)