Skip to content

Commit f6dcda3

Browse files
ddh0ggerganov
andauthored
server : context checkpointing for hybrid and recurrent models (#16382)
* initial commit for branch 3 * generalize `swa_checkpoint` to `ctx_checkpoint` this extends `llama-server`'s SWA checkpointing logic to include hybrid/recurrent models such as Jamba, Granite * oops * disable debug prints * keep backwards compat with `--swa-checkpoints` Co-authored-by: Georgi Gerganov <[email protected]> * update prompt re-processing message * fix off-by-one error per GG * keep `seq_rm` log per GG Co-authored-by: Georgi Gerganov <[email protected]> * server : fix checkpoint logic to support recurrent caches * server : cleanup and fixes --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 606a73f commit f6dcda3

File tree

8 files changed

+89
-74
lines changed

8 files changed

+89
-74
lines changed

common/arg.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1932,13 +1932,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
19321932
}
19331933
).set_env("LLAMA_ARG_SWA_FULL"));
19341934
add_opt(common_arg(
1935-
{"--swa-checkpoints"}, "N",
1936-
string_format("max number of SWA checkpoints per slot to create (default: %d)\n"
1937-
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints),
1935+
{"--ctx-checkpoints", "--swa-checkpoints"}, "N",
1936+
string_format("max number of context checkpoints to create per slot (default: %d)\n"
1937+
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_ctx_checkpoints),
19381938
[](common_params & params, int value) {
1939-
params.n_swa_checkpoints = value;
1939+
params.n_ctx_checkpoints = value;
19401940
}
1941-
).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
1941+
).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
19421942
add_opt(common_arg(
19431943
{"--kv-unified", "-kvu"},
19441944
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ struct common_params {
424424
int32_t timeout_write = timeout_read; // http write timeout in seconds
425425
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
426426
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
427-
int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot
427+
int32_t n_ctx_checkpoints = 3; // max number of context checkpoints per slot
428428

429429
std::string hostname = "127.0.0.1";
430430
std::string public_path = ""; // NOLINT

include/llama.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,9 @@ extern "C" {
543543
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
544544
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
545545

546+
// Returns true if the model is hybrid (like Jamba, Granite, etc.)
547+
LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
548+
546549
// Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
547550
LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);
548551

@@ -791,8 +794,12 @@ extern "C" {
791794
size_t n_token_capacity,
792795
size_t * n_token_count_out);
793796

797+
// for backwards-compat
794798
#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1
795799

800+
// work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba)
801+
#define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1
802+
796803
typedef uint32_t llama_state_seq_flags;
797804

798805
LLAMA_API size_t llama_state_seq_get_size_ext(

src/llama-kv-cache-iswa.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,15 +220,15 @@ bool llama_kv_cache_iswa::get_can_shift() const {
220220
}
221221

222222
void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
223-
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
223+
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
224224
kv_base->state_write(io, seq_id, flags);
225225
}
226226

227227
kv_swa->state_write(io, seq_id, flags);
228228
}
229229

230230
void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
231-
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
231+
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
232232
kv_base->state_read(io, seq_id, flags);
233233
}
234234

src/llama-memory-hybrid.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -175,17 +175,17 @@ std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid::memory_breakdo
175175
}
176176

177177
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
178-
GGML_UNUSED(flags);
179-
180-
mem_attn->state_write(io, seq_id);
181-
mem_recr->state_write(io, seq_id);
178+
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
179+
mem_attn->state_write(io, seq_id, flags);
180+
}
181+
mem_recr->state_write(io, seq_id, flags);
182182
}
183183

184184
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
185-
GGML_UNUSED(flags);
186-
187-
mem_attn->state_read(io, seq_id);
188-
mem_recr->state_read(io, seq_id);
185+
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
186+
mem_attn->state_read(io, seq_id, flags);
187+
}
188+
mem_recr->state_read(io, seq_id, flags);
189189
}
190190

