Skip to content

Commit 0c05c60

Browse files
committed
llama: restore a kv_cache in case of failed computation
1 parent 7c083f5 commit 0c05c60

File tree

1 file changed

+65
-12
lines changed

1 file changed

+65
-12
lines changed

src/llama.cpp

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2815,6 +2815,22 @@ struct llama_kv_cache {
28152815
}
28162816
};
28172817

2818+
// saves the kv_cache state for future recovery
2819+
// used to preserve the kv_cache state before searching for a slot
2820+
struct llama_kv_slot_restorer {
2821+
struct llama_kv_cache_state {
2822+
uint32_t head = 0;
2823+
uint32_t size = 0;
2824+
uint32_t used = 0;
2825+
uint32_t n = 0;
2826+
} old_state;
2827+
2828+
std::vector<llama_kv_cell> recurrent_cells; // for recurrent models only
2829+
std::pair<uint32_t, uint32_t> slot_boundaries; // for non-recurrent models only
2830+
2831+
bool restore = false;
2832+
};
2833+
28182834
struct llama_control_vector {
28192835
std::vector<struct ggml_tensor *> tensors; // per layer
28202836
std::vector<struct ggml_context *> ctxs;
@@ -3652,11 +3668,19 @@ static bool llama_kv_cache_init(
36523668
// to the first cell of the slot.
36533669
static bool llama_kv_cache_find_slot(
36543670
struct llama_kv_cache & cache,
3655-
const struct llama_ubatch & batch) {
3671+
const struct llama_ubatch & batch,
3672+
struct llama_kv_slot_restorer * slot_restorer = nullptr) {
36563673
const uint32_t n_tokens = batch.n_tokens;
36573674
const uint32_t n_seqs = batch.n_seqs;
36583675
const uint32_t n_seq_tokens = batch.n_seq_tokens;
36593676

3677+
if (slot_restorer != nullptr) {
3678+
slot_restorer->old_state.head = cache.head;
3679+
slot_restorer->old_state.size = cache.size;
3680+
slot_restorer->old_state.used = cache.used;
3681+
slot_restorer->old_state.n = cache.n;
3682+
}
3683+
36603684
if (cache.recurrent) {
36613685
// For recurrent state architectures (like Mamba or RWKV),
36623686
// each cache cell can store the state for a whole sequence.
@@ -3665,6 +3689,11 @@ static bool llama_kv_cache_find_slot(
36653689
// can only process batches with an equal number of new tokens in each sequence
36663690
GGML_ASSERT(batch.equal_seqs);
36673691

3692+
if (slot_restorer != nullptr) {
3693+
slot_restorer->recurrent_cells = cache.cells;
3694+
slot_restorer->restore = true;
3695+
}
3696+
36683697
int32_t min = cache.size - 1;
36693698
int32_t max = 0;
36703699

@@ -3853,6 +3882,11 @@ static bool llama_kv_cache_find_slot(
38533882
}
38543883
}
38553884

3885+
if (slot_restorer != nullptr) {
3886+
slot_restorer->slot_boundaries = std::make_pair(cache.head, cache.head + n_tokens);
3887+
slot_restorer->restore = true;
3888+
}
3889+
38563890
for (uint32_t s = 0; s < n_seqs; s++) {
38573891
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
38583892
uint32_t k = s*n_seq_tokens + i;
@@ -4142,6 +4176,23 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
41424176
return cparams.flash_attn ? 256u : 32u;
41434177
}
41444178

4179+
static void llama_kv_cache_slot_restore(
4180+
const struct llama_kv_slot_restorer & restorer,
4181+
struct llama_kv_cache & cache) {
4182+
if (restorer.restore) {
4183+
cache.head = restorer.old_state.head;
4184+
cache.size = restorer.old_state.size;
4185+
cache.used = restorer.old_state.used;
4186+
cache.n = restorer.old_state.n;
4187+
4188+
if (cache.recurrent) {
4189+
cache.cells = restorer.recurrent_cells;
4190+
} else {
4191+
llama_kv_cache_seq_rm(cache, -1, restorer.slot_boundaries.first, restorer.slot_boundaries.second + 1);
4192+
}
4193+
}
4194+
}
4195+
41454196
//
41464197
// model loading and saving
41474198
//
@@ -17184,6 +17235,7 @@ static int llama_decode_internal(
1718417235
lctx.n_queued_tokens += n_tokens_all;
1718517236

1718617237
auto & kv_self = lctx.kv_self;
17238+
llama_kv_slot_restorer kv_slot_restorer;
1718717239

1718817240
const int64_t n_embd = hparams.n_embd;
1718917241
const int64_t n_vocab = hparams.n_vocab;
@@ -17268,7 +17320,7 @@ static int llama_decode_internal(
1726817320
kv_self.head = 0;
1726917321
}
1727017322

17271-
if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
17323+
if (!llama_kv_cache_find_slot(kv_self, ubatch, &kv_slot_restorer)) {
1727217324
return 1;
1727317325
}
1727417326

@@ -17318,16 +17370,17 @@ static int llama_decode_internal(
1731817370
llama_set_inputs(lctx, ubatch);
1731917371

1732017372
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17321-
switch (compute_status) {
17322-
case GGML_STATUS_SUCCESS:
17323-
break;
17324-
case GGML_STATUS_ABORTED:
17325-
return 2;
17326-
case GGML_STATUS_ALLOC_FAILED:
17327-
return -2;
17328-
case GGML_STATUS_FAILED:
17329-
default:
17330-
return -3;
17373+
if (compute_status != GGML_STATUS_SUCCESS) {
17374+
llama_kv_cache_slot_restore(kv_slot_restorer, kv_self);
17375+
switch (compute_status) {
17376+
case GGML_STATUS_ABORTED:
17377+
return 2;
17378+
case GGML_STATUS_ALLOC_FAILED:
17379+
return -2;
17380+
case GGML_STATUS_FAILED:
17381+
default:
17382+
return -3;
17383+
}
1733117384
}
1733217385

1733317386
// update the kv ring buffer

0 commit comments

Comments
 (0)