@@ -2081,7 +2081,7 @@ struct llama_hparams {
2081
2081
bool use_par_res;
2082
2082
2083
2083
uint32_t n_vocab;
2084
- uint32_t n_ctx_train; // context size the model was trained on
2084
+ uint32_t n_ctx_train; // context size the model was trained on
2085
2085
uint32_t n_embd;
2086
2086
uint32_t n_head;
2087
2087
uint32_t n_head_kv;
@@ -2665,7 +2665,7 @@ struct llama_context {
2665
2665
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
2666
2666
2667
2667
// KQ mask per layer, used by sliding window attention (gemma 2)
2668
- std::vector< struct ggml_tensor *> inp_KQ_mask_l ;
2668
+ struct ggml_tensor * inp_KQ_mask_SWA ;
2669
2669
2670
2670
// control vectors
2671
2671
struct llama_control_vector cvec;
@@ -7794,6 +7794,7 @@ struct llm_build_context {
7794
7794
lctx.inp_s_copy = nullptr;
7795
7795
lctx.inp_s_mask = nullptr;
7796
7796
lctx.inp_s_seq = nullptr;
7797
+ lctx.inp_KQ_mask_SWA = nullptr;
7797
7798
}
7798
7799
7799
7800
void free() {
@@ -7946,15 +7947,18 @@ struct llm_build_context {
7946
7947
return lctx.inp_out_ids;
7947
7948
}
7948
7949
7949
- struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
7950
- if (causal) {
7951
- lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7950
+ struct ggml_tensor * build_inp_KQ_mask(bool causal = true, bool sliding_window = false) {
7951
+ struct ggml_tensor * KQ_mask = causal
7952
+ ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
7953
+ : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7954
+ cb(KQ_mask, "KQ_mask", -1);
7955
+ ggml_set_input(KQ_mask);
7956
+ if (sliding_window) {
7957
+ lctx.inp_KQ_mask_SWA = KQ_mask;
7952
7958
} else {
7953
- lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) ;
7959
+ lctx.inp_KQ_mask = KQ_mask ;
7954
7960
}
7955
- cb(lctx.inp_KQ_mask, "KQ_mask", -1);
7956
- ggml_set_input(lctx.inp_KQ_mask);
7957
- return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
7961
+ return flash_attn ? ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16) : KQ_mask;
7958
7962
}
7959
7963
7960
7964
struct ggml_tensor * build_inp_mean() {
@@ -11038,14 +11042,12 @@ struct llm_build_context {
11038
11042
11039
11043
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
11040
11044
// gemma 2 requires different mask for layers using sliding window (SWA)
11041
- struct ggml_tensor * KQ_mask_full = build_inp_KQ_mask();
11042
- struct ggml_tensor * KQ_mask_SWA = build_inp_KQ_mask();
11043
- lctx.inp_KQ_mask_l.clear();
11045
+ struct ggml_tensor * KQ_mask_full = build_inp_KQ_mask(true, false);
11046
+ struct ggml_tensor * KQ_mask_SWA = build_inp_KQ_mask(true, true);
11044
11047
11045
11048
for (int il = 0; il < n_layer; ++il) {
11046
11049
// (il % 2) layers use SWA
11047
11050
struct ggml_tensor * KQ_mask = (il % 2 == 0) ? KQ_mask_SWA : KQ_mask_full;
11048
- lctx.inp_KQ_mask_l.push_back(KQ_mask);
11049
11051
11050
11052
// norm
11051
11053
cur = llm_build_norm(ctx0, inpL, hparams,
@@ -12685,15 +12687,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
12685
12687
12686
12688
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
12687
12689
12688
- float * data = (float *) lctx.inp_KQ_mask->data;
12690
+ float * data = (float *) lctx.inp_KQ_mask->data;
12689
12691
float * data_swa = nullptr;
12690
12692
const llama_pos n_keep_swa = hparams.n_sliding - batch.n_tokens;
12691
12693
12692
12694
if (lctx.model.arch == LLM_ARCH_GEMMA2) {
12693
- GGML_ASSERT(! lctx.inp_KQ_mask_l.empty() && "gemma 2 requires different KQ mask per layer" );
12695
+ GGML_ASSERT(lctx.inp_KQ_mask_SWA );
12694
12696
GGML_ASSERT(hparams.n_sliding > 0);
12695
- data_swa = (float *) lctx.inp_KQ_mask_l[0] ->data;
12696
- data = (float *) lctx.inp_KQ_mask_l[1] ->data;
12697
+ data = (float *) lctx.inp_KQ_mask ->data;
12698
+ data_swa = (float *) lctx.inp_KQ_mask_SWA ->data;
12697
12699
// because layer masks are alternate for gemma 2, we only need to take first 2 layers
12698
12700
}
12699
12701
0 commit comments