191191
llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {

src/llama-memory-recurrent.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ void llama_memory_recurrent::clear(bool data) {
136136
}
137137

138138
bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
139+
//printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1);
139140
uint32_t new_head = size;
140141

141142
if (p0 < 0) {
@@ -156,7 +157,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
156157
if (tail_id >= 0) {
157158
const auto & cell = cells[tail_id];
158159
// partial intersection is invalid
159-
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
160+
if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) {
161+
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
160162
return false;
161163
}
162164
// invalidate tails which will be cleared
@@ -167,6 +169,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
167169
} else {
168170
// seq_id is negative, then the range should include everything or nothing
169171
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
172+
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: `seq_id` is negative, so returning false\n");
170173
return false;
171174
}
172175
}

src/llama-model.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20151,6 +20151,10 @@ bool llama_model_is_recurrent(const llama_model * model) {
2015120151
return llm_arch_is_recurrent(model->arch);
2015220152
}
2015320153

20154+
bool llama_model_is_hybrid(const llama_model * model) {
20155+
return llm_arch_is_hybrid(model->arch);
20156+
}
20157+
2015420158
bool llama_model_is_diffusion(const llama_model * model) {
2015520159
return llm_arch_is_diffusion(model->arch);
2015620160
}

tools/server/server.cpp

Lines changed: 58 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,7 @@ struct completion_token_output {
764764
}
765765
};
766766

767-
struct swa_checkpoint {
767+
struct ctx_checkpoint {
768768
llama_pos pos_min;
769769
llama_pos pos_max;
770770

@@ -1460,7 +1460,7 @@ struct server_slot {
14601460

14611461
std::vector<completion_token_output> generated_token_probs;
14621462

1463-
std::vector<swa_checkpoint> swa_checkpoints;
1463+
std::vector<ctx_checkpoint> ctx_checkpoints;
14641464

14651465
bool has_next_token = true;
14661466
bool has_new_line = false;
@@ -3541,7 +3541,11 @@ struct server_context {
35413541
slot.n_past = 0;
35423542
}
35433543

3544-
const auto n_swa = llama_model_n_swa(model);
3544+
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
3545+
const auto n_swa = std::max(1, llama_model_n_swa(model));
3546+
3547+
// the largest pos_min required for a checkpoint to be useful
3548+
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
35453549

35463550
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
35473551
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
@@ -3550,66 +3554,62 @@ struct server_context {
35503554
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
35513555
}
35523556

3553-
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
3554-
35553557
if (pos_min > pos_min_thold) {
35563558
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
35573559

3558-
// search for a SWA checkpoint
3560+
// search for a context checkpoint
35593561
const auto it = std::find_if(
3560-
slot.swa_checkpoints.rbegin(),
3561-
slot.swa_checkpoints.rend(),
3562+
slot.ctx_checkpoints.rbegin(),
3563+
slot.ctx_checkpoints.rend(),
35623564
[&](const auto & cur) {
3563-
return cur.pos_min <= pos_min_thold;
3565+
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
3566+
return cur.pos_min < pos_min_thold;
35643567
}
35653568
);
35663569

3567-
bool do_reset = it == slot.swa_checkpoints.rend();
3570+
bool do_reset = it == slot.ctx_checkpoints.rend();
3571+
//printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false");
35683572

35693573
if (!do_reset) {
3570-
// restore the checkpoint
3571-
const size_t swa_size = it->data.size();
3572-
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
3574+
// restore the context checkpoint
3575+
const size_t ctx_checkpoint_size = it->data.size();
3576+
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
35733577

3574-
if (n != swa_size) {
3575-
SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
3578+
if (n != ctx_checkpoint_size) {
3579+
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
35763580
do_reset = true;
3581+
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
35773582
} else {
3578-
slot.n_past = std::min(slot.n_past, it->pos_max);
3579-
3580-
SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
3583+
slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
3584+
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
35813585
}
35823586
}
35833587

35843588
if (do_reset) {
3585-
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
3589+
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
35863590
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
3587-
35883591
slot.n_past = 0;
3589-
slot.swa_checkpoints.clear();
35903592
}
35913593
}
35923594
}
35933595

3594-
if (n_swa > 0) {
3595-
const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
3596-
3596+
{
35973597
// erase any checkpoints with pos_min > pos_min_thold
3598-
for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) {
3599-
const auto & cur = slot.swa_checkpoints[i];
3598+
for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) {
3599+
const auto & cur = slot.ctx_checkpoints[i];
36003600
if (cur.pos_min > pos_min_thold) {
3601-
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i);
3602-
3603-
SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
3601+
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
3602+
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i);
36043603
}
36053604
}
36063605
}
36073606
}
36083607

