Overview
TorchFuser is (going to be) a compiler and runtime framework designed to enhance PyTorch training performance by fusing common transformer operations. Drawing inspiration from projects like MLC-LLM, which leverages TVM for optimization, TorchFuser aims to integrate seamlessly into PyTorch workflows, utilizing MLIR-based compilation strategies.
Motivation
Training large-scale models, such as LLMs and transformers, is computationally intensive and highly inefficient. While solutions like FlashAttention have optimized specific components, there’s a need for a more generalized approach that:
- Integrates seamlessly with PyTorch, possibly through a
@torchfuser.jit
decorator. - Leverages MLIR for backend optimizations, similar to Triton’s approach.
- Supports diverse hardware, including NVIDIA GPUs, AMD GPUs, and Apple’s MLX.
- Achieves tangible performance gains, targeting at least a 10% improvement over standard PyTorch operations.
Technical Approach
-
MLIR Integration:
- Utilize
torch-mlir
as a foundation, benefiting from its alignment with PyTorch’s ecosystem. - Draw insights from Triton’s MLIR-based backend, which compiles Python-decorated functions into optimized GPU kernels.
- Utilize
-
Decorator-Based API:
- Introduce a
@torchfuser.jit
decorator, allowing users to annotate functions for optimization, akin to Numba’s approach.
- Introduce a
-
Hardware Abstraction:
- Design the compiler to generate optimized code for various hardware backends, ensuring broad compatibility.
-
Performance Optimization:
- Implement techniques inspired by FlashAttention-3, such as overlapping computation and data movement, and leveraging low-precision computations (e.g., FP8).
Roadmap
-
Phase 1: Research & Design
- Analyze existing solutions like MLC-LLM and Triton to inform design decisions.
- Define the MLIR dialects and transformations required for TorchFuser.
-
Phase 2: Prototype Development
- Develop initial prototypes focusing on fusing transformer components like attention mechanisms and MLPs.
- Benchmark performance against standard PyTorch implementations.
-
Phase 3: Community Feedback & Iteration
- Share prototypes on here itself to gather feedback.
- Iterate on design and implementation based on the insights.
-
Phase 4: Production Readiness
- Finalize the API and backend implementations.
- Provide comprehensive documentation and tutorials to facilitate adoption.
Call to Action
We’re seeking feedback on:
- The feasibility and design of integrating MLIR-based optimizations into PyTorch.
- Strategies for broad hardware support, including AMD GPUs and Apple’s MLX.
- Potential challenges in achieving the targeted performance improvements.
Your expertise and insights are very much appreciated in developing this project.