@@ -2815,6 +2815,22 @@ struct llama_kv_cache {
2815
2815
}
2816
2816
};
2817
2817
2818
+ // saves the kv_cache state for future recovery
2819
+ // used to preserve the kv_cache state before searching for a slot
2820
+ struct llama_kv_slot_restorer {
2821
+ struct llama_kv_cache_state {
2822
+ uint32_t head = 0;
2823
+ uint32_t size = 0;
2824
+ uint32_t used = 0;
2825
+ uint32_t n = 0;
2826
+ } old_state;
2827
+
2828
+ std::vector<llama_kv_cell> recurrent_cells; // for recurrent models only
2829
+ std::pair<uint32_t, uint32_t> slot_boundaries; // for non-recurrent models only
2830
+
2831
+ bool restore = false;
2832
+ };
2833
+
2818
2834
struct llama_control_vector {
2819
2835
std::vector<struct ggml_tensor *> tensors; // per layer
2820
2836
std::vector<struct ggml_context *> ctxs;
@@ -3652,11 +3668,19 @@ static bool llama_kv_cache_init(
3652
3668
// to the first cell of the slot.
3653
3669
static bool llama_kv_cache_find_slot(
3654
3670
struct llama_kv_cache & cache,
3655
- const struct llama_ubatch & batch) {
3671
+ const struct llama_ubatch & batch,
3672
+ struct llama_kv_slot_restorer * slot_restorer = nullptr) {
3656
3673
const uint32_t n_tokens = batch.n_tokens;
3657
3674
const uint32_t n_seqs = batch.n_seqs;
3658
3675
const uint32_t n_seq_tokens = batch.n_seq_tokens;
3659
3676
3677
+ if (slot_restorer != nullptr) {
3678
+ slot_restorer->old_state.head = cache.head;
3679
+ slot_restorer->old_state.size = cache.size;
3680
+ slot_restorer->old_state.used = cache.used;
3681
+ slot_restorer->old_state.n = cache.n;
3682
+ }
3683
+
3660
3684
if (cache.recurrent) {
3661
3685
// For recurrent state architectures (like Mamba or RWKV),
3662
3686
// each cache cell can store the state for a whole sequence.
@@ -3665,6 +3689,11 @@ static bool llama_kv_cache_find_slot(
3665
3689
// can only process batches with an equal number of new tokens in each sequence
3666
3690
GGML_ASSERT(batch.equal_seqs);
3667
3691
3692
+ if (slot_restorer != nullptr) {
3693
+ slot_restorer->recurrent_cells = cache.cells;
3694
+ slot_restorer->restore = true;
3695
+ }
3696
+
3668
3697
int32_t min = cache.size - 1;
3669
3698
int32_t max = 0;
3670
3699
@@ -3853,6 +3882,11 @@ static bool llama_kv_cache_find_slot(
3853
3882
}
3854
3883
}
3855
3884
3885
+ if (slot_restorer != nullptr) {
3886
+ slot_restorer->slot_boundaries = std::make_pair(cache.head, cache.head + n_tokens);
3887
+ slot_restorer->restore = true;
3888
+ }
3889
+
3856
3890
for (uint32_t s = 0; s < n_seqs; s++) {
3857
3891
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
3858
3892
uint32_t k = s*n_seq_tokens + i;
@@ -4142,6 +4176,23 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
4142
4176
return cparams.flash_attn ? 256u : 32u;
4143
4177
}
4144
4178
4179
+ static void llama_kv_cache_slot_restore(
4180
+ const struct llama_kv_slot_restorer & restorer,
4181
+ struct llama_kv_cache & cache) {
4182
+ if (restorer.restore) {
4183
+ cache.head = restorer.old_state.head;
4184
+ cache.size = restorer.old_state.size;
4185
+ cache.used = restorer.old_state.used;
4186
+ cache.n = restorer.old_state.n;
4187
+
4188
+ if (cache.recurrent) {
4189
+ cache.cells = restorer.recurrent_cells;
4190
+ } else {
4191
+ llama_kv_cache_seq_rm(cache, -1, restorer.slot_boundaries.first, restorer.slot_boundaries.second + 1);
4192
+ }
4193
+ }
4194
+ }
4195
+
4145
4196
//
4146
4197
// model loading and saving
4147
4198
//
@@ -17184,6 +17235,7 @@ static int llama_decode_internal(
17184
17235
lctx.n_queued_tokens += n_tokens_all;
17185
17236
17186
17237
auto & kv_self = lctx.kv_self;
17238
+ llama_kv_slot_restorer kv_slot_restorer;
17187
17239
17188
17240
const int64_t n_embd = hparams.n_embd;
17189
17241
const int64_t n_vocab = hparams.n_vocab;
@@ -17268,7 +17320,7 @@ static int llama_decode_internal(
17268
17320
kv_self.head = 0;
17269
17321
}
17270
17322
17271
- if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
17323
+ if (!llama_kv_cache_find_slot(kv_self, ubatch, &kv_slot_restorer )) {
17272
17324
return 1;
17273
17325
}
17274
17326
@@ -17318,16 +17370,17 @@ static int llama_decode_internal(
17318
17370
llama_set_inputs(lctx, ubatch);
17319
17371
17320
17372
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17321
- switch (compute_status) {
17322
- case GGML_STATUS_SUCCESS:
17323
- break;
17324
- case GGML_STATUS_ABORTED:
17325
- return 2;
17326
- case GGML_STATUS_ALLOC_FAILED:
17327
- return -2;
17328
- case GGML_STATUS_FAILED:
17329
- default:
17330
- return -3;
17373
+ if (compute_status != GGML_STATUS_SUCCESS) {
17374
+ llama_kv_cache_slot_restore(kv_slot_restorer, kv_self);
17375
+ switch (compute_status) {
17376
+ case GGML_STATUS_ABORTED:
17377
+ return 2;
17378
+ case GGML_STATUS_ALLOC_FAILED:
17379
+ return -2;
17380
+ case GGML_STATUS_FAILED:
17381
+ default:
17382
+ return -3;
17383
+ }
17331
17384
}
17332
17385
17333
17386
// update the kv ring buffer
0 commit comments