Skip to content

Commit af1f585

Browse files
author
James Reed
committed
[WIP][FX] CPU Performance Profiling with FX
1 parent 9a38f3d commit af1f585

File tree

2 files changed

+253
-0
lines changed

2 files changed

+253
-0
lines changed

index.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,15 @@ Welcome to PyTorch Tutorials
215215
:link: advanced/super_resolution_with_onnxruntime.html
216216
:tags: Production
217217

218+
.. Code Transformations with FX
219+
220+
.. customcarditem::
221+
:header: Building a Simple Performance Profiler with FX
222+
:card_description: Build a simple FX interpreter to record the runtime of op, module, and function calls and report statistics
223+
:image: _static/img/thumbnails/cropped/Deploying-PyTorch-in-Python-via-a-REST-API-with-Flask.png
224+
:link: intermediate/fx_profiling_tutorial.html
225+
:tags: FX
226+
218227
.. Frontend APIs
219228
220229
.. customcarditem::
@@ -505,6 +514,14 @@ Additional Resources
505514
advanced/cpp_export
506515
advanced/super_resolution_with_onnxruntime
507516

517+
.. toctree::
518+
:maxdepth: 2
519+
:includehidden:
520+
:hidden:
521+
:caption: Code Transforms with FX
522+
523+
intermediate/fx_profiling_tutorial
524+
508525
.. toctree::
509526
:maxdepth: 2
510527
:includehidden:
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
(beta) Building a Simple CPU Performance Profiler with FX
4+
*******************************************************
5+
**Author**: `James Reed <https://github.com/jamesr66a>`_
6+
7+
In this tutorial, we are going to use FX to do the following:
8+
9+
1) Capture PyTorch Python code in a way that we can inspect and gather
10+
statistics about the structure and execution of the code
11+
2) Build out a small class that will serve as a simple performance "profiler",
12+
collecting runtime statistics about each part of the model from actual
13+
runs.
14+
15+
"""
16+
17+
######################################################################
18+
# For this tutorial, we are going to use the torchvision ResNet18 model
19+
# for demonstration purposes.
20+
21+
import torch
22+
import torch.fx
23+
import torchvision.models as models
24+
25+
rn18 = models.resnet18()
26+
rn18.eval()
27+
28+
######################################################################
29+
# Now that we have our model, we want to inspect deeper into its
30+
# performance. That is, for the following invocation, which parts
31+
# of the model are taking the longest?
32+
input = torch.randn(5, 3, 224, 224)
33+
output = rn18(input)
34+
35+
######################################################################
36+
# A common way of answering that question is to go through the program
37+
# source, add code that collects timestamps at various points in the
38+
# program, and compare the difference between those timestamps to see
39+
# how long the regions between the timestamps take.
40+
#
41+
# That technique is certainly applicable to PyTorch code, however it
42+
# would be nicer if we didn't have to copy over model code and edit it,
43+
# especially code we haven't written (like this torchvision model).
44+
# Instead, we are going to use FX to automate this "instrumentation"
45+
# process without needing to modify any source.
46+
47+
######################################################################
48+
# First, let's get some imports out of the way (we will be using all
49+
# of these later in the code).
50+
51+
import statistics, tabulate, time
52+
from typing import Any, Dict, List
53+
from torch.fx import Interpreter
54+
55+
######################################################################
56+
# .. note::
57+
# ``tabulate`` is an external library that is not a dependency of PyTorch.
58+
# We will be using it to more easily visualize performance data. Please
59+
# make sure you've installed it from your favorite Python package source.
60+
61+
######################################################################
62+
# Capturing the Model with Symbolic Tracing
63+
# -----------------------------------------
64+
# Next, we are going to use FX's symbolic tracing mechanism to capture
65+
# the definition of our model in a data structure we can manipulate
66+
# and examine.
67+
68+
traced_rn18 = torch.fx.symbolic_trace(rn18)
69+
print(traced_rn18.graph)
70+
71+
######################################################################
72+
# This gives us a Graph representation of the ResNet18 model. A Graph
73+
# consists of a series of Nodes connected to each other. Each Node
74+
# represents a call-site in the Python code (whether to a function,
75+
# a module, or a method) and the edges (represented as ``args`` and ``kwargs``
76+
# on each node) represent the values passed between these call-sites. More
77+
# information about the Graph representation and the rest of FX's APIs ca
78+
# be found at the FX documentation https://pytorch.org/docs/master/fx.html.
79+
80+
81+
######################################################################
82+
# Creating a Profiling Interpreter
83+
# --------------------------------
84+
# Next, we are going to create a class that inherits from ``torch.fx.Interpreter``.
85+
# Though the ``GraphModule`` that ``symbolic_trace`` produces compiles Python code
86+
# that is run when you call a ``GraphModule``, an alternative way to run a
87+
# ``GraphModule`` is by executing each ``Node`` in the ``Graph`` one by one. That is
88+
# the functionality that ``Interpreter`` provides: It interprets the graph node-
89+
# by-node.
90+
#
91+
# By inheriting from ``Interpreter``, we can override various functionality and
92+
# install the profiling behavior we want. The goal is to have an object to which
93+
# we can pass a model, invoke the model 1 or more times, then get statistics about
94+
# how long the model and each part of the model took during those runs.
95+
#
96+
# Let's define our ``ProfilingInterpreter`` class:
97+
98+
class ProfilingInterpreter(Interpreter):
99+
def __init__(self, mod : torch.nn.Module):
100+
# Rather than have the user symbolically trace their model,
101+
# we're going to do it in the constructor. As a result, the
102+
# user can pass in any ``Module`` without having to worry about
103+
# symbolic tracing APIs
104+
gm = torch.fx.symbolic_trace(mod)
105+
super().__init__(gm)
106+
107+
# We are going to store away two things here:
108+
#
109+
# 1. A list of total runtimes for ``mod``. In other words, we are
110+
# storing away the time ``mod(...)`` took each time this
111+
# interpreter is called.
112+
self.total_runtime_sec : List[float] = []
113+
# 2. A map from ``Node`` to a list of times (in seconds) that
114+
# node took to run. This can be seen as similar to (1) but
115+
# for specific sub-parts of the model.
116+
self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}
117+
118+
######################################################################
119+
# Next, let's override our first method: ``run()``. ``Interpreter``'s ``run``
120+
# method is the top-level entrypoint for execution of the model. We will
121+
# want to intercept this so that we can record the total runtime of the
122+
# model.
123+
124+
def run(self, *args) -> Any:
125+
# Record the time we started running the model
126+
t_start = time.time()
127+
# Run the model by delegating back into Interpreter.run()
128+
return_val = super().run(*args)
129+
# Record the time we finished running the model
130+
t_end = time.time()
131+
# Store the total elapsed time this model execution took in the
132+
# ProfilingInterpreter
133+
self.total_runtime_sec.append(t_end - t_start)
134+
return return_val
135+
136+
######################################################################
137+
# Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each
138+
# time it executes a single node. We will intercept this so that we
139+
# can measure and record the time taken for each individual call in
140+
# the model.
141+
142+
def run_node(self, n : torch.fx.Node) -> Any:
143+
# Record the time we started running the op
144+
t_start = time.time()
145+
# Run the op by delegating back into Interpreter.run_node()
146+
return_val = super().run_node(n)
147+
# Record the time we finished running the op
148+
t_end = time.time()
149+
# If we don't have an entry for this node in our runtimes_sec
150+
# data structure, add one with an empty list value.
151+
self.runtimes_sec.setdefault(n, [])
152+
# Record the total elapsed time for this single invocation
153+
# in the runtimes_sec data structure
154+
self.runtimes_sec[n].append(t_end - t_start)
155+
return return_val
156+
157+
######################################################################
158+
# Finally, we are going to define a method (one which doesn't override
159+
# any ``Interpreter`` method) that provides us a nice, organized view of
160+
# the data we have collected.
161+
162+
def summary(self, should_sort : bool = False) -> str:
163+
# Build up a list of summary information for each node
164+
node_summaries : List[List[Any]] = []
165+
# Calculate the mean runtime for the whole network. Because the
166+
# network may have been called multiple times during profiling,
167+
# we need to summarize the runtimes. We choose to use the
168+
# arithmetic mean for this.
169+
mean_total_runtime = statistics.mean(self.total_runtime_sec)
170+
171+
# For each node, record summary statistics
172+
for node, runtimes in self.runtimes_sec.items():
173+
# Similarly, compute the mean runtime for ``node``
174+
mean_runtime = statistics.mean(runtimes)
175+
# For easier understanding, we also compute the percentage
176+
# time each node took with respect to the whole network.
177+
pct_total = mean_runtime / mean_total_runtime * 100
178+
# Record the node's type, name of the node, mean runtime, and
179+
# percent runtim
180+
node_summaries.append(
181+
[node.op, str(node), mean_runtime, pct_total])
182+
183+
# One of the most important questions to answer when doing performance
184+
# profiling is "Which op(s) took the longest?". We can make this easy
185+
# to see by providing sorting functionality in our summary view
186+
if should_sort:
187+
node_summaries.sort(key=lambda s: s[2], reverse=True)
188+
189+
# Use the ``tabulate`` library to create a well-formatted table
190+
# presenting our summary information
191+
headers : List[str] = [
192+
'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
193+
]
194+
return tabulate.tabulate(node_summaries, headers=headers)
195+
196+
######################################################################
197+
# .. note::
198+
# We use Python's ``time.time`` function to pull wall clock
199+
# timestamps and compare them. This is not the most accurate
200+
# way to measure performance, and will only give us a first-
201+
# order approximation. We use this simple technique only for the
202+
# purpose of demonstration in this tutorial.
203+
204+
######################################################################
205+
# Investigating the Performance of ResNet18
206+
# -----------------------------------------
207+
# We can now use ``ProfilingInterpreter`` to inspect the performance
208+
# characteristics of our ResNet18 model;
209+
210+
interp = ProfilingInterpreter(rn18)
211+
interp.run(input)
212+
print(interp.summary(True))
213+
214+
######################################################################
215+
# There are two things we should call out here:
216+
#
217+
# * MaxPool2d takes up the most time. This is a known issue:
218+
# https://github.com/pytorch/pytorch/issues/51393
219+
# * BatchNorm2d also takes up significant time. We can continue this
220+
# line of thinking and optimize this in the Conv-BN Fusion with FX
221+
# tutorial TODO: link
222+
#
223+
#
224+
# Conclusion
225+
# ----------
226+
# As we can see, using FX we can easily capture PyTorch programs (even
227+
# ones we don't have the source code for!) in a machine-interpretable
228+
# format and use that for analysis, such as the performance analysis
229+
# we've done here. FX opens up an exiciting world of possibilities for
230+
# working with PyTorch programs.
231+
#
232+
# Finally, since FX is still in beta, we would be happy to hear any
233+
# feedback you have about using it. Please feel free to use the
234+
# PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker
235+
# (https://github.com/pytorch/pytorch/issues) to provide any feedback
236+
# you might have.

0 commit comments

Comments
 (0)