From d7ac0d3d06dadabf734441fed3d5c09f86e36ddf Mon Sep 17 00:00:00 2001 From: Maximilian Winter Date: Wed, 24 Jan 2024 01:24:57 +0100 Subject: [PATCH 01/16] Ported self extension to server example --- examples/server/server.cpp | 134 +++++++++++++++++++++++++++++++------ 1 file changed, 112 insertions(+), 22 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0462fbd24739b..cc488b481c1b0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -410,6 +410,11 @@ struct llama_client_slot struct llama_sampling_params sparams; llama_sampling_context *ctx_sampling = nullptr; + int ga_i = 0; // group-attention state + + int32_t ga_n = 1; // group-attention factor + int32_t ga_w = 512; // group-attention width + // multimodal std::vector images; @@ -438,7 +443,7 @@ struct llama_client_slot sent_count = 0; sent_token_probs_index = 0; infill = false; - + ga_i = 0; generated_token_probs.clear(); for (slot_image & img : images) @@ -633,9 +638,26 @@ struct llama_server_context slot.id = i; slot.n_ctx = n_ctx_slot; + LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot); + + const int ga_n = params.grp_attn_n; + const int ga_w = params.grp_attn_w; + + if (ga_n != 1) { + GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT + GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT + //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT + //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT + LOG_TEE(" -> Slot %i - self-extend: ga_n = %d, ga_w = %d\n", slot.id, ga_n, ga_w); + } + + slot.ga_i = 0; + slot.ga_n = ga_n; + slot.ga_w = ga_w; + slot.reset(); - LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot); + slots.push_back(slot); } @@ -1692,32 +1714,37 @@ struct llama_server_context for (llama_client_slot &slot : slots) { - if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx) + if (slot.ga_n == 1) { - // Shift context - const int n_left = slot.n_past - slot.params.n_keep - 1; - const int n_discard = n_left / 2; + if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx) + { - LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard); - llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1); - llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard); + // Shift context + const int n_left = slot.n_past - slot.params.n_keep - 1; + const int n_discard = n_left / 2; - for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++) - { - slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; - } + LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, + slot.params.n_keep, n_left, n_discard); + llama_kv_cache_seq_rm(ctx, slot.id, slot.params.n_keep + 1, slot.params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard); + + for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++) + { + slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; + } - slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); + slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); - slot.n_past -= n_discard; + slot.n_past -= n_discard; - slot.truncated = true; + slot.truncated = true; - LOG_VERBOSE("context shift", { - {"n_ctx", n_ctx}, - {"n_keep", params.n_keep}, - {"n_left", n_left}, - }); + LOG_VERBOSE("context shift", { + { "n_ctx", n_ctx }, + { "n_keep", params.n_keep }, + { "n_left", n_left }, + }); + } } } @@ -1880,11 +1907,25 @@ struct llama_server_context // process the prefix of first image std::vector prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens; + int32_t slot_npast = 0; + int ga_i = slot.ga_i; + int ga_n = slot.ga_n; + int ga_w = slot.ga_w; for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) { - llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, { slot.id }, false); + if(slot.ga_n != 1) + { + while (slot_npast >= ga_i + ga_w) { + const int bd = (ga_w/ga_n)*(ga_n - 1); + slot_npast -= bd; + ga_i += ga_w/ga_n; + } + } + llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false); + slot_npast += 1; } + slot.n_past = 0; if (has_images && !ingest_images(slot, n_batch)) { LOG_TEE("failed processing images\n"); @@ -1912,6 +1953,35 @@ struct llama_server_context for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + for (auto & slot : slots) + { + if(slot.ga_n != 1) + { + // context extension via Self-Extend + while (slot.n_past >= slot.ga_i + slot.ga_w) { + const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; + const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); + const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; + + LOG_TEE("\n"); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past, ib * bd, slot.ga_i + ib * bd, slot.n_past + ib * bd); + LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past + ib * bd + dd); + + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past, ib * bd); + llama_kv_cache_seq_div (ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past + ib * bd, dd); + + slot.n_past -= bd; + + slot.ga_i += slot.ga_w / slot.ga_n; + + LOG_TEE("\nslot.n_past_old = %d, slot.n_past = %d, ga_i = %d\n\n", slot.n_past + bd, slot.n_past, slot.ga_i); + } + } + slot.n_past += n_tokens; + } + llama_batch batch_view = { n_tokens, @@ -1925,6 +1995,7 @@ struct llama_server_context }; const int ret = llama_decode(ctx, batch_view); + if (ret != 0) { if (n_batch == 1 || ret < 0) @@ -1944,6 +2015,7 @@ struct llama_server_context for (auto & slot : slots) { + if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { continue; @@ -1995,6 +2067,7 @@ struct llama_server_context slot.i_batch = -1; } } + return true; } }; @@ -2251,6 +2324,23 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.n_threads = std::stoi(argv[i]); } + else if (arg == "--grp-attn-n" || arg == "-gan") { + if (++i >= argc) { + invalid_param = true; + break; + } + + params.grp_attn_n = std::stoi(argv[i]); + } else if (arg == "--grp-attn-w" || arg == "-gaw") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + + params.grp_attn_w = std::stoi(argv[i]); + } else if (arg == "--threads-batch" || arg == "-tb") { if (++i >= argc) From 0978977e56e808acbab3072d1ac83cdc2c71e03c Mon Sep 17 00:00:00 2001 From: Maximilian Winter Date: Wed, 24 Jan 2024 01:32:33 +0100 Subject: [PATCH 02/16] Update server.cpp --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index cc488b481c1b0..cdb0592cf5af0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1976,7 +1976,7 @@ struct llama_server_context slot.ga_i += slot.ga_w / slot.ga_n; - LOG_TEE("\nslot.n_past_old = %d, slot.n_past = %d, ga_i = %d\n\n", slot.n_past + bd, slot.n_past, slot.ga_i); + LOG_TEE("\nn_past_old = %d, 2n_past = %d, ga_i = %d\n\n", slot.n_past + bd, slot.n_past, slot.ga_i); } } slot.n_past += n_tokens; From edc2c08943aa4e1c4e1f3a65a391242eaef996c8 Mon Sep 17 00:00:00 2001 From: Maximilian Winter Date: Fri, 26 Jan 2024 22:53:34 +0100 Subject: [PATCH 03/16] Fixed prompt caching without self extend --- examples/server/server.cpp | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 195caf0b21fec..a606a927321e9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1565,7 +1565,7 @@ struct llama_server_context // process the prefix of first image std::vector prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens; - int32_t slot_npast = 0; + int32_t slot_npast = slot.n_past; int ga_i = slot.ga_i; int ga_n = slot.ga_n; int ga_w = slot.ga_w; @@ -1583,7 +1583,7 @@ struct llama_server_context slot_npast += 1; } - slot.n_past = 0; + if (has_images && !ingest_images(slot, n_batch)) { LOG_TEE("failed processing images\n"); @@ -1608,36 +1608,45 @@ struct llama_server_context return true; } + std::vector slot_npasts; + for (auto & slot : slots) + { + slot_npasts.emplace_back(slot.n_past); + } + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + int slot_id = 0; for (auto & slot : slots) { + if(slot.ga_n != 1) { // context extension via Self-Extend - while (slot.n_past >= slot.ga_i + slot.ga_w) { + while (slot_npasts[slot_id] >= slot.ga_i + slot.ga_w) { const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; LOG_TEE("\n"); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past, ib * bd, slot.ga_i + ib * bd, slot.n_past + ib * bd); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot_npasts[slot_id], ib * bd, slot.ga_i + ib * bd, slot_npasts[slot_id] + ib * bd); LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past + ib * bd + dd); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot_npasts[slot_id] + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot_npasts[slot_id] + ib * bd + dd); - llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past, ib * bd); + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot_npasts[slot_id], ib * bd); llama_kv_cache_seq_div (ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); - llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past + ib * bd, dd); + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot_npasts[slot_id] + ib * bd, dd); - slot.n_past -= bd; + slot_npasts[slot_id] -= bd; slot.ga_i += slot.ga_w / slot.ga_n; - LOG_TEE("\nn_past_old = %d, 2n_past = %d, ga_i = %d\n\n", slot.n_past + bd, slot.n_past, slot.ga_i); + LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot_npasts[slot_id] + bd, slot_npasts[slot_id], slot.ga_i); } } - slot.n_past += n_tokens; + slot_npasts[slot_id] += n_tokens; + slot_id++; } llama_batch batch_view = From 960cfb003f5b70b24b58a86d0b5d97254aa07f5b Mon Sep 17 00:00:00 2001 From: Maximilian Winter Date: Fri, 26 Jan 2024 22:54:26 +0100 Subject: [PATCH 04/16] Update server.cpp --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a606a927321e9..ef96e823a518d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1611,7 +1611,7 @@ struct llama_server_context std::vector slot_npasts; for (auto & slot : slots) { - slot_npasts.emplace_back(slot.n_past); + slot_npasts.emplace_back(0); } for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) From 4df0e88aedb010c04aed8e06de220463c1d7b044 Mon Sep 17 00:00:00 2001 From: Maximilian Winter Date: Sat, 27 Jan 2024 00:37:53 +0100 Subject: [PATCH 05/16] Added description to server readme. --- examples/server/README.md | 4 +++- examples/server/server.cpp | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/server/README.md b/examples/server/README.md index fd3034b99c3d2..a4a922bce43b9 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -30,7 +30,9 @@ Command line options: - `-cb`, `--cont-batching`: enable continuous batching (a.k.a dynamic batching) (default: disabled) - `-spf FNAME`, `--system-prompt-file FNAME` Set a file to load "a system prompt (initial prompt of all slots), this is useful for chat applications. [See more](#change-system-prompt-on-runtime) - `--mmproj MMPROJ_FILE`: Path to a multimodal projector file for LLaVA. - +- `--grp-attn-n`: Extend context size through self extend. Extend context size n-times (default: 1), used together with `--grp-attn-w` +- `--grp-attn-w`: Width of the self extend context size extension. (default: 512) shouldn't be greater than original context size +- ## Build server is build alongside everything else from the root of the project diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ef96e823a518d..d9ca2fdd90a2e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1810,6 +1810,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); + printf(" --grp-attn-n N Extend context size through self extend. Extend context size n-times (default: 1), used together with `--grp-attn-w`"); + printf(" --grp-attn-w N Width of the self extend context size extension. (default: 512) shouldn't be greater than original context size"); printf("\n"); } From aa14068a2b1efa425f704031ec40d1c74258e043 Mon Sep 17 00:00:00 2001 From: Maximilian Winter Date: Sat, 27 Jan 2024 06:18:40 +0100 Subject: [PATCH 06/16] Update server.cpp --- examples/server/server.cpp | 49 ++++++++++++++++++++------------------ 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d9ca2fdd90a2e..0db99c16775f8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -189,6 +189,7 @@ struct llama_client_slot int32_t ga_n = 1; // group-attention factor int32_t ga_w = 512; // group-attention width + int32_t n_past_self_extension = 0; // multimodal std::vector images; @@ -218,6 +219,7 @@ struct llama_client_slot sent_token_probs_index = 0; infill = false; ga_i = 0; + n_past_self_extension = 0; generated_token_probs.clear(); for (slot_image & img : images) @@ -1427,9 +1429,8 @@ struct llama_server_context } slot.i_batch = batch.n_tokens; - - llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true); - + int32_t slot_npast = slot.n_past_self_extension > 0 ? slot.n_past_self_extension : slot.n_past; + llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true); slot.n_past += 1; } @@ -1526,6 +1527,7 @@ struct llama_server_context llama_sampling_reset(slot.ctx_sampling); slot.n_past = 0; + slot.n_past_self_extension = 0; slot.num_prompt_tokens_processed = slot.num_prompt_tokens; } else @@ -1553,6 +1555,10 @@ struct llama_server_context // we have to evaluate at least 1 token to generate logits. LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id); slot.n_past--; + if(slot.n_past_self_extension > 0) + { + slot.n_past_self_extension--; + } } LOG_VERBOSE("prompt ingested", { @@ -1565,10 +1571,10 @@ struct llama_server_context // process the prefix of first image std::vector prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens; - int32_t slot_npast = slot.n_past; + int32_t slot_npast = slot.n_past_self_extension > 0 ? slot.n_past_self_extension : slot.n_past; int ga_i = slot.ga_i; - int ga_n = slot.ga_n; - int ga_w = slot.ga_w; + int32_t ga_n = slot.ga_n; + int32_t ga_w = slot.ga_w; for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) { if(slot.ga_n != 1) @@ -1579,11 +1585,10 @@ struct llama_server_context ga_i += ga_w/ga_n; } } - llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false); + llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false); slot_npast += 1; } - if (has_images && !ingest_images(slot, n_batch)) { LOG_TEE("failed processing images\n"); @@ -1617,38 +1622,37 @@ struct llama_server_context for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); - int slot_id = 0; + for (auto & slot : slots) { - - if(slot.ga_n != 1) + if (slot.ga_n != 1) { + // context extension via Self-Extend - while (slot_npasts[slot_id] >= slot.ga_i + slot.ga_w) { + while (slot.n_past_self_extension >= slot.ga_i + slot.ga_w) + { const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; LOG_TEE("\n"); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot_npasts[slot_id], ib * bd, slot.ga_i + ib * bd, slot_npasts[slot_id] + ib * bd); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_self_extension, ib * bd, slot.ga_i + ib * bd, slot.n_past_self_extension + ib * bd); LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot_npasts[slot_id] + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot_npasts[slot_id] + ib * bd + dd); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_self_extension + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_self_extension + ib * bd + dd); - llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot_npasts[slot_id], ib * bd); - llama_kv_cache_seq_div (ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); - llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot_npasts[slot_id] + ib * bd, dd); + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_self_extension, ib * bd); + llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n); + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_self_extension + ib * bd, dd); - slot_npasts[slot_id] -= bd; + slot.n_past_self_extension -= bd; slot.ga_i += slot.ga_w / slot.ga_n; - LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot_npasts[slot_id] + bd, slot_npasts[slot_id], slot.ga_i); + LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_self_extension + bd, slot.n_past_self_extension, slot.ga_i); } + slot.n_past_self_extension += n_tokens; } - slot_npasts[slot_id] += n_tokens; - slot_id++; } - llama_batch batch_view = { n_tokens, @@ -1682,7 +1686,6 @@ struct llama_server_context for (auto & slot : slots) { - if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { continue; From c8bc9297c0a394b21a7a16e80c948e9462a7e053 Mon Sep 17 00:00:00 2001 From: Maximilian Winter Date: Sat, 27 Jan 2024 06:30:07 +0100 Subject: [PATCH 07/16] Update server.cpp --- examples/server/server.cpp | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0db99c16775f8..92a1090f78531 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1555,10 +1555,6 @@ struct llama_server_context // we have to evaluate at least 1 token to generate logits. LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id); slot.n_past--; - if(slot.n_past_self_extension > 0) - { - slot.n_past_self_extension--; - } } LOG_VERBOSE("prompt ingested", { @@ -1613,12 +1609,6 @@ struct llama_server_context return true; } - std::vector slot_npasts; - for (auto & slot : slots) - { - slot_npasts.emplace_back(0); - } - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); From aa41b22f26d277fe53dcc502f8faace6f1a566a4 Mon Sep 17 00:00:00 2001 From: Maximilian Winter Date: Sat, 27 Jan 2024 06:41:50 +0100 Subject: [PATCH 08/16] Update server.cpp --- examples/server/server.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 92a1090f78531..0b21a5263609c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1528,6 +1528,7 @@ struct llama_server_context slot.n_past = 0; slot.n_past_self_extension = 0; + slot.ga_i = 0; slot.num_prompt_tokens_processed = slot.num_prompt_tokens; } else From cb96a91f7cd0cdecc229f1e4e46ca43b1406d278 Mon Sep 17 00:00:00 2001 From: Maximilian Winter Date: Sat, 27 Jan 2024 07:06:08 +0100 Subject: [PATCH 09/16] Update server.cpp --- examples/server/server.cpp | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0b21a5263609c..ff3b018766c1f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1542,6 +1542,26 @@ struct llama_server_context slot.n_past = common_part(slot.cache_tokens, prompt_tokens); slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past; + if(slot.ga_n != 1) + { + int ga_i = 0; + int32_t ga_n = slot.ga_n; + int32_t ga_w = slot.ga_w; + int32_t slot_npast = 0; + for (int k = 0; k < slot.n_past; ++k) + { + while (slot_npast >= ga_i + ga_w) { + const int bd = (ga_w/ga_n)*(ga_n - 1); + slot_npast -= bd; + ga_i += ga_w/ga_n; + } + slot_npast++; + } + slot.n_past_self_extension = slot_npast; + slot.ga_i = ga_i; + + } + LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); } @@ -1556,6 +1576,10 @@ struct llama_server_context // we have to evaluate at least 1 token to generate logits. LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id); slot.n_past--; + if(slot.ga_i > 0) + { + slot.n_past_self_extension--; + } } LOG_VERBOSE("prompt ingested", { From 8f20af4ac6e3a830169f3d8855e06ba9611fd9b3 Mon Sep 17 00:00:00 2001 From: Maximilian Winter Date: Sat, 27 Jan 2024 07:24:27 +0100 Subject: [PATCH 10/16] Update README.md --- examples/server/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/server/README.md b/examples/server/README.md index a4a922bce43b9..d5626dad424d2 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -32,7 +32,6 @@ Command line options: - `--mmproj MMPROJ_FILE`: Path to a multimodal projector file for LLaVA. - `--grp-attn-n`: Extend context size through self extend. Extend context size n-times (default: 1), used together with `--grp-attn-w` - `--grp-attn-w`: Width of the self extend context size extension. (default: 512) shouldn't be greater than original context size -- ## Build server is build alongside everything else from the root of the project From 9d7b7e686cc766c5fa0e9622f2b35730d36e0d25 Mon Sep 17 00:00:00 2001 From: Maximilian Winter Date: Sat, 27 Jan 2024 07:33:11 +0100 Subject: [PATCH 11/16] Changed descriptions --- examples/server/README.md | 4 ++-- examples/server/server.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index d5626dad424d2..1c92a2041c430 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -30,8 +30,8 @@ Command line options: - `-cb`, `--cont-batching`: enable continuous batching (a.k.a dynamic batching) (default: disabled) - `-spf FNAME`, `--system-prompt-file FNAME` Set a file to load "a system prompt (initial prompt of all slots), this is useful for chat applications. [See more](#change-system-prompt-on-runtime) - `--mmproj MMPROJ_FILE`: Path to a multimodal projector file for LLaVA. -- `--grp-attn-n`: Extend context size through self extend. Extend context size n-times (default: 1), used together with `--grp-attn-w` -- `--grp-attn-w`: Width of the self extend context size extension. (default: 512) shouldn't be greater than original context size +- `--grp-attn-n`: Set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w` +- `--grp-attn-w`: Set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n` ## Build server is build alongside everything else from the root of the project diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ff3b018766c1f..bf0220cffe6ba 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1828,8 +1828,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); - printf(" --grp-attn-n N Extend context size through self extend. Extend context size n-times (default: 1), used together with `--grp-attn-w`"); - printf(" --grp-attn-w N Width of the self extend context size extension. (default: 512) shouldn't be greater than original context size"); + printf(" -gan N, --grp-attn-n N Set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`"); + printf(" -gaw N, --grp-attn-w N Set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`"); printf("\n"); } From 26f95fb079d59c8e2155c607f42dfafb3514c1df Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 27 Jan 2024 12:39:33 +0200 Subject: [PATCH 12/16] server : formatting --- examples/server/server.cpp | 40 +++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index bf0220cffe6ba..c107083927ff0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -184,12 +184,13 @@ struct llama_client_slot struct llama_sampling_params sparams; llama_sampling_context *ctx_sampling = nullptr; - int ga_i = 0; // group-attention state + int ga_i = 0; // group-attention state - int32_t ga_n = 1; // group-attention factor - int32_t ga_w = 512; // group-attention width + int32_t ga_n = 1; // group-attention factor + int32_t ga_w = 512; // group-attention width int32_t n_past_self_extension = 0; + // multimodal std::vector images; @@ -218,8 +219,8 @@ struct llama_client_slot sent_count = 0; sent_token_probs_index = 0; infill = false; - ga_i = 0; - n_past_self_extension = 0; + ga_i = 0; + n_past_self_extension = 0; generated_token_probs.clear(); for (slot_image & img : images) @@ -406,6 +407,7 @@ struct llama_server_context slot.id = i; slot.n_ctx = n_ctx_slot; + LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot); const int ga_n = params.grp_attn_n; @@ -425,7 +427,6 @@ struct llama_server_context slot.reset(); - slots.push_back(slot); } @@ -1377,14 +1378,12 @@ struct llama_server_context { if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx) { - // Shift context - const int n_left = slot.n_past - slot.params.n_keep - 1; + const int n_left = slot.n_past - slot.params.n_keep - 1; const int n_discard = n_left / 2; - LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, - slot.params.n_keep, n_left, n_discard); - llama_kv_cache_seq_rm(ctx, slot.id, slot.params.n_keep + 1, slot.params.n_keep + n_discard + 1); + LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard); + llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1); llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard); for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++) @@ -1429,8 +1428,10 @@ struct llama_server_context } slot.i_batch = batch.n_tokens; - int32_t slot_npast = slot.n_past_self_extension > 0 ? slot.n_past_self_extension : slot.n_past; + + const int32_t slot_npast = slot.n_past_self_extension > 0 ? slot.n_past_self_extension : slot.n_past; llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true); + slot.n_past += 1; } @@ -1542,7 +1543,7 @@ struct llama_server_context slot.n_past = common_part(slot.cache_tokens, prompt_tokens); slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past; - if(slot.ga_n != 1) + if (slot.ga_n != 1) { int ga_i = 0; int32_t ga_n = slot.ga_n; @@ -1559,7 +1560,6 @@ struct llama_server_context } slot.n_past_self_extension = slot_npast; slot.ga_i = ga_i; - } LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); @@ -1576,7 +1576,7 @@ struct llama_server_context // we have to evaluate at least 1 token to generate logits. LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id); slot.n_past--; - if(slot.ga_i > 0) + if (slot.ga_i > 0) { slot.n_past_self_extension--; } @@ -1598,7 +1598,7 @@ struct llama_server_context int32_t ga_w = slot.ga_w; for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) { - if(slot.ga_n != 1) + if (slot.ga_n != 1) { while (slot_npast >= ga_i + ga_w) { const int bd = (ga_w/ga_n)*(ga_n - 1); @@ -1642,7 +1642,6 @@ struct llama_server_context { if (slot.ga_n != 1) { - // context extension via Self-Extend while (slot.n_past_self_extension >= slot.ga_i + slot.ga_w) { @@ -1752,7 +1751,6 @@ struct llama_server_context slot.i_batch = -1; } } - return true; } @@ -2015,14 +2013,16 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.n_threads = std::stoi(argv[i]); } - else if (arg == "--grp-attn-n" || arg == "-gan") { + else if (arg == "--grp-attn-n" || arg == "-gan") + { if (++i >= argc) { invalid_param = true; break; } params.grp_attn_n = std::stoi(argv[i]); - } else if (arg == "--grp-attn-w" || arg == "-gaw") + } + else if (arg == "--grp-attn-w" || arg == "-gaw") { if (++i >= argc) { From 67096a353977c583764942b88e93a1b626af8b01 Mon Sep 17 00:00:00 2001 From: Maximilian Winter Date: Sat, 27 Jan 2024 11:47:19 +0100 Subject: [PATCH 13/16] Update examples/server/server.cpp Co-authored-by: Georgi Gerganov --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c107083927ff0..c6b13d718e634 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -189,7 +189,7 @@ struct llama_client_slot int32_t ga_n = 1; // group-attention factor int32_t ga_w = 512; // group-attention width - int32_t n_past_self_extension = 0; + int32_t n_past_se = 0; // self-extend // multimodal std::vector images; From d56cecc2cd6f18e4dfe7629311c608ea7bd0dbda Mon Sep 17 00:00:00 2001 From: Maximilian Winter Date: Sat, 27 Jan 2024 11:47:27 +0100 Subject: [PATCH 14/16] Update examples/server/server.cpp Co-authored-by: Georgi Gerganov --- examples/server/server.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c6b13d718e634..4228739562325 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -184,8 +184,7 @@ struct llama_client_slot struct llama_sampling_params sparams; llama_sampling_context *ctx_sampling = nullptr; - int ga_i = 0; // group-attention state - + int32_t ga_i = 0; // group-attention state int32_t ga_n = 1; // group-attention factor int32_t ga_w = 512; // group-attention width From 826e6dcad95d743d66c042379e2935fde7dc646d Mon Sep 17 00:00:00 2001 From: Maximilian Winter Date: Sat, 27 Jan 2024 14:13:34 +0100 Subject: [PATCH 15/16] Update server.cpp --- examples/server/server.cpp | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 4228739562325..3a82e408e75e9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -219,7 +219,7 @@ struct llama_client_slot sent_token_probs_index = 0; infill = false; ga_i = 0; - n_past_self_extension = 0; + n_past_se = 0; generated_token_probs.clear(); for (slot_image & img : images) @@ -1428,7 +1428,7 @@ struct llama_server_context slot.i_batch = batch.n_tokens; - const int32_t slot_npast = slot.n_past_self_extension > 0 ? slot.n_past_self_extension : slot.n_past; + const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true); slot.n_past += 1; @@ -1527,7 +1527,7 @@ struct llama_server_context llama_sampling_reset(slot.ctx_sampling); slot.n_past = 0; - slot.n_past_self_extension = 0; + slot.n_past_se = 0; slot.ga_i = 0; slot.num_prompt_tokens_processed = slot.num_prompt_tokens; } @@ -1557,7 +1557,7 @@ struct llama_server_context } slot_npast++; } - slot.n_past_self_extension = slot_npast; + slot.n_past_se = slot_npast; slot.ga_i = ga_i; } @@ -1577,7 +1577,7 @@ struct llama_server_context slot.n_past--; if (slot.ga_i > 0) { - slot.n_past_self_extension--; + slot.n_past_se--; } } @@ -1591,7 +1591,7 @@ struct llama_server_context // process the prefix of first image std::vector prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens; - int32_t slot_npast = slot.n_past_self_extension > 0 ? slot.n_past_self_extension : slot.n_past; + int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; int ga_i = slot.ga_i; int32_t ga_n = slot.ga_n; int32_t ga_w = slot.ga_w; @@ -1642,28 +1642,28 @@ struct llama_server_context if (slot.ga_n != 1) { // context extension via Self-Extend - while (slot.n_past_self_extension >= slot.ga_i + slot.ga_w) + while (slot.n_past_se >= slot.ga_i + slot.ga_w) { const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; LOG_TEE("\n"); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_self_extension, ib * bd, slot.ga_i + ib * bd, slot.n_past_self_extension + ib * bd); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd); LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_self_extension + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_self_extension + ib * bd + dd); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); - llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_self_extension, ib * bd); + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd); llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n); - llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_self_extension + ib * bd, dd); + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd); - slot.n_past_self_extension -= bd; + slot.n_past_se -= bd; slot.ga_i += slot.ga_w / slot.ga_n; - LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_self_extension + bd, slot.n_past_self_extension, slot.ga_i); + LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i); } - slot.n_past_self_extension += n_tokens; + slot.n_past_se += n_tokens; } } llama_batch batch_view = From f7c0e043de458dde15fefe1f82f8e9332bb02651 Mon Sep 17 00:00:00 2001 From: Maximilian Winter Date: Sat, 27 Jan 2024 14:15:36 +0100 Subject: [PATCH 16/16] Update server.cpp --- examples/server/server.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3a82e408e75e9..af63f2f6f7d12 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -184,8 +184,8 @@ struct llama_client_slot struct llama_sampling_params sparams; llama_sampling_context *ctx_sampling = nullptr; - int32_t ga_i = 0; // group-attention state - int32_t ga_n = 1; // group-attention factor + int32_t ga_i = 0; // group-attention state + int32_t ga_n = 1;// group-attention factor int32_t ga_w = 512; // group-attention width int32_t n_past_se = 0; // self-extend