Skip to content

Commit d09ecb8

Browse files
committed
replace list with single tensor
1 parent 231dae4 commit d09ecb8

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

src/llama.cpp

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2081,7 +2081,7 @@ struct llama_hparams {
20812081
bool use_par_res;
20822082

20832083
uint32_t n_vocab;
2084-
uint32_t n_ctx_train; // context size the model was trained on
2084+
uint32_t n_ctx_train; // context size the model was trained on
20852085
uint32_t n_embd;
20862086
uint32_t n_head;
20872087
uint32_t n_head_kv;
@@ -2665,7 +2665,7 @@ struct llama_context {
26652665
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
26662666

26672667
// KQ mask per layer, used by sliding window attention (gemma 2)
2668-
std::vector<struct ggml_tensor *> inp_KQ_mask_l;
2668+
struct ggml_tensor * inp_KQ_mask_SWA;
26692669

26702670
// control vectors
26712671
struct llama_control_vector cvec;
@@ -7794,6 +7794,7 @@ struct llm_build_context {
77947794
lctx.inp_s_copy = nullptr;
77957795
lctx.inp_s_mask = nullptr;
77967796
lctx.inp_s_seq = nullptr;
7797+
lctx.inp_KQ_mask_SWA = nullptr;
77977798
}
77987799

77997800
void free() {
@@ -7946,15 +7947,18 @@ struct llm_build_context {
79467947
return lctx.inp_out_ids;
79477948
}
79487949

7949-
struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
7950-
if (causal) {
7951-
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7950+
struct ggml_tensor * build_inp_KQ_mask(bool causal = true, bool sliding_window = false) {
7951+
struct ggml_tensor * KQ_mask = causal
7952+
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
7953+
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7954+
cb(KQ_mask, "KQ_mask", -1);
7955+
ggml_set_input(KQ_mask);
7956+
if (sliding_window) {
7957+
lctx.inp_KQ_mask_SWA = KQ_mask;
79527958
} else {
7953-
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
7959+
lctx.inp_KQ_mask = KQ_mask;
79547960
}
7955-
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
7956-
ggml_set_input(lctx.inp_KQ_mask);
7957-
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
7961+
return flash_attn ? ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16) : KQ_mask;
79587962
}
79597963

79607964
struct ggml_tensor * build_inp_mean() {
@@ -11038,14 +11042,12 @@ struct llm_build_context {
1103811042

1103911043
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
1104011044
// gemma 2 requires different mask for layers using sliding window (SWA)
11041-
struct ggml_tensor * KQ_mask_full = build_inp_KQ_mask();
11042-
struct ggml_tensor * KQ_mask_SWA = build_inp_KQ_mask();
11043-
lctx.inp_KQ_mask_l.clear();
11045+
struct ggml_tensor * KQ_mask_full = build_inp_KQ_mask(true, false);
11046+
struct ggml_tensor * KQ_mask_SWA = build_inp_KQ_mask(true, true);
1104411047

1104511048
for (int il = 0; il < n_layer; ++il) {
1104611049
// (il % 2) layers use SWA
1104711050
struct ggml_tensor * KQ_mask = (il % 2 == 0) ? KQ_mask_SWA : KQ_mask_full;
11048-
lctx.inp_KQ_mask_l.push_back(KQ_mask);
1104911051

1105011052
// norm
1105111053
cur = llm_build_norm(ctx0, inpL, hparams,
@@ -12685,15 +12687,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1268512687

1268612688
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
1268712689

12688-
float * data = (float *) lctx.inp_KQ_mask->data;
12690+
float * data = (float *) lctx.inp_KQ_mask->data;
1268912691
float * data_swa = nullptr;
1269012692
const llama_pos n_keep_swa = hparams.n_sliding - batch.n_tokens;
1269112693

1269212694
if (lctx.model.arch == LLM_ARCH_GEMMA2) {
12693-
GGML_ASSERT(!lctx.inp_KQ_mask_l.empty() && "gemma 2 requires different KQ mask per layer");
12695+
GGML_ASSERT(lctx.inp_KQ_mask_SWA);
1269412696
GGML_ASSERT(hparams.n_sliding > 0);
12695-
data_swa = (float *) lctx.inp_KQ_mask_l[0]->data;
12696-
data = (float *) lctx.inp_KQ_mask_l[1]->data;
12697+
data = (float *) lctx.inp_KQ_mask->data;
12698+
data_swa = (float *) lctx.inp_KQ_mask_SWA->data;
1269712699
// because layer masks are alternate for gemma 2, we only need to take first 2 layers
1269812700
}
1269912701

0 commit comments

Comments
 (0)