Skip to content

vulkan: Add fusion support for RMS_NORM+MUL #14366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

jeffbolznv
Copy link
Collaborator

  • Add a use_count to ggml_tensor, so we can detect if an output is used more than once.
  • Change the ggml-vulkan rms_norm shader to optionally multiply by another tensor.
  • Add detection logic and basic fusion logic in ggml-vulkan.
  • Add some testing support for fusion. Rather than computing one node at a time, allow for computing the whole graph and just testing one node's results. Add rms_norm_mul tests and enable a llama test.

I picked rms_norm+mul to fuse because it's commonly generated in llm_graph_context::build_norm.

- Add a use_count to ggml_tensor, so we can detect if an output is used more than once.
- Change the ggml-vulkan rms_norm shader to optionally multiply by another tensor.
- Add detection logic and basic fusion logic in ggml-vulkan.
- Add some testing support for fusion. Rather than computing one node at a time, allow
for computing the whole graph and just testing one node's results. Add rms_norm_mul tests
and enable a llama test.
@jeffbolznv jeffbolznv requested a review from slaren June 24, 2025 20:12
@github-actions github-actions bot added testing Everything test related Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Jun 24, 2025
@jeffbolznv jeffbolznv marked this pull request as draft June 24, 2025 20:13
// Since norm is the first operand of mul, it must be the same shape
GGML_ASSERT(ggml_are_same_shape(mul, norm));

// XXX TODO: Do we need a way to indicate that the user doesn't need the intermediate result?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@slaren I'd like to get your thoughts on this. How do we know that the user doesn't want to have all the intermediate results available to them? Seems like they'd need to opt in or out of fusion somehow... GGML_TENSOR_FLAG_OUTPUT seems like almost what we would want, but it's slightly different and doesn't seem to be used much?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking for GGML_TENSOR_FLAG_OUTPUT is enough. This flag guarantees that the tensor will not be overwritten by another computation, so it is necessary to set this flag if the user wants the output of a tensor. In cases like imatrix, the graph is evaluated one operation at a time, so fusion should never be active.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Precisely because getting details like this one right is not completely obvious, it would be good to have a function in ggml similar to this one that checks if multiple operations can be fused, that all the backends can use to implement operator fusion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I can move some of the logic there, but I think eventually some of the implementation details may differ between backends. Hopefully there's a way to have some common logic and some backend-specific logic.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I've extracted out some common logic into a helper function.

@jeffbolznv
Copy link
Collaborator Author

Forgot to include perf results. Helps pp, but no meaningful change to tg.

before:
Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench -m C:\models\Phi-3-mini-4k-instruct-q4.gguf -fa 1 -n 128 -p 512 --prio 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           pp512 |      6213.91 ± 37.97 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           tg128 |        140.52 ± 0.86 |

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench -m C:\models\meta-llama-3-8b-instruct.Q4_K_M.gguf -fa 1 -n 128 -p 512 --prio 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |           pp512 |      3600.60 ± 15.70 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |         84.19 ± 0.21 |

after:
Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench -m C:\models\Phi-3-mini-4k-instruct-q4.gguf -fa 1 -n 128 -p 512 --prio 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           pp512 |      6456.74 ± 56.08 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     |  99 |  1 |           tg128 |        140.58 ± 1.12 |

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench -m C:\models\meta-llama-3-8b-instruct.Q4_K_M.gguf -fa 1 -n 128 -p 512 --prio 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |           pp512 |      3644.10 ± 12.41 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |           tg128 |         84.04 ± 0.22 |

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning testing Everything test related Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants