|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +(beta) Building a Simple CPU Performance Profiler with FX |
| 4 | +******************************************************* |
| 5 | +**Author**: `James Reed <https://github.com/jamesr66a>`_ |
| 6 | +
|
| 7 | +In this tutorial, we are going to use FX to do the following: |
| 8 | +
|
| 9 | +1) Capture PyTorch Python code in a way that we can inspect and gather |
| 10 | + statistics about the structure and execution of the code |
| 11 | +2) Build out a small class that will serve as a simple performance "profiler", |
| 12 | + collecting runtime statistics about each part of the model from actual |
| 13 | + runs. |
| 14 | +
|
| 15 | +""" |
| 16 | + |
| 17 | +###################################################################### |
| 18 | +# For this tutorial, we are going to use the torchvision ResNet18 model |
| 19 | +# for demonstration purposes. |
| 20 | + |
| 21 | +import torch |
| 22 | +import torch.fx |
| 23 | +import torchvision.models as models |
| 24 | + |
| 25 | +rn18 = models.resnet18() |
| 26 | +rn18.eval() |
| 27 | + |
| 28 | +###################################################################### |
| 29 | +# Now that we have our model, we want to inspect deeper into its |
| 30 | +# performance. That is, for the following invocation, which parts |
| 31 | +# of the model are taking the longest? |
| 32 | +input = torch.randn(5, 3, 224, 224) |
| 33 | +output = rn18(input) |
| 34 | + |
| 35 | +###################################################################### |
| 36 | +# A common way of answering that question is to go through the program |
| 37 | +# source, add code that collects timestamps at various points in the |
| 38 | +# program, and compare the difference between those timestamps to see |
| 39 | +# how long the regions between the timestamps take. |
| 40 | +# |
| 41 | +# That technique is certainly applicable to PyTorch code, however it |
| 42 | +# would be nicer if we didn't have to copy over model code and edit it, |
| 43 | +# especially code we haven't written (like this torchvision model). |
| 44 | +# Instead, we are going to use FX to automate this "instrumentation" |
| 45 | +# process without needing to modify any source. |
| 46 | + |
| 47 | +###################################################################### |
| 48 | +# First, let's get some imports out of the way (we will be using all |
| 49 | +# of these later in the code). |
| 50 | + |
| 51 | +import statistics, tabulate, time |
| 52 | +from typing import Any, Dict, List |
| 53 | +from torch.fx import Interpreter |
| 54 | + |
| 55 | +###################################################################### |
| 56 | +# .. note:: |
| 57 | +# ``tabulate`` is an external library that is not a dependency of PyTorch. |
| 58 | +# We will be using it to more easily visualize performance data. Please |
| 59 | +# make sure you've installed it from your favorite Python package source. |
| 60 | + |
| 61 | +###################################################################### |
| 62 | +# Capturing the Model with Symbolic Tracing |
| 63 | +# ----------------------------------------- |
| 64 | +# Next, we are going to use FX's symbolic tracing mechanism to capture |
| 65 | +# the definition of our model in a data structure we can manipulate |
| 66 | +# and examine. |
| 67 | + |
| 68 | +traced_rn18 = torch.fx.symbolic_trace(rn18) |
| 69 | +print(traced_rn18.graph) |
| 70 | + |
| 71 | +###################################################################### |
| 72 | +# This gives us a Graph representation of the ResNet18 model. A Graph |
| 73 | +# consists of a series of Nodes connected to each other. Each Node |
| 74 | +# represents a call-site in the Python code (whether to a function, |
| 75 | +# a module, or a method) and the edges (represented as ``args`` and ``kwargs`` |
| 76 | +# on each node) represent the values passed between these call-sites. More |
| 77 | +# information about the Graph representation and the rest of FX's APIs ca |
| 78 | +# be found at the FX documentation https://pytorch.org/docs/master/fx.html. |
| 79 | + |
| 80 | + |
| 81 | +###################################################################### |
| 82 | +# Creating a Profiling Interpreter |
| 83 | +# -------------------------------- |
| 84 | +# Next, we are going to create a class that inherits from ``torch.fx.Interpreter``. |
| 85 | +# Though the ``GraphModule`` that ``symbolic_trace`` produces compiles Python code |
| 86 | +# that is run when you call a ``GraphModule``, an alternative way to run a |
| 87 | +# ``GraphModule`` is by executing each ``Node`` in the ``Graph`` one by one. That is |
| 88 | +# the functionality that ``Interpreter`` provides: It interprets the graph node- |
| 89 | +# by-node. |
| 90 | +# |
| 91 | +# By inheriting from ``Interpreter``, we can override various functionality and |
| 92 | +# install the profiling behavior we want. The goal is to have an object to which |
| 93 | +# we can pass a model, invoke the model 1 or more times, then get statistics about |
| 94 | +# how long the model and each part of the model took during those runs. |
| 95 | +# |
| 96 | +# Let's define our ``ProfilingInterpreter`` class: |
| 97 | + |
| 98 | +class ProfilingInterpreter(Interpreter): |
| 99 | + def __init__(self, mod : torch.nn.Module): |
| 100 | + # Rather than have the user symbolically trace their model, |
| 101 | + # we're going to do it in the constructor. As a result, the |
| 102 | + # user can pass in any ``Module`` without having to worry about |
| 103 | + # symbolic tracing APIs |
| 104 | + gm = torch.fx.symbolic_trace(mod) |
| 105 | + super().__init__(gm) |
| 106 | + |
| 107 | + # We are going to store away two things here: |
| 108 | + # |
| 109 | + # 1. A list of total runtimes for ``mod``. In other words, we are |
| 110 | + # storing away the time ``mod(...)`` took each time this |
| 111 | + # interpreter is called. |
| 112 | + self.total_runtime_sec : List[float] = [] |
| 113 | + # 2. A map from ``Node`` to a list of times (in seconds) that |
| 114 | + # node took to run. This can be seen as similar to (1) but |
| 115 | + # for specific sub-parts of the model. |
| 116 | + self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {} |
| 117 | + |
| 118 | + ###################################################################### |
| 119 | + # Next, let's override our first method: ``run()``. ``Interpreter``'s ``run`` |
| 120 | + # method is the top-level entrypoint for execution of the model. We will |
| 121 | + # want to intercept this so that we can record the total runtime of the |
| 122 | + # model. |
| 123 | + |
| 124 | + def run(self, *args) -> Any: |
| 125 | + # Record the time we started running the model |
| 126 | + t_start = time.time() |
| 127 | + # Run the model by delegating back into Interpreter.run() |
| 128 | + return_val = super().run(*args) |
| 129 | + # Record the time we finished running the model |
| 130 | + t_end = time.time() |
| 131 | + # Store the total elapsed time this model execution took in the |
| 132 | + # ProfilingInterpreter |
| 133 | + self.total_runtime_sec.append(t_end - t_start) |
| 134 | + return return_val |
| 135 | + |
| 136 | + ###################################################################### |
| 137 | + # Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each |
| 138 | + # time it executes a single node. We will intercept this so that we |
| 139 | + # can measure and record the time taken for each individual call in |
| 140 | + # the model. |
| 141 | + |
| 142 | + def run_node(self, n : torch.fx.Node) -> Any: |
| 143 | + # Record the time we started running the op |
| 144 | + t_start = time.time() |
| 145 | + # Run the op by delegating back into Interpreter.run_node() |
| 146 | + return_val = super().run_node(n) |
| 147 | + # Record the time we finished running the op |
| 148 | + t_end = time.time() |
| 149 | + # If we don't have an entry for this node in our runtimes_sec |
| 150 | + # data structure, add one with an empty list value. |
| 151 | + self.runtimes_sec.setdefault(n, []) |
| 152 | + # Record the total elapsed time for this single invocation |
| 153 | + # in the runtimes_sec data structure |
| 154 | + self.runtimes_sec[n].append(t_end - t_start) |
| 155 | + return return_val |
| 156 | + |
| 157 | + ###################################################################### |
| 158 | + # Finally, we are going to define a method (one which doesn't override |
| 159 | + # any ``Interpreter`` method) that provides us a nice, organized view of |
| 160 | + # the data we have collected. |
| 161 | + |
| 162 | + def summary(self, should_sort : bool = False) -> str: |
| 163 | + # Build up a list of summary information for each node |
| 164 | + node_summaries : List[List[Any]] = [] |
| 165 | + # Calculate the mean runtime for the whole network. Because the |
| 166 | + # network may have been called multiple times during profiling, |
| 167 | + # we need to summarize the runtimes. We choose to use the |
| 168 | + # arithmetic mean for this. |
| 169 | + mean_total_runtime = statistics.mean(self.total_runtime_sec) |
| 170 | + |
| 171 | + # For each node, record summary statistics |
| 172 | + for node, runtimes in self.runtimes_sec.items(): |
| 173 | + # Similarly, compute the mean runtime for ``node`` |
| 174 | + mean_runtime = statistics.mean(runtimes) |
| 175 | + # For easier understanding, we also compute the percentage |
| 176 | + # time each node took with respect to the whole network. |
| 177 | + pct_total = mean_runtime / mean_total_runtime * 100 |
| 178 | + # Record the node's type, name of the node, mean runtime, and |
| 179 | + # percent runtim |
| 180 | + node_summaries.append( |
| 181 | + [node.op, str(node), mean_runtime, pct_total]) |
| 182 | + |
| 183 | + # One of the most important questions to answer when doing performance |
| 184 | + # profiling is "Which op(s) took the longest?". We can make this easy |
| 185 | + # to see by providing sorting functionality in our summary view |
| 186 | + if should_sort: |
| 187 | + node_summaries.sort(key=lambda s: s[2], reverse=True) |
| 188 | + |
| 189 | + # Use the ``tabulate`` library to create a well-formatted table |
| 190 | + # presenting our summary information |
| 191 | + headers : List[str] = [ |
| 192 | + 'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime' |
| 193 | + ] |
| 194 | + return tabulate.tabulate(node_summaries, headers=headers) |
| 195 | + |
| 196 | +###################################################################### |
| 197 | +# .. note:: |
| 198 | +# We use Python's ``time.time`` function to pull wall clock |
| 199 | +# timestamps and compare them. This is not the most accurate |
| 200 | +# way to measure performance, and will only give us a first- |
| 201 | +# order approximation. We use this simple technique only for the |
| 202 | +# purpose of demonstration in this tutorial. |
| 203 | + |
| 204 | +###################################################################### |
| 205 | +# Investigating the Performance of ResNet18 |
| 206 | +# ----------------------------------------- |
| 207 | +# We can now use ``ProfilingInterpreter`` to inspect the performance |
| 208 | +# characteristics of our ResNet18 model; |
| 209 | + |
| 210 | +interp = ProfilingInterpreter(rn18) |
| 211 | +interp.run(input) |
| 212 | +print(interp.summary(True)) |
| 213 | + |
| 214 | +###################################################################### |
| 215 | +# There are two things we should call out here: |
| 216 | +# |
| 217 | +# * MaxPool2d takes up the most time. This is a known issue: |
| 218 | +# https://github.com/pytorch/pytorch/issues/51393 |
| 219 | +# * BatchNorm2d also takes up significant time. We can continue this |
| 220 | +# line of thinking and optimize this in the Conv-BN Fusion with FX |
| 221 | +# tutorial TODO: link |
| 222 | +# |
| 223 | +# |
| 224 | +# Conclusion |
| 225 | +# ---------- |
| 226 | +# As we can see, using FX we can easily capture PyTorch programs (even |
| 227 | +# ones we don't have the source code for!) in a machine-interpretable |
| 228 | +# format and use that for analysis, such as the performance analysis |
| 229 | +# we've done here. FX opens up an exiciting world of possibilities for |
| 230 | +# working with PyTorch programs. |
| 231 | +# |
| 232 | +# Finally, since FX is still in beta, we would be happy to hear any |
| 233 | +# feedback you have about using it. Please feel free to use the |
| 234 | +# PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker |
| 235 | +# (https://github.com/pytorch/pytorch/issues) to provide any feedback |
| 236 | +# you might have. |
0 commit comments