Description
In collaboration with @liuliu, I have been developing a new kernel library for GEMM and attention operations. Stable Diffusion/NNC is the primary use case, but I hope to integrate into LLaMA/GGML too. The library consistently outperforms MPS by a large amount: https://twitter.com/philipturnerar/status/1669146393271730178
Here is out of the box performance, with zero fine-tuning. The table shows matrix sizes common in SD and LLaMA. For example, in the second GEMM of
M | N | K | F16 | MFA/16 | MPS/16 | F32 | MFA/32 | MPS/32 |
---|---|---|---|---|---|---|---|---|
1280 | 4096 | 320 | ✅ | 83% | 68% | ✅ | 75% | 75% |
1024 | 2560 | 640 | ✅ | 82% | 69% | ✅ | 76% | 75% |
4096 | 4096 | 40 | ✅ | 62% | 35% | ✅ | 50% | 40% |
4096 | 40 | 4096 | ✅ | 50% | 11% | ✅ | 36% | 11% |
1024 | 1024 | 80 | ✅ | 54% | 42% | ✅ | 48% | 48% |
1024 | 80 | 1024 | ✅ | 43% | 17% | ✅ | 40% | 14% |
4096 | 320 | 320 | ✅ | 79% | 62% | ✅ | 70% | 68% |
4096 | 1713 | 40 | ✅ | 52% | 32% | ✅ | 40% | 34% |
4096 | 40 | 1713 | ✅ | 46% | 19% | ✅ | 40% | 9.7% |
4096 | 92 | 40 | ✅ | 28% | 7.2% | ✅ | 21% | 6.4% |
4096 | 40 | 92 | ✅ | 27% | 7.5% | ✅ | 19% | 7.3% |
1805 | 320 | 768 | ✅ | 75% | 51% | ✅ | 63% | 56% |
1805 | 1280 | 768 | ✅ | 81% | 64% | ❌ | 67% | 71% |
512 | 512 | 32 | ✅ | 26% | 14% | ✅ | 20% | 14% |
512 | 32 | 512 | ✅ | 8.2% | 8.1% | ❌ | 7.5% | 7.7% |
2048 | 2048 | 32 | ✅ | 61% | 40% | ✅ | 50% | 46% |
2048 | 32 | 2048 | ✅ | 35% | 35% | ✅ | 35% | 32% |
2048 | 2048 | 40 | ✅ | 56% | 32% | ✅ | 46% | 36% |
2048 | 40 | 2048 | ✅ | 40% | 5.8% | ✅ | 37% | 5.9% |
2048 | 2048 | 52 | ✅ | 52% | 32% | ✅ | 39% | 35% |
2048 | 52 | 2048 | ✅ | 49% | 14% | ✅ | 47% | 12% |
Compute utilization (higher is better). Check means MFA is much faster than MPS; X means either slower or same performance.
I will open source Metal FlashAttention, but it's not in a presentable state just yet. I am opening this thread to discuss anything relevant to integration, such as existing bottlenecks, simulation results, dependencies, etc.
cc: @ggerganov