Skip to content

Commit ed5496f

Browse files
committed
update
1 parent d09ecb8 commit ed5496f

File tree

1 file changed

+9
-14
lines changed

1 file changed

+9
-14
lines changed

src/llama.cpp

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2101,7 +2101,7 @@ struct llama_hparams {
21012101
uint32_t n_ff_shexp = 0;
21022102
uint32_t n_expert_shared = 0;
21032103
float expert_weights_scale = 0.0;
2104-
uint32_t n_sliding = 0; // sliding window attention (SWA)
2104+
uint32_t n_swa = 0; // sliding window attention (SWA)
21052105

21062106
float f_norm_eps;
21072107
float f_norm_rms_eps;
@@ -2665,7 +2665,7 @@ struct llama_context {
26652665
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
26662666

26672667
// KQ mask per layer, used by sliding window attention (gemma 2)
2668-
struct ggml_tensor * inp_KQ_mask_SWA;
2668+
struct ggml_tensor * inp_KQ_mask_swa;
26692669

26702670
// control vectors
26712671
struct llama_control_vector cvec;
@@ -4715,8 +4715,8 @@ static void llm_load_hparams(
47154715
} break;
47164716
case LLM_ARCH_GEMMA2:
47174717
{
4718-
hparams.n_sliding = 4096; // default value of gemma 2
4719-
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_sliding, false);
4718+
hparams.n_swa = 4096; // default value of gemma 2
4719+
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
47204720
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
47214721
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
47224722
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
@@ -7794,7 +7794,7 @@ struct llm_build_context {
77947794
lctx.inp_s_copy = nullptr;
77957795
lctx.inp_s_mask = nullptr;
77967796
lctx.inp_s_seq = nullptr;
7797-
lctx.inp_KQ_mask_SWA = nullptr;
7797+
lctx.inp_KQ_mask_swa = nullptr;
77987798
}
77997799

78007800
void free() {
@@ -7954,7 +7954,7 @@ struct llm_build_context {
79547954
cb(KQ_mask, "KQ_mask", -1);
79557955
ggml_set_input(KQ_mask);
79567956
if (sliding_window) {
7957-
lctx.inp_KQ_mask_SWA = KQ_mask;
7957+
lctx.inp_KQ_mask_swa = KQ_mask;
79587958
} else {
79597959
lctx.inp_KQ_mask = KQ_mask;
79607960
}
@@ -12689,14 +12689,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1268912689

1269012690
float * data = (float *) lctx.inp_KQ_mask->data;
1269112691
float * data_swa = nullptr;
12692-
const llama_pos n_keep_swa = hparams.n_sliding - batch.n_tokens;
1269312692

12694-
if (lctx.model.arch == LLM_ARCH_GEMMA2) {
12695-
GGML_ASSERT(lctx.inp_KQ_mask_SWA);
12696-
GGML_ASSERT(hparams.n_sliding > 0);
12697-
data = (float *) lctx.inp_KQ_mask->data;
12698-
data_swa = (float *) lctx.inp_KQ_mask_SWA->data;
12699-
// because layer masks are alternate for gemma 2, we only need to take first 2 layers
12693+
if (lctx.inp_KQ_mask_swa) {
12694+
data_swa = (float *) lctx.inp_KQ_mask_swa->data;
1270012695
}
1270112696

1270212697
// For causal attention, use only the previous KV cells
@@ -12722,7 +12717,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1272212717

1272312718
// may need to cut off old tokens for sliding window
1272412719
if (data_swa) {
12725-
if (pos - lctx.kv_self.cells[i].pos > n_keep_swa) {
12720+
if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
1272612721
f = -INFINITY;
1272712722
}
1272812723
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;

0 commit comments

Comments
 (0)