@@ -764,7 +764,7 @@ struct completion_token_output {
764
764
}
765
765
};
766
766
767
- struct swa_checkpoint {
767
+ struct ctx_checkpoint {
768
768
llama_pos pos_min;
769
769
llama_pos pos_max;
770
770
@@ -1460,7 +1460,7 @@ struct server_slot {
1460
1460
1461
1461
std::vector<completion_token_output> generated_token_probs;
1462
1462
1463
- std::vector<swa_checkpoint> swa_checkpoints ;
1463
+ std::vector<ctx_checkpoint> ctx_checkpoints ;
1464
1464
1465
1465
bool has_next_token = true ;
1466
1466
bool has_new_line = false ;
@@ -3541,7 +3541,11 @@ struct server_context {
3541
3541
slot.n_past = 0 ;
3542
3542
}
3543
3543
3544
- const auto n_swa = llama_model_n_swa (model);
3544
+ // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
3545
+ const auto n_swa = std::max (1 , llama_model_n_swa (model));
3546
+
3547
+ // the largest pos_min required for a checkpoint to be useful
3548
+ const auto pos_min_thold = std::max (0 , slot.n_past - n_swa);
3545
3549
3546
3550
if (slot.n_past > 0 && slot.n_past < (int ) slot.cache_tokens .size ()) {
3547
3551
const auto pos_min = llama_memory_seq_pos_min (llama_get_memory (ctx), slot.id );
@@ -3550,66 +3554,62 @@ struct server_context {
3550
3554
GGML_ABORT (" pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237" );
3551
3555
}
3552
3556
3553
- const auto pos_min_thold = std::max (0 , slot.n_past - n_swa);
3554
-
3555
3557
if (pos_min > pos_min_thold) {
3556
3558
SLT_WRN (slot, " n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n " , slot.n_past , (int ) slot.cache_tokens .size (), slot.id , pos_min, n_swa);
3557
3559
3558
- // search for a SWA checkpoint
3560
+ // search for a context checkpoint
3559
3561
const auto it = std::find_if (
3560
- slot.swa_checkpoints .rbegin (),
3561
- slot.swa_checkpoints .rend (),
3562
+ slot.ctx_checkpoints .rbegin (),
3563
+ slot.ctx_checkpoints .rend (),
3562
3564
[&](const auto & cur) {
3563
- return cur.pos_min <= pos_min_thold;
3565
+ // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
3566
+ return cur.pos_min < pos_min_thold;
3564
3567
}
3565
3568
);
3566
3569
3567
- bool do_reset = it == slot.swa_checkpoints .rend ();
3570
+ bool do_reset = it == slot.ctx_checkpoints .rend ();
3571
+ // printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false");
3568
3572
3569
3573
if (!do_reset) {
3570
- // restore the checkpoint
3571
- const size_t swa_size = it->data .size ();
3572
- const size_t n = llama_state_seq_set_data_ext (ctx, it->data .data (), swa_size , slot.id , LLAMA_STATE_SEQ_FLAGS_SWA_ONLY );
3574
+ // restore the context checkpoint
3575
+ const size_t ctx_checkpoint_size = it->data .size ();
3576
+ const size_t n = llama_state_seq_set_data_ext (ctx, it->data .data (), ctx_checkpoint_size , slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY );
3573
3577
3574
- if (n != swa_size ) {
3575
- SLT_ERR (slot, " failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n " , it->pos_min , it->pos_max , (float ) swa_size / 1024 / 1024 );
3578
+ if (n != ctx_checkpoint_size ) {
3579
+ SLT_ERR (slot, " failed to restore context checkpoint ( pos_min = %d, pos_max = %d, size = %.3f MiB) \n " , it->pos_min , it->pos_max , (float ) ctx_checkpoint_size / 1024 / 1024 );
3576
3580
do_reset = true ;
3581
+ // printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
3577
3582
} else {
3578
- slot.n_past = std::min (slot.n_past , it->pos_max );
3579
-
3580
- SLT_WRN (slot, " SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n " , it->pos_min , it->pos_max , (float ) swa_size / 1024 / 1024 );
3583
+ slot.n_past = std::min (slot.n_past , std::max (it->pos_min + 1 , it->pos_max ));
3584
+ SLT_WRN (slot, " restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " , it->pos_min , it->pos_max , (float ) ctx_checkpoint_size / 1024 / 1024 );
3581
3585
}
3582
3586
}
3583
3587
3584
3588
if (do_reset) {
3585
- SLT_WRN (slot, " forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n " ,
3589
+ SLT_WRN (slot, " forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory , see %s)\n " ,
3586
3590
" https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055" );
3587
-
3588
3591
slot.n_past = 0 ;
3589
- slot.swa_checkpoints .clear ();
3590
3592
}
3591
3593
}
3592
3594
}
3593
3595
3594
- if (n_swa > 0 ) {
3595
- const auto pos_min_thold = std::max (0 , slot.n_past - n_swa);
3596
-
3596
+ {
3597
3597
// erase any checkpoints with pos_min > pos_min_thold
3598
- for (int i = (int ) slot.swa_checkpoints .size () - 1 ; i >= 0 ; i--) {
3599
- const auto & cur = slot.swa_checkpoints [i];
3598
+ for (int i = (int ) slot.ctx_checkpoints .size () - 1 ; i >= 0 ; i--) {
3599
+ const auto & cur = slot.ctx_checkpoints [i];
3600
3600
if (cur.pos_min > pos_min_thold) {
3601
- slot.swa_checkpoints .erase (slot.swa_checkpoints .begin () + i);
3602
-
3603
- SLT_WRN (slot, " SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n " , cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
3601
+ SLT_WRN (slot, " erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n " , cur.pos_min , cur.pos_max , n_swa, (float ) cur.data .size () / 1024 / 1024 );
3602
+ slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin () + i);
3604
3603
}
3605
3604
}
3606
3605
}
3607
3606
}
3608
3607
3608
+ // [TAG_PROMPT_LOGITS]
3609
3609
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0 ) {
3610
- SLT_WRN (slot, " need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n " , slot.n_past , slot.n_prompt_tokens );
3611
-
3610
+ SLT_WRN (slot, " need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n " , slot.n_past , slot.n_prompt_tokens );
3612
3611
slot.n_past --;
3612
+ SLT_WRN (slot, " n_past was set to %d\n " , slot.n_past );
3613
3613
}
3614
3614
3615
3615
slot.n_prompt_tokens_cache = slot.n_past ;
@@ -3623,17 +3623,17 @@ struct server_context {
3623
3623
}
3624
3624
}
3625
3625
3626
- // keep only the common part
3626
+ // truncate any tokens that are beyond n_past for this slot
3627
3627
if (!llama_memory_seq_rm (llama_get_memory (ctx), slot.id , slot.n_past , -1 )) {
3628
- // could not partially delete (likely using a non-Transformer model)
3628
+ SLT_WRN (slot, " failed to truncate tokens beyond n_past = %d \n " , slot. n_past );
3629
3629
llama_memory_seq_rm (llama_get_memory (ctx), slot.id , -1 , -1 );
3630
3630
3631
3631
// there is no common part left
3632
3632
slot.n_past = 0 ;
3633
3633
slot.n_prompt_tokens_cache = 0 ;
3634
3634
}
3635
3635
3636
- SLT_INF (slot, " kv cache rm [%d, end)\n " , slot.n_past );
3636
+ SLT_INF (slot, " n_past = %d, memory_seq_rm [%d, end)\n " , slot. n_past , slot.n_past );
3637
3637
3638
3638
// remove the non-common part from the cache
3639
3639
slot.cache_tokens .keep_first (slot.n_past );
@@ -3854,37 +3854,38 @@ struct server_context {
3854
3854
// prompt evaluated for next-token prediction
3855
3855
slot.state = SLOT_STATE_GENERATING;
3856
3856
3857
- // make a checkpoint with the SWA memory
3858
- // checkpoints are needed only if we are not using "--swa-full"
3859
- if (llama_model_n_swa (model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0 ) {
3860
- if (slot.swa_checkpoints .size () >= (size_t ) params_base.n_swa_checkpoints ) {
3861
- {
3862
- const auto & cur = slot.swa_checkpoints .back ();
3863
-
3864
- SLT_WRN (slot, " SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n " ,
3865
- cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
3866
- }
3867
-
3868
- slot.swa_checkpoints .erase (slot.swa_checkpoints .begin ());
3857
+ // make a checkpoint of the parts of the memory that cannot be rolled back.
3858
+ // checkpoints are created only if:
3859
+ // - the model uses SWA and we are not using `swa_full`
3860
+ // - the model architecture is marked as recurrent or hybrid
3861
+ //
3862
+ // TODO: try to make this conditional on the context or the memory module, instead of the model type
3863
+ const bool do_checkpoint =
3864
+ (llama_model_is_recurrent (model) || llama_model_is_hybrid (model)) ||
3865
+ (llama_model_n_swa (model) > 0 && !params_base.swa_full );
3866
+
3867
+ if (do_checkpoint && params_base.n_ctx_checkpoints > 0 ) {
3868
+ while (slot.ctx_checkpoints .size () >= (size_t ) params_base.n_ctx_checkpoints ) {
3869
+ // make room for the new checkpoint, if needed
3870
+ const auto & cur = slot.ctx_checkpoints .front ();
3871
+ SLT_WRN (slot, " erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n " ,
3872
+ cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
3873
+
3874
+ slot.ctx_checkpoints .erase (slot.ctx_checkpoints .begin ());
3869
3875
}
3870
3876
3871
- const size_t swa_size = llama_state_seq_get_size_ext (ctx, slot.id , LLAMA_STATE_SEQ_FLAGS_SWA_ONLY );
3877
+ const size_t checkpoint_size = llama_state_seq_get_size_ext (ctx, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY );
3872
3878
3873
- auto & cur = slot.swa_checkpoints .emplace_back (swa_checkpoint {
3879
+ auto & cur = slot.ctx_checkpoints .emplace_back (ctx_checkpoint {
3874
3880
/* .pos_min = */ llama_memory_seq_pos_min (llama_get_memory (ctx), slot.id ),
3875
3881
/* .pos_max = */ llama_memory_seq_pos_max (llama_get_memory (ctx), slot.id ),
3876
- /* .data = */ std::vector<uint8_t >(swa_size ),
3882
+ /* .data = */ std::vector<uint8_t >(checkpoint_size ),
3877
3883
});
3878
3884
3879
- llama_state_seq_get_data_ext (ctx, cur.data .data (), swa_size, slot.id , LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
3880
-
3881
- float size_total = 0 .0f ;
3882
- for (const auto & checkpoint : slot.swa_checkpoints ) {
3883
- size_total += (float ) checkpoint.data .size () / 1024 / 1024 ;
3884
- }
3885
+ llama_state_seq_get_data_ext (ctx, cur.data .data (), checkpoint_size, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
3885
3886
3886
- SLT_WRN (slot, " SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d ( %.3f MiB)\n " ,
3887
- cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 , ( int ) slot. swa_checkpoints . size (), params_base. n_swa_checkpoints , size_total );
3887
+ SLT_WRN (slot, " saved context checkpoint %d of %d ( pos_min = %d, pos_max = %d, size = %.3f MiB)\n " ,
3888
+ ( int ) slot. ctx_checkpoints . size (), params_base. n_ctx_checkpoints , cur.pos_min , cur.pos_max , (float ) cur.data .size () / 1024 / 1024 );
3888
3889
}
3889
3890
} else if (slot.state != SLOT_STATE_GENERATING) {
3890
3891
continue ; // continue loop of slots
0 commit comments