.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "intermediate/torch_compile_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_intermediate_torch_compile_tutorial.py: Introduction to ``torch.compile`` ================================= **Author:** William Wen .. GENERATED FROM PYTHON SOURCE LINES 10-37 ``torch.compile`` is the latest method to speed up your PyTorch code! ``torch.compile`` makes PyTorch code run faster by JIT-compiling PyTorch code into optimized kernels, all while requiring minimal code changes. In this tutorial, we cover basic ``torch.compile`` usage, and demonstrate the advantages of ``torch.compile`` over previous PyTorch compiler solutions, such as `TorchScript `__ and `FX Tracing `__. **Contents** .. contents:: :local: **Required pip Dependencies** - ``torch >= 2.0`` - ``torchvision`` - ``numpy`` - ``scipy`` - ``tabulate`` **System Requirements** - A C++ compiler, such as ``g++`` - Python development package (``python-devel``/``python-dev``) .. GENERATED FROM PYTHON SOURCE LINES 39-41 NOTE: a modern NVIDIA GPU (H100, A100, or V100) is recommended for this tutorial in order to reproduce the speedup numbers shown below and documented elsewhere. .. GENERATED FROM PYTHON SOURCE LINES 41-57 .. code-block:: default import torch import warnings gpu_ok = False if torch.cuda.is_available(): device_cap = torch.cuda.get_device_capability() if device_cap in ((7, 0), (8, 0), (9, 0)): gpu_ok = True if not gpu_ok: warnings.warn( "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower " "than expected." ) .. rst-class:: sphx-glr-script-out .. code-block:: none /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:52: UserWarning: GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower than expected. .. GENERATED FROM PYTHON SOURCE LINES 58-70 Basic Usage ------------ ``torch.compile`` is included in the latest PyTorch. Running TorchInductor on GPU requires Triton, which is included with the PyTorch 2.0 nightly binary. If Triton is still missing, try installing ``torchtriton`` via pip (``pip install torchtriton --extra-index-url "https://download.pytorch.org/whl/nightly/cu117"`` for CUDA 11.7). Arbitrary Python functions can be optimized by passing the callable to ``torch.compile``. We can then call the returned optimized function in place of the original function. .. GENERATED FROM PYTHON SOURCE LINES 70-78 .. code-block:: default def foo(x, y): a = torch.sin(x) b = torch.cos(y) return a + b opt_foo1 = torch.compile(foo) print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10))) .. rst-class:: sphx-glr-script-out .. code-block:: none No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda' tensor([[ 0.2064, 1.2323, 0.1162, 1.9844, -0.4039, -0.0563, 1.0293, 1.6001, 0.7094, -0.5227], [ 0.3270, 0.6888, -0.2318, 0.3682, 0.0054, 1.5653, 1.0798, 0.3274, -0.7216, 1.4286], [ 1.1166, 1.5464, 1.4986, -0.2112, 0.2012, 0.0188, 0.7430, 1.7743, -0.2269, 0.4664], [-0.9179, -0.1216, 0.7813, 1.2106, 0.1166, 0.2849, 0.2511, -0.2535, -0.6960, 0.4591], [-0.3820, 0.0822, -0.1324, 0.5291, 1.4519, 1.8927, 1.9773, -0.4737, 0.3605, 1.1481], [ 0.9345, 0.7358, -1.4884, 1.1542, 0.5837, -0.7215, 0.6450, 0.0737, 0.0341, 0.2624], [ 1.2980, 0.3397, 1.2401, 1.3413, 1.9852, -0.3563, 1.3711, 0.5093, 1.0857, -1.0871], [ 1.9868, -0.0514, 1.1913, 0.3325, 0.8146, 0.1027, 0.4929, 1.3424, 0.3617, -0.3192], [ 1.5075, 0.6393, 0.6106, 0.9151, -0.0736, 0.6184, -0.6001, 1.0612, 0.5941, 1.1044], [ 0.8302, 0.9031, 0.3992, 1.4239, 1.6328, -1.4245, 1.0346, 0.5697, 1.4806, 0.1568]]) .. GENERATED FROM PYTHON SOURCE LINES 79-80 Alternatively, we can decorate the function. .. GENERATED FROM PYTHON SOURCE LINES 80-90 .. code-block:: default t1 = torch.randn(10, 10) t2 = torch.randn(10, 10) @torch.compile def opt_foo2(x, y): a = torch.sin(x) b = torch.cos(y) return a + b print(opt_foo2(t1, t2)) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[-2.5352e-01, -6.5860e-01, 1.7139e+00, 1.6317e+00, -1.0283e+00, -6.0966e-01, 1.5140e+00, 1.9060e+00, 1.1836e+00, 7.5357e-01], [ 1.2014e+00, -1.0037e+00, 1.3054e+00, 1.2069e-01, 8.7427e-01, 1.1791e+00, 1.9804e+00, -7.6623e-01, 1.5417e+00, 9.6536e-01], [ 1.7000e+00, 1.7255e+00, 9.9250e-02, -4.8253e-01, 3.6082e-01, -1.6297e-02, -6.0315e-01, 1.1077e+00, 4.4549e-01, 3.1217e-01], [ 1.3802e+00, 1.3222e+00, 1.3881e+00, 4.4119e-01, 5.2560e-01, 1.2724e+00, 9.3867e-01, 1.7990e+00, 7.9353e-01, 3.4996e-01], [ 1.0997e+00, 3.9743e-01, 1.0511e+00, 1.9785e-01, -9.2254e-01, 4.0315e-01, -7.2081e-02, -6.7474e-01, 1.6736e+00, 1.4608e+00], [ 7.7275e-01, -3.5559e-01, -9.0909e-04, 4.4555e-01, 1.1515e+00, -3.4242e-01, -1.4194e+00, 9.3979e-01, 1.2617e+00, -8.7727e-01], [ 4.9980e-01, 1.6714e+00, 1.5997e+00, 1.9855e+00, 1.1320e+00, 7.6333e-01, 1.6410e+00, 1.7876e+00, 1.7370e+00, -5.9264e-01], [ 8.1104e-01, -2.2948e-01, 1.8480e+00, 1.1356e+00, 1.7497e+00, 5.2498e-01, -1.6087e-02, 8.2839e-01, 8.7289e-01, 1.0582e+00], [-4.2427e-01, 3.6893e-01, 1.7472e+00, 5.7673e-01, -3.1366e-01, -3.7856e-01, 5.4967e-01, 1.7602e+00, 1.0393e+00, 4.6959e-01], [ 6.4678e-01, 1.4350e+00, -1.5536e-01, -5.6575e-03, 1.0960e+00, 1.6234e+00, -1.2709e-01, 1.4242e+00, -5.2102e-01, 1.8811e+00]]) .. GENERATED FROM PYTHON SOURCE LINES 91-92 We can also optimize ``torch.nn.Module`` instances. .. GENERATED FROM PYTHON SOURCE LINES 92-110 .. code-block:: default t = torch.randn(10, 100) class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.lin = torch.nn.Linear(100, 10) def forward(self, x): return torch.nn.functional.relu(self.lin(x)) mod = MyModule() mod.compile() print(mod(t)) ## or: # opt_mod = torch.compile(mod) # print(opt_mod(t)) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[0.0000, 0.0000, 1.2454, 0.0000, 0.0000, 0.0000, 0.6308, 0.0000, 0.6850, 0.0000], [0.5021, 0.0000, 0.0000, 0.0000, 0.0886, 0.4646, 0.7747, 0.0000, 0.0000, 0.0969], [0.0000, 0.0000, 0.2703, 0.0000, 0.8512, 0.0000, 0.2160, 0.2986, 0.0471, 0.7194], [0.0000, 0.1985, 0.0000, 0.9523, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0024, 0.1001, 0.0000, 0.0000, 0.0000, 0.3376, 0.0000, 0.7179, 0.0000, 0.3872], [0.3770, 0.1996, 0.1506, 0.0597, 0.0067, 1.2380, 0.0425, 0.0160, 0.2247, 0.7713], [0.0000, 0.0000, 0.2558, 0.9463, 0.0112, 0.7113, 0.0000, 0.6916, 0.4548, 0.0000], [0.4070, 0.0000, 0.2770, 0.2405, 0.0000, 0.7215, 0.0000, 0.9368, 0.0000, 0.1904], [0.0000, 0.1852, 0.0000, 0.0000, 0.6368, 1.0512, 0.0588, 0.6034, 0.0333, 0.0000], [0.6639, 0.0000, 0.0000, 0.0000, 0.6491, 0.0000, 0.2989, 0.0000, 0.0000, 1.1716]], grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 111-114 torch.compile and Nested Calls ------------------------------ Nested function calls within the decorated function will also be compiled. .. GENERATED FROM PYTHON SOURCE LINES 114-126 .. code-block:: default def nested_function(x): return torch.sin(x) @torch.compile def outer_function(x, y): a = nested_function(x) b = torch.cos(y) return a + b print(outer_function(t1, t2)) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[-2.5352e-01, -6.5860e-01, 1.7139e+00, 1.6317e+00, -1.0283e+00, -6.0966e-01, 1.5140e+00, 1.9060e+00, 1.1836e+00, 7.5357e-01], [ 1.2014e+00, -1.0037e+00, 1.3054e+00, 1.2069e-01, 8.7427e-01, 1.1791e+00, 1.9804e+00, -7.6623e-01, 1.5417e+00, 9.6536e-01], [ 1.7000e+00, 1.7255e+00, 9.9250e-02, -4.8253e-01, 3.6082e-01, -1.6297e-02, -6.0315e-01, 1.1077e+00, 4.4549e-01, 3.1217e-01], [ 1.3802e+00, 1.3222e+00, 1.3881e+00, 4.4119e-01, 5.2560e-01, 1.2724e+00, 9.3867e-01, 1.7990e+00, 7.9353e-01, 3.4996e-01], [ 1.0997e+00, 3.9743e-01, 1.0511e+00, 1.9785e-01, -9.2254e-01, 4.0315e-01, -7.2081e-02, -6.7474e-01, 1.6736e+00, 1.4608e+00], [ 7.7275e-01, -3.5559e-01, -9.0909e-04, 4.4555e-01, 1.1515e+00, -3.4242e-01, -1.4194e+00, 9.3979e-01, 1.2617e+00, -8.7727e-01], [ 4.9980e-01, 1.6714e+00, 1.5997e+00, 1.9855e+00, 1.1320e+00, 7.6333e-01, 1.6410e+00, 1.7876e+00, 1.7370e+00, -5.9264e-01], [ 8.1104e-01, -2.2948e-01, 1.8480e+00, 1.1356e+00, 1.7497e+00, 5.2498e-01, -1.6087e-02, 8.2839e-01, 8.7289e-01, 1.0582e+00], [-4.2427e-01, 3.6893e-01, 1.7472e+00, 5.7673e-01, -3.1366e-01, -3.7856e-01, 5.4967e-01, 1.7602e+00, 1.0393e+00, 4.6959e-01], [ 6.4678e-01, 1.4350e+00, -1.5536e-01, -5.6575e-03, 1.0960e+00, 1.6234e+00, -1.2709e-01, 1.4242e+00, -5.2102e-01, 1.8811e+00]]) .. GENERATED FROM PYTHON SOURCE LINES 127-129 In the same fashion, when compiling a module all sub-modules and methods within it, that are not in a skip list, are also compiled. .. GENERATED FROM PYTHON SOURCE LINES 129-144 .. code-block:: default class OuterModule(torch.nn.Module): def __init__(self): super().__init__() self.inner_module = MyModule() self.outer_lin = torch.nn.Linear(10, 2) def forward(self, x): x = self.inner_module(x) return torch.nn.functional.relu(self.outer_lin(x)) outer_mod = OuterModule() outer_mod.compile() print(outer_mod(t)) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0000], [0.0000, 0.0700]], grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 145-151 We can also disable some functions from being compiled by using ``torch.compiler.disable``. Suppose you want to disable the tracing on just the ``complex_function`` function, but want to continue the tracing back in ``complex_conjugate``. In this case, you can use ``torch.compiler.disable(recursive=False)`` option. Otherwise, the default is ``recursive=True``. .. GENERATED FROM PYTHON SOURCE LINES 151-174 .. code-block:: default def complex_conjugate(z): return torch.conj(z) @torch.compiler.disable(recursive=False) def complex_function(real, imag): # Assuming this function cause problems in the compilation z = torch.complex(real, imag) return complex_conjugate(z) def outer_function(): real = torch.tensor([2, 3], dtype=torch.float32) imag = torch.tensor([4, 5], dtype=torch.float32) z = complex_function(real, imag) return torch.abs(z) # Try to compile the outer_function try: opt_outer_function = torch.compile(outer_function) print(opt_outer_function()) except Exception as e: print("Compilation of outer_function failed:", e) .. rst-class:: sphx-glr-script-out .. code-block:: none /usr/local/lib/python3.10/dist-packages/torch/_inductor/lowering.py:1917: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. tensor([4.4721, 5.8310]) .. GENERATED FROM PYTHON SOURCE LINES 175-208 Best Practices and Recommendations ---------------------------------- Behavior of ``torch.compile`` with Nested Modules and Function Calls When you use ``torch.compile``, the compiler will try to recursively compile every function call inside the target function or module inside the target function or module that is not in a skip list (such as built-ins, some functions in the torch.* namespace). **Best Practices:** 1. **Top-Level Compilation:** One approach is to compile at the highest level possible (i.e., when the top-level module is initialized/called) and selectively disable compilation when encountering excessive graph breaks or errors. If there are still many compile issues, compile individual subcomponents instead. 2. **Modular Testing:** Test individual functions and modules with ``torch.compile`` before integrating them into larger models to isolate potential issues. 3. **Disable Compilation Selectively:** If certain functions or sub-modules cannot be handled by `torch.compile`, use the `torch.compiler.disable` context managers to recursively exclude them from compilation. 4. **Compile Leaf Functions First:** In complex models with multiple nested functions and modules, start by compiling the leaf functions or modules first. For more information see `TorchDynamo APIs for fine-grained tracing `__. 5. **Prefer ``mod.compile()`` over ``torch.compile(mod)``:** Avoids ``_orig_`` prefix issues in ``state_dict``. 6. **Use ``fullgraph=True`` to catch graph breaks:** Helps ensure end-to-end compilation, maximizing speedup and compatibility with ``torch.export``. .. GENERATED FROM PYTHON SOURCE LINES 211-219 Demonstrating Speedups ----------------------- Let's now demonstrate that using ``torch.compile`` can speed up real models. We will compare standard eager mode and ``torch.compile`` by evaluating and training a ``torchvision`` model on random data. Before we start, we need to define some utility functions. .. GENERATED FROM PYTHON SOURCE LINES 219-246 .. code-block:: default # Returns the result of running `fn()` and the time it took for `fn()` to run, # in seconds. We use CUDA events and synchronization for the most accurate # measurements. def timed(fn): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() result = fn() end.record() torch.cuda.synchronize() return result, start.elapsed_time(end) / 1000 # Generates random input and targets data for the model, where `b` is # batch size. def generate_data(b): return ( torch.randn(b, 3, 128, 128).to(torch.float32).cuda(), torch.randint(1000, (b,)).cuda(), ) N_ITERS = 10 from torchvision.models import densenet121 def init_model(): return densenet121().to(torch.float32).cuda() .. GENERATED FROM PYTHON SOURCE LINES 247-251 First, let's compare inference. Note that in the call to ``torch.compile``, we have the additional ``mode`` argument, which we will discuss below. .. GENERATED FROM PYTHON SOURCE LINES 251-265 .. code-block:: default model = init_model() # Reset since we are using a different mode. import torch._dynamo torch._dynamo.reset() model_opt = torch.compile(model, mode="reduce-overhead") inp = generate_data(16)[0] with torch.no_grad(): print("eager:", timed(lambda: model(inp))[1]) print("compile:", timed(lambda: model_opt(inp))[1]) .. rst-class:: sphx-glr-script-out .. code-block:: pytb Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 252, in model = init_model() File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 244, in init_model return densenet121().to(torch.float32).cuda() File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1065, in cuda return self._apply(lambda t: t.cuda(device)) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 915, in _apply module._apply(fn) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 915, in _apply module._apply(fn) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 942, in _apply param_applied = fn(param) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1065, in return self._apply(lambda t: t.cuda(device)) File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 372, in _lazy_init torch._C._cuda_init() RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx .. GENERATED FROM PYTHON SOURCE LINES 266-272 Notice that ``torch.compile`` takes a lot longer to complete compared to eager. This is because ``torch.compile`` compiles the model into optimized kernels as it executes. In our example, the structure of the model doesn't change, and so recompilation is not needed. So if we run our optimized model several more times, we should see a significant improvement compared to eager. .. GENERATED FROM PYTHON SOURCE LINES 272-300 .. code-block:: default eager_times = [] for i in range(N_ITERS): inp = generate_data(16)[0] with torch.no_grad(): _, eager_time = timed(lambda: model(inp)) eager_times.append(eager_time) print(f"eager eval time {i}: {eager_time}") print("~" * 10) compile_times = [] for i in range(N_ITERS): inp = generate_data(16)[0] with torch.no_grad(): _, compile_time = timed(lambda: model_opt(inp)) compile_times.append(compile_time) print(f"compile eval time {i}: {compile_time}") print("~" * 10) import numpy as np eager_med = np.median(eager_times) compile_med = np.median(compile_times) speedup = eager_med / compile_med assert(speedup > 1) print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x") print("~" * 10) .. GENERATED FROM PYTHON SOURCE LINES 301-323 And indeed, we can see that running our model with ``torch.compile`` results in a significant speedup. Speedup mainly comes from reducing Python overhead and GPU read/writes, and so the observed speedup may vary on factors such as model architecture and batch size. For example, if a model's architecture is simple and the amount of data is large, then the bottleneck would be GPU compute and the observed speedup may be less significant. You may also see different speedup results depending on the chosen ``mode`` argument. The ``"reduce-overhead"`` mode uses CUDA graphs to further reduce the overhead of Python. For your own models, you may need to experiment with different modes to maximize speedup. You can read more about modes `here `__. You may might also notice that the second time we run our model with ``torch.compile`` is significantly slower than the other runs, although it is much faster than the first run. This is because the ``"reduce-overhead"`` mode runs a few warm-up iterations for CUDA graphs. For general PyTorch benchmarking, you can try using ``torch.utils.benchmark`` instead of the ``timed`` function we defined above. We wrote our own timing function in this tutorial to show ``torch.compile``'s compilation latency. Now, let's consider comparing training. .. GENERATED FROM PYTHON SOURCE LINES 323-361 .. code-block:: default model = init_model() opt = torch.optim.Adam(model.parameters()) def train(mod, data): opt.zero_grad(True) pred = mod(data[0]) loss = torch.nn.CrossEntropyLoss()(pred, data[1]) loss.backward() opt.step() eager_times = [] for i in range(N_ITERS): inp = generate_data(16) _, eager_time = timed(lambda: train(model, inp)) eager_times.append(eager_time) print(f"eager train time {i}: {eager_time}") print("~" * 10) model = init_model() opt = torch.optim.Adam(model.parameters()) train_opt = torch.compile(train, mode="reduce-overhead") compile_times = [] for i in range(N_ITERS): inp = generate_data(16) _, compile_time = timed(lambda: train_opt(model, inp)) compile_times.append(compile_time) print(f"compile train time {i}: {compile_time}") print("~" * 10) eager_med = np.median(eager_times) compile_med = np.median(compile_times) speedup = eager_med / compile_med assert(speedup > 1) print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x") print("~" * 10) .. GENERATED FROM PYTHON SOURCE LINES 362-369 Again, we can see that ``torch.compile`` takes longer in the first iteration, as it must compile the model, but in subsequent iterations, we see significant speedups compared to eager. We remark that the speedup numbers presented in this tutorial are for demonstration purposes only. Official speedup values can be seen at the `TorchInductor performance dashboard `__. .. GENERATED FROM PYTHON SOURCE LINES 371-383 Comparison to TorchScript and FX Tracing ----------------------------------------- We have seen that ``torch.compile`` can speed up PyTorch code. Why else should we use ``torch.compile`` over existing PyTorch compiler solutions, such as TorchScript or FX Tracing? Primarily, the advantage of ``torch.compile`` lies in its ability to handle arbitrary Python code with minimal changes to existing code. One case that ``torch.compile`` can handle that other compiler solutions struggle with is data-dependent control flow (the ``if x.sum() < 0:`` line below). .. GENERATED FROM PYTHON SOURCE LINES 383-400 .. code-block:: default def f1(x, y): if x.sum() < 0: return -y return y # Test that `fn1` and `fn2` return the same result, given # the same arguments `args`. Typically, `fn1` will be an eager function # while `fn2` will be a compiled function (torch.compile, TorchScript, or FX graph). def test_fns(fn1, fn2, args): out1 = fn1(*args) out2 = fn2(*args) return torch.allclose(out1, out2) inp1 = torch.randn(5, 5) inp2 = torch.randn(5, 5) .. GENERATED FROM PYTHON SOURCE LINES 401-404 TorchScript tracing ``f1`` results in silently incorrect results, since only the actual control flow path is traced. .. GENERATED FROM PYTHON SOURCE LINES 404-409 .. code-block:: default traced_f1 = torch.jit.trace(f1, (inp1, inp2)) print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2))) print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2))) .. GENERATED FROM PYTHON SOURCE LINES 410-412 FX tracing ``f1`` results in an error due to the presence of data-dependent control flow. .. GENERATED FROM PYTHON SOURCE LINES 412-419 .. code-block:: default import traceback as tb try: torch.fx.symbolic_trace(f1) except: tb.print_exc() .. GENERATED FROM PYTHON SOURCE LINES 420-423 If we provide a value for ``x`` as we try to FX trace ``f1``, then we run into the same problem as TorchScript tracing, as the data-dependent control flow is removed in the traced function. .. GENERATED FROM PYTHON SOURCE LINES 423-428 .. code-block:: default fx_f1 = torch.fx.symbolic_trace(f1, concrete_args={"x": inp1}) print("fx 1, 1:", test_fns(f1, fx_f1, (inp1, inp2))) print("fx 1, 2:", test_fns(f1, fx_f1, (-inp1, inp2))) .. GENERATED FROM PYTHON SOURCE LINES 429-431 Now we can see that ``torch.compile`` correctly handles data-dependent control flow. .. GENERATED FROM PYTHON SOURCE LINES 431-440 .. code-block:: default # Reset since we are using a different mode. torch._dynamo.reset() compile_f1 = torch.compile(f1) print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2))) print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2))) print("~" * 10) .. GENERATED FROM PYTHON SOURCE LINES 441-449 TorchScript scripting can handle data-dependent control flow, but this solution comes with its own set of problems. Namely, TorchScript scripting can require major code changes and will raise errors when unsupported Python is used. In the example below, we forget TorchScript type annotations and we receive a TorchScript error because the input type for argument ``y``, an ``int``, does not match with the default argument type, ``torch.Tensor``. .. GENERATED FROM PYTHON SOURCE LINES 449-462 .. code-block:: default def f2(x, y): return x + y inp1 = torch.randn(5, 5) inp2 = 3 script_f2 = torch.jit.script(f2) try: script_f2(inp1, inp2) except: tb.print_exc() .. GENERATED FROM PYTHON SOURCE LINES 463-464 However, ``torch.compile`` is easily able to handle ``f2``. .. GENERATED FROM PYTHON SOURCE LINES 464-469 .. code-block:: default compile_f2 = torch.compile(f2) print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2))) print("~" * 10) .. GENERATED FROM PYTHON SOURCE LINES 470-472 Another case that ``torch.compile`` handles well compared to previous compilers solutions is the usage of non-PyTorch functions. .. GENERATED FROM PYTHON SOURCE LINES 472-481 .. code-block:: default import scipy def f3(x): x = x * 2 x = scipy.fft.dct(x.numpy()) x = torch.from_numpy(x) x = x * 2 return x .. GENERATED FROM PYTHON SOURCE LINES 482-484 TorchScript tracing treats results from non-PyTorch function calls as constants, and so our results can be silently wrong. .. GENERATED FROM PYTHON SOURCE LINES 484-490 .. code-block:: default inp1 = torch.randn(5, 5) inp2 = torch.randn(5, 5) traced_f3 = torch.jit.trace(f3, (inp1,)) print("traced 3:", test_fns(f3, traced_f3, (inp2,))) .. GENERATED FROM PYTHON SOURCE LINES 491-492 TorchScript scripting and FX tracing disallow non-PyTorch function calls. .. GENERATED FROM PYTHON SOURCE LINES 492-503 .. code-block:: default try: torch.jit.script(f3) except: tb.print_exc() try: torch.fx.symbolic_trace(f3) except: tb.print_exc() .. GENERATED FROM PYTHON SOURCE LINES 504-506 In comparison, ``torch.compile`` is easily able to handle the non-PyTorch function call. .. GENERATED FROM PYTHON SOURCE LINES 506-510 .. code-block:: default compile_f3 = torch.compile(f3) print("compile 3:", test_fns(f3, compile_f3, (inp2,))) .. GENERATED FROM PYTHON SOURCE LINES 511-525 TorchDynamo and FX Graphs -------------------------- One important component of ``torch.compile`` is TorchDynamo. TorchDynamo is responsible for JIT compiling arbitrary Python code into `FX graphs `__, which can then be further optimized. TorchDynamo extracts FX graphs by analyzing Python bytecode during runtime and detecting calls to PyTorch operations. Normally, TorchInductor, another component of ``torch.compile``, further compiles the FX graphs into optimized kernels, but TorchDynamo allows for different backends to be used. In order to inspect the FX graphs that TorchDynamo outputs, let us create a custom backend that outputs the FX graph and simply returns the graph's unoptimized forward method. .. GENERATED FROM PYTHON SOURCE LINES 525-538 .. code-block:: default from typing import List def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): print("custom backend called with FX graph:") gm.graph.print_tabular() return gm.forward # Reset since we are using a different backend. torch._dynamo.reset() opt_model = torch.compile(init_model(), backend=custom_backend) opt_model(generate_data(16)[0]) .. GENERATED FROM PYTHON SOURCE LINES 539-542 Using our custom backend, we can now see how TorchDynamo is able to handle data-dependent control flow. Consider the function below, where the line ``if b.sum() < 0`` is the source of data-dependent control flow. .. GENERATED FROM PYTHON SOURCE LINES 542-555 .. code-block:: default def bar(a, b): x = a / (torch.abs(a) + 1) if b.sum() < 0: b = b * -1 return x * b opt_bar = torch.compile(bar, backend=custom_backend) inp1 = torch.randn(10) inp2 = torch.randn(10) opt_bar(inp1, inp2) opt_bar(inp1, -inp2) .. GENERATED FROM PYTHON SOURCE LINES 556-580 The output reveals that TorchDynamo extracted 3 different FX graphs corresponding the following code (order may differ from the output above): 1. ``x = a / (torch.abs(a) + 1)`` 2. ``b = b * -1; return x * b`` 3. ``return x * b`` When TorchDynamo encounters unsupported Python features, such as data-dependent control flow, it breaks the computation graph, lets the default Python interpreter handle the unsupported code, then resumes capturing the graph. Let's investigate by example how TorchDynamo would step through ``bar``. If ``b.sum() < 0``, then TorchDynamo would run graph 1, let Python determine the result of the conditional, then run graph 2. On the other hand, if ``not b.sum() < 0``, then TorchDynamo would run graph 1, let Python determine the result of the conditional, then run graph 3. This highlights a major difference between TorchDynamo and previous PyTorch compiler solutions. When encountering unsupported Python features, previous solutions either raise an error or silently fail. TorchDynamo, on the other hand, will break the computation graph. We can see where TorchDynamo breaks the graph by using ``torch._dynamo.explain``: .. GENERATED FROM PYTHON SOURCE LINES 580-586 .. code-block:: default # Reset since we are using a different backend. torch._dynamo.reset() explain_output = torch._dynamo.explain(bar)(torch.randn(10), torch.randn(10)) print(explain_output) .. GENERATED FROM PYTHON SOURCE LINES 587-590 In order to maximize speedup, graph breaks should be limited. We can force TorchDynamo to raise an error upon the first graph break encountered by using ``fullgraph=True``: .. GENERATED FROM PYTHON SOURCE LINES 590-597 .. code-block:: default opt_bar = torch.compile(bar, fullgraph=True) try: opt_bar(torch.randn(10), torch.randn(10)) except: tb.print_exc() .. GENERATED FROM PYTHON SOURCE LINES 598-600 And below, we demonstrate that TorchDynamo does not break the graph on the model we used above for demonstrating speedups. .. GENERATED FROM PYTHON SOURCE LINES 600-604 .. code-block:: default opt_model = torch.compile(init_model(), fullgraph=True) print(opt_model(generate_data(16)[0])) .. GENERATED FROM PYTHON SOURCE LINES 605-611 We can use ``torch.export`` (from PyTorch 2.1+) to extract a single, exportable FX graph from the input PyTorch program. The exported graph is intended to be run on different (i.e. Python-less) environments. One important restriction is that the ``torch.export`` does not support graph breaks. Please check `this tutorial `__ for more details on ``torch.export``. .. GENERATED FROM PYTHON SOURCE LINES 613-620 Conclusion ------------ In this tutorial, we introduced ``torch.compile`` by covering basic usage, demonstrating speedups over eager mode, comparing to previous PyTorch compiler solutions, and briefly investigating TorchDynamo and its interactions with FX graphs. We hope that you will give ``torch.compile`` a try! .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 11.772 seconds) .. _sphx_glr_download_intermediate_torch_compile_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: torch_compile_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: torch_compile_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_