3608+
// [TAG_PROMPT_LOGITS]
36093609
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
3610-
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
3611-
3610+
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens);
36123611
slot.n_past--;
3612+
SLT_WRN(slot, "n_past was set to %d\n", slot.n_past);
36133613
}
36143614

36153615
slot.n_prompt_tokens_cache = slot.n_past;
@@ -3623,17 +3623,17 @@ struct server_context {
36233623
}
36243624
}
36253625

3626-
// keep only the common part
3626+
// truncate any tokens that are beyond n_past for this slot
36273627
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) {
3628-
// could not partially delete (likely using a non-Transformer model)
3628+
SLT_WRN(slot, "failed to truncate tokens beyond n_past = %d\n", slot.n_past);
36293629
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
36303630

36313631
// there is no common part left
36323632
slot.n_past = 0;
36333633
slot.n_prompt_tokens_cache = 0;
36343634
}
36353635

3636-
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
3636+
SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past);
36373637

36383638
// remove the non-common part from the cache
36393639
slot.cache_tokens.keep_first(slot.n_past);
@@ -3854,37 +3854,38 @@ struct server_context {
38543854
// prompt evaluated for next-token prediction
38553855
slot.state = SLOT_STATE_GENERATING;
38563856

3857-
// make a checkpoint with the SWA memory
3858-
// checkpoints are needed only if we are not using "--swa-full"
3859-
if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) {
3860-
if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) {
3861-
{
3862-
const auto & cur = slot.swa_checkpoints.back();
3863-
3864-
SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n",
3865-
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
3866-
}
3867-
3868-
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
3857+
// make a checkpoint of the parts of the memory that cannot be rolled back.
3858+
// checkpoints are created only if:
3859+
// - the model uses SWA and we are not using `swa_full`
3860+
// - the model architecture is marked as recurrent or hybrid
3861+
//
3862+
// TODO: try to make this conditional on the context or the memory module, instead of the model type
3863+
const bool do_checkpoint =
3864+
(llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) ||
3865+
(llama_model_n_swa(model) > 0 && !params_base.swa_full);
3866+
3867+
if (do_checkpoint && params_base.n_ctx_checkpoints > 0) {
3868+
while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
3869+
// make room for the new checkpoint, if needed
3870+
const auto & cur = slot.ctx_checkpoints.front();
3871+
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
3872+
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
3873+
3874+
slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
38693875
}
38703876

3871-
const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
3877+
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
38723878

3873-
auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{
3879+
auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
38743880
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
38753881
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
3876-
/*.data = */ std::vector<uint8_t>(swa_size),
3882+
/*.data = */ std::vector<uint8_t>(checkpoint_size),
38773883
});
38783884

3879-
llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
3880-
3881-
float size_total = 0.0f;
3882-
for (const auto & checkpoint : slot.swa_checkpoints) {
3883-
size_total += (float) checkpoint.data.size() / 1024 / 1024;
3884-
}
3885+
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
38853886

3886-
SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d (%.3f MiB)\n",
3887-
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) slot.swa_checkpoints.size(), params_base.n_swa_checkpoints, size_total);
3887+
SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
3888+
(int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
38883889
}
38893890
} else if (slot.state != SLOT_STATE_GENERATING) {
38903891
continue; // continue loop of slots

0 commit comments

Comments
 (0)