TorchFuser: A Plug-and-Play MLIR-Based Compiler and Optimized Runtime Integration

Overview

TorchFuser is (going to be) a compiler and runtime framework designed to enhance PyTorch training performance by fusing common transformer operations. Drawing inspiration from projects like MLC-LLM, which leverages TVM for optimization, TorchFuser aims to integrate seamlessly into PyTorch workflows, utilizing MLIR-based compilation strategies.

Motivation

Training large-scale models, such as LLMs and transformers, is computationally intensive and highly inefficient. While solutions like FlashAttention have optimized specific components, there’s a need for a more generalized approach that:

  • Integrates seamlessly with PyTorch, possibly through a @torchfuser.jit decorator.
  • Leverages MLIR for backend optimizations, similar to Triton’s approach.
  • Supports diverse hardware, including NVIDIA GPUs, AMD GPUs, and Apple’s MLX.
  • Achieves tangible performance gains, targeting at least a 10% improvement over standard PyTorch operations.

Technical Approach

  1. MLIR Integration:

    • Utilize torch-mlir as a foundation, benefiting from its alignment with PyTorch’s ecosystem.
    • Draw insights from Triton’s MLIR-based backend, which compiles Python-decorated functions into optimized GPU kernels.
  2. Decorator-Based API:

    • Introduce a @torchfuser.jit decorator, allowing users to annotate functions for optimization, akin to Numba’s approach.
  3. Hardware Abstraction:

    • Design the compiler to generate optimized code for various hardware backends, ensuring broad compatibility.
  4. Performance Optimization:

    • Implement techniques inspired by FlashAttention-3, such as overlapping computation and data movement, and leveraging low-precision computations (e.g., FP8).

Roadmap

  1. Phase 1: Research & Design

    • Analyze existing solutions like MLC-LLM and Triton to inform design decisions.
    • Define the MLIR dialects and transformations required for TorchFuser.
  2. Phase 2: Prototype Development

    • Develop initial prototypes focusing on fusing transformer components like attention mechanisms and MLPs.
    • Benchmark performance against standard PyTorch implementations.
  3. Phase 3: Community Feedback & Iteration

    • Share prototypes on here itself to gather feedback.
    • Iterate on design and implementation based on the insights.
  4. Phase 4: Production Readiness

    • Finalize the API and backend implementations.
    • Provide comprehensive documentation and tutorials to facilitate adoption.

Call to Action

We’re seeking feedback on:

  • The feasibility and design of integrating MLIR-based optimizations into PyTorch.
  • Strategies for broad hardware support, including AMD GPUs and Apple’s MLX.
  • Potential challenges in achieving the targeted performance improvements.

Your expertise and insights are very much appreciated in developing this project.

3 Likes

why build a new compiler? why not just improve the TorchInductor?

Would you post the github link for this project? Thanks!

I think we can register TorchFuser as a new backend in Torch Dynamo. This way, we can reuse torch.compile. For users, a single torch.compile can solve all the problems. All they need to do is try different backends to determine which one offers the best performance.