Description
What happened?
Since commit b3e5859 (PR #10301), The vulkan backend tends to glitch out after a few tokens, this happened to me a few times with llama-3 8B, and happens consistently with qwen2.5-Coder 0.5B. My guess is that there must be invalid values like NaN or Inf appearing somewhere during computation of softmax.
How to reproduce.
Command:
.\build\bin\Release\llama-cli.exe -m E:\Downloads\Qwen2.5-Coder-0.5B.f16.gguf -ngl 99 -t 6 -tb 12 -p "Hello, I'm a" --seed 0 -n 512
Output:
Hello, I'm a little bit of a beginner with C++ and I have a question regarding the
std::vector
. I am trying to add an element to the end of the vector, but I am getting a segmentation fault error. Could you help me with this? Sure, I can help with that. To add an element to the end of a vector in C++,GGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG
Expected output (commit 557924f)
Hello, I'm a little bit of a beginner with C++ and I have a question regarding the
std::vector
. I am trying to add an element to the end of the vector, but I am getting a segmentation fault error. Could you help me with this? Sure, I can help with that. To add an element to the end of a vector in C++, you can use thepush_back()
function. Thepush_back()
function takes a single argument which is the element you want to add to the vector. Here is an example:vector<int> myVector; myVector.push_back(5);In the above code,
myVector
is a reference to a vector of integers. Thepush_back()
function is used to add an element to the end of the vector. The element is specified after thepush_back()
function call. If you want to add multiple elements at once, you can separate them with commas.Please let me know if you have any other questions. [end of text]
Git bisect results
b3e585988fc65d3a8083c6d94dfc0629f9ce226d is the first bad commit
commit b3e585988fc65d3a8083c6d94dfc0629f9ce226d (tag: b4128)
Author: Jeff Bolz <[email protected]>
Date: Tue Nov 19 01:25:17 2024 -0600
vulkan: Optimize soft_max (#10301)
* vulkan: Optimize soft_max
Large soft_max could already saturate memory, but small/medium sizes were
pretty slow. The bulk of the gains for them comes from using a smaller
workgroup size, and making the workgroup size match the subgroup size also
makes the barriers much cheaper.
Cache some values in locals to avoid refetching/recomputing. And stamp
out a few "template instantiations" so smaller cases will fully unroll.
Add a missing early return for OOB rows. This happens when there are more
than 512 rows and the dispatch is 512 x H.
* vulkan: Further soft_max optimizations
Restore the workgroup size of 512 case, use it for >1024.
Use unrollable loops for more iteration counts.
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 13 ++-
ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp | 112 +++++++++++++++++-----
tests/test-backend-ops.cpp | 8 ++
3 files changed, 106 insertions(+), 27 deletions(-)
Full logs
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = AMD Radeon RX 5700 XT (AMD proprietary driver) | uma: 0 | fp16: 1 | warp size: 64
build: 4128 (b3e585988) with MSVC 19.41.34120.0 for x64
main: llama backend init
main: load the model and apply lora adapter, if any
llama_load_model_from_file: using device Vulkan0 (AMD Radeon RX 5700 XT) - 8176 MiB free
llama_model_loader: loaded meta data with 40 key-value pairs and 290 tensors from E:\Downloads\Qwen2.5-Coder-0.5B.f16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv 0: general.architecture str = qwen2
llama_model_loader: - kv 1: general.type str = model
llama_model_loader: - kv 2: general.name str = Qwen2.5 Coder 0.5B
llama_model_loader: - kv 3: general.basename str = Qwen2.5-Coder
llama_model_loader: - kv 4: general.size_label str = 0.5B
llama_model_loader: - kv 5: general.license str = apache-2.0
llama_model_loader: - kv 6: general.license.link str = https://huggingface.co/Qwen/Qwen2.5-C...
llama_model_loader: - kv 7: general.base_model.count u32 = 1
llama_model_loader: - kv 8: general.base_model.0.name str = Qwen2.5 0.5B
llama_model_loader: - kv 9: general.base_model.0.organization str = Qwen
llama_model_loader: - kv 10: general.base_model.0.repo_url str = https://huggingface.co/Qwen/Qwen2.5-0.5B
llama_model_loader: - kv 11: general.tags arr[str,5] = ["code", "qwen", "qwen-coder", "codeq...
llama_model_loader: - kv 12: general.languages arr[str,1] = ["en"]
llama_model_loader: - kv 13: qwen2.block_count u32 = 24
llama_model_loader: - kv 14: qwen2.context_length u32 = 32768
llama_model_loader: - kv 15: qwen2.embedding_length u32 = 896
llama_model_loader: - kv 16: qwen2.feed_forward_length u32 = 4864
llama_model_loader: - kv 17: qwen2.attention.head_count u32 = 14
llama_model_loader: - kv 18: qwen2.attention.head_count_kv u32 = 2
llama_model_loader: - kv 19: qwen2.rope.freq_base f32 = 1000000.000000
llama_model_loader: - kv 20: qwen2.attention.layer_norm_rms_epsilon f32 = 0.000001
llama_model_loader: - kv 21: general.file_type u32 = 1
llama_model_loader: - kv 22: tokenizer.ggml.model str = gpt2
llama_model_loader: - kv 23: tokenizer.ggml.pre str = qwen2
llama_model_loader: - kv 24: tokenizer.ggml.tokens arr[str,151936] = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv 25: tokenizer.ggml.token_type arr[i32,151936] = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv 26: tokenizer.ggml.merges arr[str,151387] = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
llama_model_loader: - kv 27: tokenizer.ggml.eos_token_id u32 = 151645
llama_model_loader: - kv 28: tokenizer.ggml.padding_token_id u32 = 151643
llama_model_loader: - kv 29: tokenizer.ggml.bos_token_id u32 = 151643
llama_model_loader: - kv 30: tokenizer.ggml.add_bos_token bool = false
llama_model_loader: - kv 31: tokenizer.chat_template str = {%- if tools %}\n {{- '<|im_start|>...
llama_model_loader: - kv 32: general.quantization_version u32 = 2
llama_model_loader: - kv 33: general.url str = https://huggingface.co/mradermacher/Q...
llama_model_loader: - kv 34: mradermacher.quantize_version str = 2
llama_model_loader: - kv 35: mradermacher.quantized_by str = mradermacher
llama_model_loader: - kv 36: mradermacher.quantized_at str = 2024-11-12T06:29:06+01:00
llama_model_loader: - kv 37: mradermacher.quantized_on str = db2
llama_model_loader: - kv 38: general.source.url str = https://huggingface.co/Qwen/Qwen2.5-C...
llama_model_loader: - kv 39: mradermacher.convert_type str = hf
llama_model_loader: - type f32: 121 tensors
llama_model_loader: - type f16: 169 tensors
llm_load_vocab: special tokens cache size = 22
llm_load_vocab: token to piece cache size = 0.9310 MB
llm_load_print_meta: format = GGUF V3 (latest)
llm_load_print_meta: arch = qwen2
llm_load_print_meta: vocab type = BPE
llm_load_print_meta: n_vocab = 151936
llm_load_print_meta: n_merges = 151387
llm_load_print_meta: vocab_only = 0
llm_load_print_meta: n_ctx_train = 32768
llm_load_print_meta: n_embd = 896
llm_load_print_meta: n_layer = 24
llm_load_print_meta: n_head = 14
llm_load_print_meta: n_head_kv = 2
llm_load_print_meta: n_rot = 64
llm_load_print_meta: n_swa = 0
llm_load_print_meta: n_embd_head_k = 64
llm_load_print_meta: n_embd_head_v = 64
llm_load_print_meta: n_gqa = 7
llm_load_print_meta: n_embd_k_gqa = 128
llm_load_print_meta: n_embd_v_gqa = 128
llm_load_print_meta: f_norm_eps = 0.0e+00
llm_load_print_meta: f_norm_rms_eps = 1.0e-06
llm_load_print_meta: f_clamp_kqv = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale = 0.0e+00
llm_load_print_meta: n_ff = 4864
llm_load_print_meta: n_expert = 0
llm_load_print_meta: n_expert_used = 0
llm_load_print_meta: causal attn = 1
llm_load_print_meta: pooling type = 0
llm_load_print_meta: rope type = 2
llm_load_print_meta: rope scaling = linear
llm_load_print_meta: freq_base_train = 1000000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn = 32768
llm_load_print_meta: rope_finetuned = unknown
llm_load_print_meta: ssm_d_conv = 0
llm_load_print_meta: ssm_d_inner = 0
llm_load_print_meta: ssm_d_state = 0
llm_load_print_meta: ssm_dt_rank = 0
llm_load_print_meta: ssm_dt_b_c_rms = 0
llm_load_print_meta: model type = 1B
llm_load_print_meta: model ftype = F16
llm_load_print_meta: model params = 494.03 M
llm_load_print_meta: model size = 942.43 MiB (16.00 BPW)
llm_load_print_meta: general.name = Qwen2.5 Coder 0.5B
llm_load_print_meta: BOS token = 151643 '<|endoftext|>'
llm_load_print_meta: EOS token = 151645 '<|im_end|>'
llm_load_print_meta: EOT token = 151645 '<|im_end|>'
llm_load_print_meta: PAD token = 151643 '<|endoftext|>'
llm_load_print_meta: LF token = 148848 'ÄĬ'
llm_load_print_meta: FIM PRE token = 151659 '<|fim_prefix|>'
llm_load_print_meta: FIM SUF token = 151661 '<|fim_suffix|>'
llm_load_print_meta: FIM MID token = 151660 '<|fim_middle|>'
llm_load_print_meta: FIM PAD token = 151662 '<|fim_pad|>'
llm_load_print_meta: FIM REP token = 151663 '<|repo_name|>'
llm_load_print_meta: FIM SEP token = 151664 '<|file_sep|>'
llm_load_print_meta: EOG token = 151643 '<|endoftext|>'
llm_load_print_meta: EOG token = 151645 '<|im_end|>'
llm_load_print_meta: EOG token = 151662 '<|fim_pad|>'
llm_load_print_meta: EOG token = 151663 '<|repo_name|>'
llm_load_print_meta: EOG token = 151664 '<|file_sep|>'
llm_load_print_meta: max token length = 256
ggml_vulkan: Compiling shaders..............................Done!
llm_load_tensors: offloading 24 repeating layers to GPU
llm_load_tensors: offloading output layer to GPU
llm_load_tensors: offloaded 25/25 layers to GPU
llm_load_tensors: Vulkan0 model buffer size = 942.43 MiB
llm_load_tensors: CPU_Mapped model buffer size = 259.66 MiB
...........................................................
llama_new_context_with_model: n_seq_max = 1
llama_new_context_with_model: n_ctx = 4096
llama_new_context_with_model: n_ctx_per_seq = 4096
llama_new_context_with_model: n_batch = 2048
llama_new_context_with_model: n_ubatch = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base = 1000000.0
llama_new_context_with_model: freq_scale = 1
llama_new_context_with_model: n_ctx_per_seq (4096) < n_ctx_train (32768) -- the full capacity of the model will not be utilized
llama_kv_cache_init: Vulkan0 KV buffer size = 48.00 MiB
llama_new_context_with_model: KV self size = 48.00 MiB, K (f16): 24.00 MiB, V (f16): 24.00 MiB
llama_new_context_with_model: Vulkan_Host output buffer size = 0.58 MiB
llama_new_context_with_model: Vulkan0 compute buffer size = 298.50 MiB
llama_new_context_with_model: Vulkan_Host compute buffer size = 9.76 MiB
llama_new_context_with_model: graph nodes = 846
llama_new_context_with_model: graph splits = 2
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
main: llama threadpool init, n_threads = 6
system_info: n_threads = 6 (n_threads_batch = 12) / 24 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | AMX_INT8 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | RISCV_VECT = 0 | WASM_SIMD = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |
sampler seed: 0
sampler params:
repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
dry_multiplier = 0.000, dry_base = 1.750, dry_allowed_length = 2, dry_penalty_last_n = -1
top_k = 40, top_p = 0.950, min_p = 0.050, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, temp = 0.800
mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> dry -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist
generate: n_ctx = 4096, n_batch = 2048, n_predict = 512, n_keep = 0
Hello, I'm a little bit of a beginner with C++ and I have a question regarding the `std::vector`. I am trying to add an element to the end of the vector, but I am getting a segmentation fault error. Could you help me with this? Sure, I can help with that. To add an element to the end of a vector in C++,GGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG
llama_perf_sampler_print: sampling time = 46.22 ms / 517 runs ( 0.09 ms per token, 11186.60 tokens per second)
llama_perf_context_print: load time = 1177.49 ms
llama_perf_context_print: prompt eval time = 50.35 ms / 5 tokens ( 10.07 ms per token, 99.30 tokens per second)
llama_perf_context_print: eval time = 2678.84 ms / 511 runs ( 5.24 ms per token, 190.75 tokens per second)
llama_perf_context_print: total time = 2830.83 ms / 516 tokens
Name and Version
.\build\bin\Release\llama-cli.exe --version
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = AMD Radeon RX 5700 XT (AMD proprietary driver) | uma: 0 | fp16: 1 | warp size: 64
version: 4128 (b3e5859)
built with MSVC 19.41.34120.0 for x64
What operating system are you seeing the problem on?
Windows