@@ -2101,7 +2101,7 @@ struct llama_hparams {
2101
2101
uint32_t n_ff_shexp = 0;
2102
2102
uint32_t n_expert_shared = 0;
2103
2103
float expert_weights_scale = 0.0;
2104
- uint32_t n_sliding = 0; // sliding window attention (SWA)
2104
+ uint32_t n_swa = 0; // sliding window attention (SWA)
2105
2105
2106
2106
float f_norm_eps;
2107
2107
float f_norm_rms_eps;
@@ -2665,7 +2665,7 @@ struct llama_context {
2665
2665
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
2666
2666
2667
2667
// KQ mask per layer, used by sliding window attention (gemma 2)
2668
- struct ggml_tensor * inp_KQ_mask_SWA ;
2668
+ struct ggml_tensor * inp_KQ_mask_swa ;
2669
2669
2670
2670
// control vectors
2671
2671
struct llama_control_vector cvec;
@@ -4715,8 +4715,8 @@ static void llm_load_hparams(
4715
4715
} break;
4716
4716
case LLM_ARCH_GEMMA2:
4717
4717
{
4718
- hparams.n_sliding = 4096; // default value of gemma 2
4719
- ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_sliding , false);
4718
+ hparams.n_swa = 4096; // default value of gemma 2
4719
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa , false);
4720
4720
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
4721
4721
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
4722
4722
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
@@ -7794,7 +7794,7 @@ struct llm_build_context {
7794
7794
lctx.inp_s_copy = nullptr;
7795
7795
lctx.inp_s_mask = nullptr;
7796
7796
lctx.inp_s_seq = nullptr;
7797
- lctx.inp_KQ_mask_SWA = nullptr;
7797
+ lctx.inp_KQ_mask_swa = nullptr;
7798
7798
}
7799
7799
7800
7800
void free() {
@@ -7954,7 +7954,7 @@ struct llm_build_context {
7954
7954
cb(KQ_mask, "KQ_mask", -1);
7955
7955
ggml_set_input(KQ_mask);
7956
7956
if (sliding_window) {
7957
- lctx.inp_KQ_mask_SWA = KQ_mask;
7957
+ lctx.inp_KQ_mask_swa = KQ_mask;
7958
7958
} else {
7959
7959
lctx.inp_KQ_mask = KQ_mask;
7960
7960
}
@@ -12689,14 +12689,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
12689
12689
12690
12690
float * data = (float *) lctx.inp_KQ_mask->data;
12691
12691
float * data_swa = nullptr;
12692
- const llama_pos n_keep_swa = hparams.n_sliding - batch.n_tokens;
12693
12692
12694
- if (lctx.model.arch == LLM_ARCH_GEMMA2) {
12695
- GGML_ASSERT(lctx.inp_KQ_mask_SWA);
12696
- GGML_ASSERT(hparams.n_sliding > 0);
12697
- data = (float *) lctx.inp_KQ_mask->data;
12698
- data_swa = (float *) lctx.inp_KQ_mask_SWA->data;
12699
- // because layer masks are alternate for gemma 2, we only need to take first 2 layers
12693
+ if (lctx.inp_KQ_mask_swa) {
12694
+ data_swa = (float *) lctx.inp_KQ_mask_swa->data;
12700
12695
}
12701
12696
12702
12697
// For causal attention, use only the previous KV cells
@@ -12722,7 +12717,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
12722
12717
12723
12718
// may need to cut off old tokens for sliding window
12724
12719
if (data_swa) {
12725
- if (pos - lctx.kv_self.cells[i].pos > n_keep_swa ) {
12720
+ if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa ) {
12726
12721
f = -INFINITY;
12727
12722
}
12728
12723
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
0 commit comments