.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "intermediate/fx_profiling_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_intermediate_fx_profiling_tutorial.py: (beta) Building a Simple CPU Performance Profiler with FX ********************************************************* **Author**: `James Reed `_ In this tutorial, we are going to use FX to do the following: 1) Capture PyTorch Python code in a way that we can inspect and gather statistics about the structure and execution of the code 2) Build out a small class that will serve as a simple performance "profiler", collecting runtime statistics about each part of the model from actual runs. .. GENERATED FROM PYTHON SOURCE LINES 18-20 For this tutorial, we are going to use the torchvision ResNet18 model for demonstration purposes. .. GENERATED FROM PYTHON SOURCE LINES 20-28 .. code-block:: default import torch import torch.fx import torchvision.models as models rn18 = models.resnet18() rn18.eval() .. rst-class:: sphx-glr-script-out .. code-block:: none ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer2): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer3): Sequential( (0): BasicBlock( (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer4): Sequential( (0): BasicBlock( (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) (fc): Linear(in_features=512, out_features=1000, bias=True) ) .. GENERATED FROM PYTHON SOURCE LINES 29-32 Now that we have our model, we want to inspect deeper into its performance. That is, for the following invocation, which parts of the model are taking the longest? .. GENERATED FROM PYTHON SOURCE LINES 32-35 .. code-block:: default input = torch.randn(5, 3, 224, 224) output = rn18(input) .. GENERATED FROM PYTHON SOURCE LINES 36-46 A common way of answering that question is to go through the program source, add code that collects timestamps at various points in the program, and compare the difference between those timestamps to see how long the regions between the timestamps take. That technique is certainly applicable to PyTorch code, however it would be nicer if we didn't have to copy over model code and edit it, especially code we haven't written (like this torchvision model). Instead, we are going to use FX to automate this "instrumentation" process without needing to modify any source. .. GENERATED FROM PYTHON SOURCE LINES 48-50 First, let's get some imports out of the way (we will be using all of these later in the code). .. GENERATED FROM PYTHON SOURCE LINES 50-55 .. code-block:: default import statistics, tabulate, time from typing import Any, Dict, List from torch.fx import Interpreter .. GENERATED FROM PYTHON SOURCE LINES 56-60 .. note:: ``tabulate`` is an external library that is not a dependency of PyTorch. We will be using it to more easily visualize performance data. Please make sure you've installed it from your favorite Python package source. .. GENERATED FROM PYTHON SOURCE LINES 62-67 Capturing the Model with Symbolic Tracing ----------------------------------------- Next, we are going to use FX's symbolic tracing mechanism to capture the definition of our model in a data structure we can manipulate and examine. .. GENERATED FROM PYTHON SOURCE LINES 67-71 .. code-block:: default traced_rn18 = torch.fx.symbolic_trace(rn18) print(traced_rn18.graph) .. rst-class:: sphx-glr-script-out .. code-block:: none graph(): %x : torch.Tensor [num_users=1] = placeholder[target=x] %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {}) %bn1 : [num_users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {}) %relu : [num_users=1] = call_module[target=relu](args = (%bn1,), kwargs = {}) %maxpool : [num_users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {}) %layer1_0_conv1 : [num_users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {}) %layer1_0_bn1 : [num_users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {}) %layer1_0_relu : [num_users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {}) %layer1_0_conv2 : [num_users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {}) %layer1_0_bn2 : [num_users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {}) %add : [num_users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {}) %layer1_0_relu_1 : [num_users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {}) %layer1_1_conv1 : [num_users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {}) %layer1_1_bn1 : [num_users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {}) %layer1_1_relu : [num_users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {}) %layer1_1_conv2 : [num_users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {}) %layer1_1_bn2 : [num_users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {}) %add_1 : [num_users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {}) %layer1_1_relu_1 : [num_users=2] = call_module[target=layer1.1.relu](args = (%add_1,), kwargs = {}) %layer2_0_conv1 : [num_users=1] = call_module[target=layer2.0.conv1](args = (%layer1_1_relu_1,), kwargs = {}) %layer2_0_bn1 : [num_users=1] = call_module[target=layer2.0.bn1](args = (%layer2_0_conv1,), kwargs = {}) %layer2_0_relu : [num_users=1] = call_module[target=layer2.0.relu](args = (%layer2_0_bn1,), kwargs = {}) %layer2_0_conv2 : [num_users=1] = call_module[target=layer2.0.conv2](args = (%layer2_0_relu,), kwargs = {}) %layer2_0_bn2 : [num_users=1] = call_module[target=layer2.0.bn2](args = (%layer2_0_conv2,), kwargs = {}) %layer2_0_downsample_0 : [num_users=1] = call_module[target=layer2.0.downsample.0](args = (%layer1_1_relu_1,), kwargs = {}) %layer2_0_downsample_1 : [num_users=1] = call_module[target=layer2.0.downsample.1](args = (%layer2_0_downsample_0,), kwargs = {}) %add_2 : [num_users=1] = call_function[target=operator.add](args = (%layer2_0_bn2, %layer2_0_downsample_1), kwargs = {}) %layer2_0_relu_1 : [num_users=2] = call_module[target=layer2.0.relu](args = (%add_2,), kwargs = {}) %layer2_1_conv1 : [num_users=1] = call_module[target=layer2.1.conv1](args = (%layer2_0_relu_1,), kwargs = {}) %layer2_1_bn1 : [num_users=1] = call_module[target=layer2.1.bn1](args = (%layer2_1_conv1,), kwargs = {}) %layer2_1_relu : [num_users=1] = call_module[target=layer2.1.relu](args = (%layer2_1_bn1,), kwargs = {}) %layer2_1_conv2 : [num_users=1] = call_module[target=layer2.1.conv2](args = (%layer2_1_relu,), kwargs = {}) %layer2_1_bn2 : [num_users=1] = call_module[target=layer2.1.bn2](args = (%layer2_1_conv2,), kwargs = {}) %add_3 : [num_users=1] = call_function[target=operator.add](args = (%layer2_1_bn2, %layer2_0_relu_1), kwargs = {}) %layer2_1_relu_1 : [num_users=2] = call_module[target=layer2.1.relu](args = (%add_3,), kwargs = {}) %layer3_0_conv1 : [num_users=1] = call_module[target=layer3.0.conv1](args = (%layer2_1_relu_1,), kwargs = {}) %layer3_0_bn1 : [num_users=1] = call_module[target=layer3.0.bn1](args = (%layer3_0_conv1,), kwargs = {}) %layer3_0_relu : [num_users=1] = call_module[target=layer3.0.relu](args = (%layer3_0_bn1,), kwargs = {}) %layer3_0_conv2 : [num_users=1] = call_module[target=layer3.0.conv2](args = (%layer3_0_relu,), kwargs = {}) %layer3_0_bn2 : [num_users=1] = call_module[target=layer3.0.bn2](args = (%layer3_0_conv2,), kwargs = {}) %layer3_0_downsample_0 : [num_users=1] = call_module[target=layer3.0.downsample.0](args = (%layer2_1_relu_1,), kwargs = {}) %layer3_0_downsample_1 : [num_users=1] = call_module[target=layer3.0.downsample.1](args = (%layer3_0_downsample_0,), kwargs = {}) %add_4 : [num_users=1] = call_function[target=operator.add](args = (%layer3_0_bn2, %layer3_0_downsample_1), kwargs = {}) %layer3_0_relu_1 : [num_users=2] = call_module[target=layer3.0.relu](args = (%add_4,), kwargs = {}) %layer3_1_conv1 : [num_users=1] = call_module[target=layer3.1.conv1](args = (%layer3_0_relu_1,), kwargs = {}) %layer3_1_bn1 : [num_users=1] = call_module[target=layer3.1.bn1](args = (%layer3_1_conv1,), kwargs = {}) %layer3_1_relu : [num_users=1] = call_module[target=layer3.1.relu](args = (%layer3_1_bn1,), kwargs = {}) %layer3_1_conv2 : [num_users=1] = call_module[target=layer3.1.conv2](args = (%layer3_1_relu,), kwargs = {}) %layer3_1_bn2 : [num_users=1] = call_module[target=layer3.1.bn2](args = (%layer3_1_conv2,), kwargs = {}) %add_5 : [num_users=1] = call_function[target=operator.add](args = (%layer3_1_bn2, %layer3_0_relu_1), kwargs = {}) %layer3_1_relu_1 : [num_users=2] = call_module[target=layer3.1.relu](args = (%add_5,), kwargs = {}) %layer4_0_conv1 : [num_users=1] = call_module[target=layer4.0.conv1](args = (%layer3_1_relu_1,), kwargs = {}) %layer4_0_bn1 : [num_users=1] = call_module[target=layer4.0.bn1](args = (%layer4_0_conv1,), kwargs = {}) %layer4_0_relu : [num_users=1] = call_module[target=layer4.0.relu](args = (%layer4_0_bn1,), kwargs = {}) %layer4_0_conv2 : [num_users=1] = call_module[target=layer4.0.conv2](args = (%layer4_0_relu,), kwargs = {}) %layer4_0_bn2 : [num_users=1] = call_module[target=layer4.0.bn2](args = (%layer4_0_conv2,), kwargs = {}) %layer4_0_downsample_0 : [num_users=1] = call_module[target=layer4.0.downsample.0](args = (%layer3_1_relu_1,), kwargs = {}) %layer4_0_downsample_1 : [num_users=1] = call_module[target=layer4.0.downsample.1](args = (%layer4_0_downsample_0,), kwargs = {}) %add_6 : [num_users=1] = call_function[target=operator.add](args = (%layer4_0_bn2, %layer4_0_downsample_1), kwargs = {}) %layer4_0_relu_1 : [num_users=2] = call_module[target=layer4.0.relu](args = (%add_6,), kwargs = {}) %layer4_1_conv1 : [num_users=1] = call_module[target=layer4.1.conv1](args = (%layer4_0_relu_1,), kwargs = {}) %layer4_1_bn1 : [num_users=1] = call_module[target=layer4.1.bn1](args = (%layer4_1_conv1,), kwargs = {}) %layer4_1_relu : [num_users=1] = call_module[target=layer4.1.relu](args = (%layer4_1_bn1,), kwargs = {}) %layer4_1_conv2 : [num_users=1] = call_module[target=layer4.1.conv2](args = (%layer4_1_relu,), kwargs = {}) %layer4_1_bn2 : [num_users=1] = call_module[target=layer4.1.bn2](args = (%layer4_1_conv2,), kwargs = {}) %add_7 : [num_users=1] = call_function[target=operator.add](args = (%layer4_1_bn2, %layer4_0_relu_1), kwargs = {}) %layer4_1_relu_1 : [num_users=1] = call_module[target=layer4.1.relu](args = (%add_7,), kwargs = {}) %avgpool : [num_users=1] = call_module[target=avgpool](args = (%layer4_1_relu_1,), kwargs = {}) %flatten : [num_users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {}) %fc : [num_users=1] = call_module[target=fc](args = (%flatten,), kwargs = {}) return fc .. GENERATED FROM PYTHON SOURCE LINES 72-79 This gives us a Graph representation of the ResNet18 model. A Graph consists of a series of Nodes connected to each other. Each Node represents a call-site in the Python code (whether to a function, a module, or a method) and the edges (represented as ``args`` and ``kwargs`` on each node) represent the values passed between these call-sites. More information about the Graph representation and the rest of FX's APIs ca be found at the FX documentation https://pytorch.org/docs/master/fx.html. .. GENERATED FROM PYTHON SOURCE LINES 82-97 Creating a Profiling Interpreter -------------------------------- Next, we are going to create a class that inherits from ``torch.fx.Interpreter``. Though the ``GraphModule`` that ``symbolic_trace`` produces compiles Python code that is run when you call a ``GraphModule``, an alternative way to run a ``GraphModule`` is by executing each ``Node`` in the ``Graph`` one by one. That is the functionality that ``Interpreter`` provides: It interprets the graph node- by-node. By inheriting from ``Interpreter``, we can override various functionality and install the profiling behavior we want. The goal is to have an object to which we can pass a model, invoke the model 1 or more times, then get statistics about how long the model and each part of the model took during those runs. Let's define our ``ProfilingInterpreter`` class: .. GENERATED FROM PYTHON SOURCE LINES 97-196 .. code-block:: default class ProfilingInterpreter(Interpreter): def __init__(self, mod : torch.nn.Module): # Rather than have the user symbolically trace their model, # we're going to do it in the constructor. As a result, the # user can pass in any ``Module`` without having to worry about # symbolic tracing APIs gm = torch.fx.symbolic_trace(mod) super().__init__(gm) # We are going to store away two things here: # # 1. A list of total runtimes for ``mod``. In other words, we are # storing away the time ``mod(...)`` took each time this # interpreter is called. self.total_runtime_sec : List[float] = [] # 2. A map from ``Node`` to a list of times (in seconds) that # node took to run. This can be seen as similar to (1) but # for specific sub-parts of the model. self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {} ###################################################################### # Next, let's override our first method: ``run()``. ``Interpreter``'s ``run`` # method is the top-level entry point for execution of the model. We will # want to intercept this so that we can record the total runtime of the # model. def run(self, *args) -> Any: # Record the time we started running the model t_start = time.time() # Run the model by delegating back into Interpreter.run() return_val = super().run(*args) # Record the time we finished running the model t_end = time.time() # Store the total elapsed time this model execution took in the # ``ProfilingInterpreter`` self.total_runtime_sec.append(t_end - t_start) return return_val ###################################################################### # Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each # time it executes a single node. We will intercept this so that we # can measure and record the time taken for each individual call in # the model. def run_node(self, n : torch.fx.Node) -> Any: # Record the time we started running the op t_start = time.time() # Run the op by delegating back into Interpreter.run_node() return_val = super().run_node(n) # Record the time we finished running the op t_end = time.time() # If we don't have an entry for this node in our runtimes_sec # data structure, add one with an empty list value. self.runtimes_sec.setdefault(n, []) # Record the total elapsed time for this single invocation # in the runtimes_sec data structure self.runtimes_sec[n].append(t_end - t_start) return return_val ###################################################################### # Finally, we are going to define a method (one which doesn't override # any ``Interpreter`` method) that provides us a nice, organized view of # the data we have collected. def summary(self, should_sort : bool = False) -> str: # Build up a list of summary information for each node node_summaries : List[List[Any]] = [] # Calculate the mean runtime for the whole network. Because the # network may have been called multiple times during profiling, # we need to summarize the runtimes. We choose to use the # arithmetic mean for this. mean_total_runtime = statistics.mean(self.total_runtime_sec) # For each node, record summary statistics for node, runtimes in self.runtimes_sec.items(): # Similarly, compute the mean runtime for ``node`` mean_runtime = statistics.mean(runtimes) # For easier understanding, we also compute the percentage # time each node took with respect to the whole network. pct_total = mean_runtime / mean_total_runtime * 100 # Record the node's type, name of the node, mean runtime, and # percent runtime. node_summaries.append( [node.op, str(node), mean_runtime, pct_total]) # One of the most important questions to answer when doing performance # profiling is "Which op(s) took the longest?". We can make this easy # to see by providing sorting functionality in our summary view if should_sort: node_summaries.sort(key=lambda s: s[2], reverse=True) # Use the ``tabulate`` library to create a well-formatted table # presenting our summary information headers : List[str] = [ 'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime' ] return tabulate.tabulate(node_summaries, headers=headers) .. GENERATED FROM PYTHON SOURCE LINES 197-203 .. note:: We use Python's ``time.time`` function to pull wall clock timestamps and compare them. This is not the most accurate way to measure performance, and will only give us a first- order approximation. We use this simple technique only for the purpose of demonstration in this tutorial. .. GENERATED FROM PYTHON SOURCE LINES 205-209 Investigating the Performance of ResNet18 ----------------------------------------- We can now use ``ProfilingInterpreter`` to inspect the performance characteristics of our ResNet18 model; .. GENERATED FROM PYTHON SOURCE LINES 209-214 .. code-block:: default interp = ProfilingInterpreter(rn18) interp.run(input) print(interp.summary(True)) .. rst-class:: sphx-glr-script-out .. code-block:: none Op type Op Average runtime (s) Pct total runtime ------------- --------------------- --------------------- ------------------- call_module maxpool 0.00674701 12.0222 call_module conv1 0.00448108 7.98466 call_module layer4_1_conv2 0.00302887 5.39702 call_module layer4_0_conv2 0.00300789 5.35964 call_module layer4_1_conv1 0.00299788 5.3418 call_module layer1_0_conv1 0.00269198 4.79674 call_module layer1_1_conv2 0.00249958 4.4539 call_module layer1_0_conv2 0.00246072 4.38466 call_module layer2_0_conv2 0.00232244 4.13826 call_module layer3_0_conv2 0.00221491 3.94666 call_module layer1_1_conv1 0.00220823 3.93476 call_module layer2_1_conv2 0.00219274 3.90715 call_module layer3_1_conv2 0.00215411 3.83833 call_module layer3_1_conv1 0.00211573 3.76993 call_module layer2_1_conv1 0.00194478 3.46533 call_module layer4_0_conv1 0.00184011 3.27883 call_module layer2_0_conv1 0.00140786 2.50861 call_module layer3_0_conv1 0.00130749 2.32976 call_module layer2_0_downsample_0 0.0010097 1.79915 call_module bn1 0.000562429 1.00217 call_module layer3_0_downsample_0 0.000504255 0.898513 call_module layer4_0_downsample_0 0.000440121 0.784234 call_function add 0.000385046 0.686098 call_function add_1 0.000368118 0.655935 call_module relu 0.000238419 0.424829 call_module layer1_0_bn1 0.000216007 0.384895 call_function add_3 0.000208855 0.37215 call_module fc 0.000193357 0.344536 call_module layer1_1_bn2 0.00018692 0.333066 call_module layer1_0_bn2 0.000183821 0.327543 call_module layer1_1_bn1 0.000182867 0.325844 call_module layer2_0_bn1 0.000171185 0.305027 call_module layer2_0_downsample_1 0.000131845 0.23493 call_module layer4_1_bn2 0.000120163 0.214114 call_module avgpool 0.000119925 0.213689 call_module layer3_1_bn2 0.000113964 0.203068 call_module layer3_0_bn2 0.000111818 0.199245 call_module layer4_0_bn2 0.000110149 0.196271 call_module layer3_0_bn1 0.000109196 0.194572 call_module layer1_0_relu 9.32217e-05 0.166108 call_module layer4_1_bn1 8.60691e-05 0.153363 call_module layer3_0_downsample_1 8.39233e-05 0.14954 call_module layer1_1_relu_1 8.36849e-05 0.149115 call_module layer2_0_bn2 8.2016e-05 0.146141 call_module layer2_1_bn2 7.93934e-05 0.141468 call_function add_2 7.89165e-05 0.140618 call_function add_5 7.72476e-05 0.137644 call_module layer4_0_downsample_1 7.60555e-05 0.13552 call_module layer1_0_relu_1 7.58171e-05 0.135096 call_module layer2_1_bn1 7.39098e-05 0.131697 call_function add_6 6.7234e-05 0.119802 call_module layer1_1_relu 6.60419e-05 0.117678 call_module layer3_1_bn1 6.60419e-05 0.117678 call_module layer4_0_bn1 6.58035e-05 0.117253 call_function add_7 6.48499e-05 0.115553 call_module layer4_1_relu 5.26905e-05 0.0938871 call_module layer2_0_relu_1 4.76837e-05 0.0849657 call_module layer2_1_relu_1 4.57764e-05 0.0815671 call_module layer4_0_relu 4.55379e-05 0.0811423 call_module layer4_1_relu_1 4.29153e-05 0.0764692 call_module layer4_0_relu_1 4.22001e-05 0.0751947 call_module layer3_1_relu 4.17233e-05 0.074345 call_module layer2_0_relu 4.02927e-05 0.071796 call_function add_4 3.98159e-05 0.0709464 call_module layer2_1_relu 3.86238e-05 0.0688222 call_module layer3_1_relu_1 3.60012e-05 0.0641491 call_module layer3_0_relu 3.50475e-05 0.0624498 call_module layer3_0_relu_1 3.48091e-05 0.062025 call_function flatten 2.52724e-05 0.0450318 placeholder x 2.21729e-05 0.0395091 output output 8.82149e-06 0.0157187 .. GENERATED FROM PYTHON SOURCE LINES 215-237 There are two things we should call out here: * ``MaxPool2d`` takes up the most time. This is a known issue: https://github.com/pytorch/pytorch/issues/51393 * BatchNorm2d also takes up significant time. We can continue this line of thinking and optimize this in the Conv-BN Fusion with FX `tutorial `_. Conclusion ---------- As we can see, using FX we can easily capture PyTorch programs (even ones we don't have the source code for!) in a machine-interpretable format and use that for analysis, such as the performance analysis we've done here. FX opens up an exciting world of possibilities for working with PyTorch programs. Finally, since FX is still in beta, we would be happy to hear any feedback you have about using it. Please feel free to use the PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker (https://github.com/pytorch/pytorch/issues) to provide any feedback you might have. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.319 seconds) .. _sphx_glr_download_intermediate_fx_profiling_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: fx_profiling_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: fx_profiling_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_