diff --git a/.gitignore b/.gitignore
index 274f868..0f023ba 100644
--- a/.gitignore
+++ b/.gitignore
@@ -42,4 +42,6 @@ src/test/resources/**/*.gbnf
**/*.etag
**/*.lastModified
-src/main/cpp/llama.cpp/
\ No newline at end of file
+src/main/cpp/llama.cpp/
+/.classpath
+/.project
diff --git a/README.md b/README.md
index 1bc278b..69b2d8a 100644
--- a/README.md
+++ b/README.md
@@ -27,7 +27,11 @@ Access this library via Maven:
de.kherud
llama
+<<<<<<< HEAD
+ 4.0.1
+=======
4.1.0
+>>>>>>> 481714559fd5c80bad3a51edfa4c5887c0b528b3
```
diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp
index 66169a8..652e821 100644
--- a/src/main/cpp/server.hpp
+++ b/src/main/cpp/server.hpp
@@ -31,16 +31,15 @@ enum stop_type {
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
enum slot_state {
SLOT_STATE_IDLE,
- SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it
- // with launch_slot_with_task in the future
+ SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
SLOT_STATE_PROCESSING_PROMPT,
SLOT_STATE_DONE_PROMPT,
SLOT_STATE_GENERATING,
};
enum server_state {
- SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
- SERVER_STATE_READY, // Server is ready and model is loaded
+ SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
+ SERVER_STATE_READY, // Server is ready and model is loaded
};
enum server_task_type {
@@ -71,22 +70,21 @@ enum error_type {
ERROR_TYPE_SERVER,
ERROR_TYPE_NOT_FOUND,
ERROR_TYPE_PERMISSION,
- ERROR_TYPE_UNAVAILABLE, // custom error
+ ERROR_TYPE_UNAVAILABLE, // custom error
ERROR_TYPE_NOT_SUPPORTED, // custom error
};
struct slot_params {
- bool stream = true;
- bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
+ bool stream = true;
+ bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
bool return_tokens = false;
- int32_t n_keep = 0; // number of tokens to keep from initial prompt
- int32_t n_discard =
- 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
+ int32_t n_keep = 0; // number of tokens to keep from initial prompt
+ int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
- int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
+ int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
- int64_t t_max_prompt_ms = -1; // TODO: implement
+ int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
std::vector lora;
@@ -101,16 +99,16 @@ struct slot_params {
struct common_params_speculative speculative;
// OAI-compat fields
- bool verbose = false;
- oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
- std::string oaicompat_model;
- std::string oaicompat_cmpl_id;
- common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+ bool verbose = false;
+ oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
+ std::string oaicompat_model;
+ std::string oaicompat_cmpl_id;
+ common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
json to_json() const {
std::vector samplers;
samplers.reserve(sampling.samplers.size());
- for (const auto &sampler : sampling.samplers) {
+ for (const auto & sampler : sampling.samplers) {
samplers.emplace_back(common_sampler_type_to_str(sampler));
}
@@ -120,61 +118,61 @@ struct slot_params {
}
auto grammar_triggers = json::array();
- for (const auto &trigger : sampling.grammar_triggers) {
+ for (const auto & trigger : sampling.grammar_triggers) {
grammar_triggers.push_back(trigger.to_json());
}
- return json{
- {"n_predict", n_predict}, // Server configured n_predict
- {"seed", sampling.seed},
- {"temperature", sampling.temp},
- {"dynatemp_range", sampling.dynatemp_range},
- {"dynatemp_exponent", sampling.dynatemp_exponent},
- {"top_k", sampling.top_k},
- {"top_p", sampling.top_p},
- {"min_p", sampling.min_p},
- {"xtc_probability", sampling.xtc_probability},
- {"xtc_threshold", sampling.xtc_threshold},
- {"typical_p", sampling.typ_p},
- {"repeat_last_n", sampling.penalty_last_n},
- {"repeat_penalty", sampling.penalty_repeat},
- {"presence_penalty", sampling.penalty_present},
- {"frequency_penalty", sampling.penalty_freq},
- {"dry_multiplier", sampling.dry_multiplier},
- {"dry_base", sampling.dry_base},
- {"dry_allowed_length", sampling.dry_allowed_length},
- {"dry_penalty_last_n", sampling.dry_penalty_last_n},
- {"dry_sequence_breakers", sampling.dry_sequence_breakers},
- {"mirostat", sampling.mirostat},
- {"mirostat_tau", sampling.mirostat_tau},
- {"mirostat_eta", sampling.mirostat_eta},
- {"stop", antiprompt},
- {"max_tokens", n_predict}, // User configured n_predict
- {"n_keep", n_keep},
- {"n_discard", n_discard},
- {"ignore_eos", sampling.ignore_eos},
- {"stream", stream},
- {"logit_bias", format_logit_bias(sampling.logit_bias)},
- {"n_probs", sampling.n_probs},
- {"min_keep", sampling.min_keep},
- {"grammar", sampling.grammar},
- {"grammar_lazy", sampling.grammar_lazy},
- {"grammar_triggers", grammar_triggers},
- {"preserved_tokens", sampling.preserved_tokens},
- {"chat_format", common_chat_format_name(oaicompat_chat_format)},
- {"samplers", samplers},
- {"speculative.n_max", speculative.n_max},
- {"speculative.n_min", speculative.n_min},
- {"speculative.p_min", speculative.p_min},
- {"timings_per_token", timings_per_token},
- {"post_sampling_probs", post_sampling_probs},
- {"lora", lora},
+ return json {
+ {"n_predict", n_predict}, // Server configured n_predict
+ {"seed", sampling.seed},
+ {"temperature", sampling.temp},
+ {"dynatemp_range", sampling.dynatemp_range},
+ {"dynatemp_exponent", sampling.dynatemp_exponent},
+ {"top_k", sampling.top_k},
+ {"top_p", sampling.top_p},
+ {"min_p", sampling.min_p},
+ {"xtc_probability", sampling.xtc_probability},
+ {"xtc_threshold", sampling.xtc_threshold},
+ {"typical_p", sampling.typ_p},
+ {"repeat_last_n", sampling.penalty_last_n},
+ {"repeat_penalty", sampling.penalty_repeat},
+ {"presence_penalty", sampling.penalty_present},
+ {"frequency_penalty", sampling.penalty_freq},
+ {"dry_multiplier", sampling.dry_multiplier},
+ {"dry_base", sampling.dry_base},
+ {"dry_allowed_length", sampling.dry_allowed_length},
+ {"dry_penalty_last_n", sampling.dry_penalty_last_n},
+ {"dry_sequence_breakers", sampling.dry_sequence_breakers},
+ {"mirostat", sampling.mirostat},
+ {"mirostat_tau", sampling.mirostat_tau},
+ {"mirostat_eta", sampling.mirostat_eta},
+ {"stop", antiprompt},
+ {"max_tokens", n_predict}, // User configured n_predict
+ {"n_keep", n_keep},
+ {"n_discard", n_discard},
+ {"ignore_eos", sampling.ignore_eos},
+ {"stream", stream},
+ {"logit_bias", format_logit_bias(sampling.logit_bias)},
+ {"n_probs", sampling.n_probs},
+ {"min_keep", sampling.min_keep},
+ {"grammar", sampling.grammar},
+ {"grammar_lazy", sampling.grammar_lazy},
+ {"grammar_triggers", grammar_triggers},
+ {"preserved_tokens", sampling.preserved_tokens},
+ {"chat_format", common_chat_format_name(oaicompat_chat_format)},
+ {"samplers", samplers},
+ {"speculative.n_max", speculative.n_max},
+ {"speculative.n_min", speculative.n_min},
+ {"speculative.p_min", speculative.p_min},
+ {"timings_per_token", timings_per_token},
+ {"post_sampling_probs", post_sampling_probs},
+ {"lora", lora},
};
}
};
struct server_task {
- int id = -1; // to be filled by server_queue
+ int id = -1; // to be filled by server_queue
int index = -1; // used when there are multiple prompts (batch request)
server_task_type type;
@@ -183,7 +181,7 @@ struct server_task {
int id_target = -1;
// used by SERVER_TASK_TYPE_INFERENCE
- slot_params params;
+ slot_params params;
llama_tokens prompt_tokens;
int id_selected_slot = -1;
@@ -203,61 +201,59 @@ struct server_task {
server_task(server_task_type type) : type(type) {}
- static slot_params params_from_json_cmpl(const llama_context *ctx, const common_params ¶ms_base,
- const json &data) {
- const llama_model *model = llama_get_model(ctx);
- const llama_vocab *vocab = llama_model_get_vocab(model);
+ static slot_params params_from_json_cmpl(
+ const llama_context * ctx,
+ const common_params & params_base,
+ const json & data) {
+ const llama_model * model = llama_get_model(ctx);
+ const llama_vocab * vocab = llama_model_get_vocab(model);
slot_params params;
- // Sampling parameter defaults are loaded from the global server context (but individual requests can still
- // override them)
+ // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
slot_params defaults;
- defaults.sampling = params_base.sampling;
+ defaults.sampling = params_base.sampling;
defaults.speculative = params_base.speculative;
// enabling this will output extra debug information in the HTTP responses from the server
- params.verbose = params_base.verbosity > 9;
+ params.verbose = params_base.verbosity > 9;
params.timings_per_token = json_value(data, "timings_per_token", false);
- params.stream = json_value(data, "stream", false);
- params.cache_prompt = json_value(data, "cache_prompt", true);
- params.return_tokens = json_value(data, "return_tokens", false);
- params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
- params.n_indent = json_value(data, "n_indent", defaults.n_indent);
- params.n_keep = json_value(data, "n_keep", defaults.n_keep);
- params.n_discard = json_value(data, "n_discard", defaults.n_discard);
- // params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO:
- // implement
- params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
- params.response_fields = json_value(data, "response_fields", std::vector());
-
- params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
- params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
- params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
- params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
- params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
- params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
- params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
- params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
- params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
- params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
- params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
- params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
- params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
- params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
- params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
- params.sampling.dry_allowed_length =
- json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
- params.sampling.dry_penalty_last_n =
- json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
- params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
- params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
- params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
- params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
- params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
- params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
- params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
+ params.stream = json_value(data, "stream", false);
+ params.cache_prompt = json_value(data, "cache_prompt", true);
+ params.return_tokens = json_value(data, "return_tokens", false);
+ params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
+ params.n_indent = json_value(data, "n_indent", defaults.n_indent);
+ params.n_keep = json_value(data, "n_keep", defaults.n_keep);
+ params.n_discard = json_value(data, "n_discard", defaults.n_discard);
+ //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
+ params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
+ params.response_fields = json_value(data, "response_fields", std::vector());
+
+ params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
+ params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
+ params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
+ params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
+ params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
+ params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
+ params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
+ params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
+ params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
+ params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
+ params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
+ params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
+ params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
+ params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
+ params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
+ params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
+ params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
+ params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
+ params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
+ params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
+ params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
+ params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
+ params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
+ params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
@@ -268,7 +264,7 @@ struct server_task {
params.speculative.n_max = std::max(params.speculative.n_max, 0);
// Use OpenAI API logprobs only if n_probs wasn't provided
- if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs) {
+ if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);
}
@@ -308,12 +304,10 @@ struct server_task {
// sequence breakers for DRY
{
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
- // Ref:
- // https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
+ // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
if (data.contains("dry_sequence_breakers")) {
- params.sampling.dry_sequence_breakers =
- json_value(data, "dry_sequence_breakers", std::vector());
+ params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector());
if (params.sampling.dry_sequence_breakers.empty()) {
throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings");
}
@@ -323,15 +317,15 @@ struct server_task {
// process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
- auto schema = json_value(data, "json_schema", json::object());
+ auto schema = json_value(data, "json_schema", json::object());
SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
- params.sampling.grammar = json_schema_to_grammar(schema);
+ params.sampling.grammar = json_schema_to_grammar(schema);
SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
- } catch (const std::exception &e) {
+ } catch (const std::exception & e) {
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
}
} else {
- params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
+ params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
@@ -350,39 +344,35 @@ struct server_task {
{
const auto preserved_tokens = data.find("preserved_tokens");
if (preserved_tokens != data.end()) {
- for (const auto &t : *preserved_tokens) {
- auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false,
- /* parse_special= */ true);
+ for (const auto & t : *preserved_tokens) {
+ auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
SRV_DBG("Preserved token: %d\n", ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
} else {
- // This may happen when using a tool call style meant for a model with special tokens to
- // preserve on a model without said tokens.
+ // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
SRV_DBG("Not preserved because more than 1 token: %s\n", t.get().c_str());
}
}
}
const auto grammar_triggers = data.find("grammar_triggers");
if (grammar_triggers != data.end()) {
- for (const auto &t : *grammar_triggers) {
+ for (const auto & t : *grammar_triggers) {
auto ct = common_grammar_trigger::from_json(t);
if (ct.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
- const auto &word = ct.value;
+ const auto & word = ct.value;
auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
auto token = ids[0];
- if (std::find(params.sampling.preserved_tokens.begin(),
- params.sampling.preserved_tokens.end(),
- (llama_token)token) == params.sampling.preserved_tokens.end()) {
- throw std::runtime_error("Grammar trigger word should be marked as preserved token: " +
- word);
+ if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) {
+ throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
}
SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
common_grammar_trigger trigger;
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
- trigger.value = (llama_token)token;
- params.sampling.grammar_triggers.push_back(trigger);
+ trigger.value = word;
+ trigger.token = token;
+ params.sampling.grammar_triggers.push_back(std::move(trigger));
} else {
SRV_DBG("Grammar trigger word: `%s`\n", word.c_str());
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
@@ -401,10 +391,10 @@ struct server_task {
params.sampling.logit_bias.clear();
params.ignore_eos = json_value(data, "ignore_eos", false);
- const auto &logit_bias = data.find("logit_bias");
+ const auto & logit_bias = data.find("logit_bias");
if (logit_bias != data.end() && logit_bias->is_array()) {
const int n_vocab = llama_vocab_n_tokens(vocab);
- for (const auto &el : *logit_bias) {
+ for (const auto & el : *logit_bias) {
// TODO: we may want to throw errors here, in case "el" is incorrect
if (el.is_array() && el.size() == 2) {
float bias;
@@ -435,9 +425,9 @@ struct server_task {
{
params.antiprompt.clear();
- const auto &stop = data.find("stop");
+ const auto & stop = data.find("stop");
if (stop != data.end() && stop->is_array()) {
- for (const auto &word : *stop) {
+ for (const auto & word : *stop) {
if (!word.empty()) {
params.antiprompt.push_back(word);
}
@@ -450,7 +440,7 @@ struct server_task {
if (samplers != data.end()) {
if (samplers->is_array()) {
params.sampling.samplers = common_sampler_types_from_names(*samplers, false);
- } else if (samplers->is_string()) {
+ } else if (samplers->is_string()){
params.sampling.samplers = common_sampler_types_from_chars(samplers->get());
}
} else {
@@ -465,7 +455,7 @@ struct server_task {
}
// utility function
- static std::unordered_set get_list_id(const std::vector &tasks) {
+ static std::unordered_set get_list_id(const std::vector & tasks) {
std::unordered_set ids(tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
ids.insert(tasks[i].id);
@@ -487,22 +477,22 @@ struct result_timings {
json to_json() const {
return {
- {"prompt_n", prompt_n},
- {"prompt_ms", prompt_ms},
- {"prompt_per_token_ms", prompt_per_token_ms},
- {"prompt_per_second", prompt_per_second},
+ {"prompt_n", prompt_n},
+ {"prompt_ms", prompt_ms},
+ {"prompt_per_token_ms", prompt_per_token_ms},
+ {"prompt_per_second", prompt_per_second},
- {"predicted_n", predicted_n},
- {"predicted_ms", predicted_ms},
+ {"predicted_n", predicted_n},
+ {"predicted_ms", predicted_ms},
{"predicted_per_token_ms", predicted_per_token_ms},
- {"predicted_per_second", predicted_per_second},
+ {"predicted_per_second", predicted_per_second},
};
}
};
struct server_task_result {
- int id = -1;
- int id_slot = -1;
+ int id = -1;
+ int id_slot = -1;
virtual bool is_error() {
// only used by server_task_result_error
return false;
@@ -511,7 +501,9 @@ struct server_task_result {
// only used by server_task_result_cmpl_*
return false;
}
- virtual int get_index() { return -1; }
+ virtual int get_index() {
+ return -1;
+ }
virtual json to_json() = 0;
virtual ~server_task_result() = default;
};
@@ -521,14 +513,10 @@ using server_task_result_ptr = std::unique_ptr;
inline std::string stop_type_to_str(stop_type type) {
switch (type) {
- case STOP_TYPE_EOS:
- return "eos";
- case STOP_TYPE_WORD:
- return "word";
- case STOP_TYPE_LIMIT:
- return "limit";
- default:
- return "none";
+ case STOP_TYPE_EOS: return "eos";
+ case STOP_TYPE_WORD: return "word";
+ case STOP_TYPE_LIMIT: return "limit";
+ default: return "none";
}
}
@@ -545,30 +533,39 @@ struct completion_token_output {
json to_json(bool post_sampling_probs) const {
json probs_for_token = json::array();
- for (const auto &p : probs) {
+ for (const auto & p : probs) {
std::string txt(p.txt);
txt.resize(validate_utf8(txt));
- probs_for_token.push_back(json{
- {"id", p.tok},
- {"token", txt},
- {"bytes", str_to_bytes(p.txt)},
- {post_sampling_probs ? "prob" : "logprob", post_sampling_probs ? p.prob : logarithm(p.prob)},
+ probs_for_token.push_back(json {
+ {"id", p.tok},
+ {"token", txt},
+ {"bytes", str_to_bytes(p.txt)},
+ {
+ post_sampling_probs ? "prob" : "logprob",
+ post_sampling_probs ? p.prob : logarithm(p.prob)
+ },
});
}
return probs_for_token;
}
- static json probs_vector_to_json(const std::vector &probs, bool post_sampling_probs) {
+ static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) {
json out = json::array();
- for (const auto &p : probs) {
+ for (const auto & p : probs) {
std::string txt(p.text_to_send);
txt.resize(validate_utf8(txt));
- out.push_back(json{
- {"id", p.tok},
- {"token", txt},
- {"bytes", str_to_bytes(p.text_to_send)},
- {post_sampling_probs ? "prob" : "logprob", post_sampling_probs ? p.prob : logarithm(p.prob)},
- {post_sampling_probs ? "top_probs" : "top_logprobs", p.to_json(post_sampling_probs)},
+ out.push_back(json {
+ {"id", p.tok},
+ {"token", txt},
+ {"bytes", str_to_bytes(p.text_to_send)},
+ {
+ post_sampling_probs ? "prob" : "logprob",
+ post_sampling_probs ? p.prob : logarithm(p.prob)
+ },
+ {
+ post_sampling_probs ? "top_probs" : "top_logprobs",
+ p.to_json(post_sampling_probs)
+ },
});
}
return out;
@@ -579,7 +576,7 @@ struct completion_token_output {
return x == 0.0f ? std::numeric_limits::lowest() : std::log(x);
}
- static std::vector str_to_bytes(const std::string &str) {
+ static std::vector str_to_bytes(const std::string & str) {
std::vector bytes;
for (unsigned char c : str) {
bytes.push_back(c);
@@ -608,18 +605,20 @@ struct server_task_result_cmpl_final : server_task_result {
bool post_sampling_probs;
std::vector probs_output;
- std::vector response_fields;
+ std::vector response_fields;
slot_params generation_params;
// OAI-compat fields
- bool verbose = false;
- oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
- std::string oaicompat_model;
- std::string oaicompat_cmpl_id;
- common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+ bool verbose = false;
+ oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
+ std::string oaicompat_model;
+ std::string oaicompat_cmpl_id;
+ common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
- virtual int get_index() override { return index; }
+ virtual int get_index() override {
+ return index;
+ }
virtual bool is_stop() override {
return true; // in stream mode, final responses are considered stop
@@ -627,39 +626,38 @@ struct server_task_result_cmpl_final : server_task_result {
virtual json to_json() override {
switch (oaicompat) {
- case OAICOMPAT_TYPE_NONE:
- return to_json_non_oaicompat();
- case OAICOMPAT_TYPE_COMPLETION:
- return to_json_oaicompat();
- case OAICOMPAT_TYPE_CHAT:
- return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
- default:
- GGML_ASSERT(false && "Invalid oaicompat_type");
+ case OAICOMPAT_TYPE_NONE:
+ return to_json_non_oaicompat();
+ case OAICOMPAT_TYPE_COMPLETION:
+ return to_json_oaicompat();
+ case OAICOMPAT_TYPE_CHAT:
+ return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
+ default:
+ GGML_ASSERT(false && "Invalid oaicompat_type");
}
}
json to_json_non_oaicompat() {
- json res = json{
- {"index", index},
- {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
- {"tokens", stream ? llama_tokens{} : tokens},
- {"id_slot", id_slot},
- {"stop", true},
- {"model", oaicompat_model},
- {"tokens_predicted", n_decoded},
- {"tokens_evaluated", n_prompt_tokens},
+ json res = json {
+ {"index", index},
+ {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
+ {"tokens", stream ? llama_tokens {} : tokens},
+ {"id_slot", id_slot},
+ {"stop", true},
+ {"model", oaicompat_model},
+ {"tokens_predicted", n_decoded},
+ {"tokens_evaluated", n_prompt_tokens},
{"generation_settings", generation_params.to_json()},
- {"prompt", prompt},
- {"has_new_line", has_new_line},
- {"truncated", truncated},
- {"stop_type", stop_type_to_str(stop)},
- {"stopping_word", stopping_word},
- {"tokens_cached", n_tokens_cached},
- {"timings", timings.to_json()},
+ {"prompt", prompt},
+ {"has_new_line", has_new_line},
+ {"truncated", truncated},
+ {"stop_type", stop_type_to_str(stop)},
+ {"stopping_word", stopping_word},
+ {"tokens_cached", n_tokens_cached},
+ {"timings", timings.to_json()},
};
if (!stream && !probs_output.empty()) {
- res["completion_probabilities"] =
- completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
+ res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
}
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
}
@@ -676,21 +674,26 @@ struct server_task_result_cmpl_final : server_task_result {
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop";
}
- json res = json{
- {"choices", json::array({json{
- {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
- {"index", index},
- {"logprobs", logprobs},
- {"finish_reason", finish_reason},
- }})},
- {"created", t},
- {"model", oaicompat_model},
+ json res = json {
+ {"choices", json::array({
+ json{
+ {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
+ {"index", index},
+ {"logprobs", logprobs},
+ {"finish_reason", finish_reason},
+ }
+ })},
+ {"created", t},
+ {"model", oaicompat_model},
{"system_fingerprint", build_info},
- {"object", "text_completion"},
- {"usage", json{{"completion_tokens", n_decoded},
- {"prompt_tokens", n_prompt_tokens},
- {"total_tokens", n_decoded + n_prompt_tokens}}},
- {"id", oaicompat_cmpl_id}};
+ {"object", "text_completion"},
+ {"usage", json {
+ {"completion_tokens", n_decoded},
+ {"prompt_tokens", n_prompt_tokens},
+ {"total_tokens", n_decoded + n_prompt_tokens}
+ }},
+ {"id", oaicompat_cmpl_id}
+ };
// extra fields for debugging purposes
if (verbose) {
@@ -714,7 +717,7 @@ struct server_task_result_cmpl_final : server_task_result {
msg.content = content;
}
- json message{
+ json message {
{"role", "assistant"},
};
if (!msg.reasoning_content.empty()) {
@@ -727,21 +730,23 @@ struct server_task_result_cmpl_final : server_task_result {
}
if (!msg.tool_calls.empty()) {
auto tool_calls = json::array();
- for (const auto &tc : msg.tool_calls) {
+ for (const auto & tc : msg.tool_calls) {
tool_calls.push_back({
{"type", "function"},
- {"function",
- {
- {"name", tc.name},
- {"arguments", tc.arguments},
- }},
- {"id", tc.id},
+ {"function", {
+ {"name", tc.name},
+ {"arguments", tc.arguments},
+ }},
+ // Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
+ // We only generate a random id for the ones that don't generate one by themselves
+ // (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
+ {"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
});
}
message["tool_calls"] = tool_calls;
}
- json choice{
+ json choice {
{"finish_reason", finish_reason},
{"index", 0},
{"message", message},
@@ -755,15 +760,19 @@ struct server_task_result_cmpl_final : server_task_result {
std::time_t t = std::time(0);
- json res = json{{"choices", json::array({choice})},
- {"created", t},
- {"model", oaicompat_model},
- {"system_fingerprint", build_info},
- {"object", "chat.completion"},
- {"usage", json{{"completion_tokens", n_decoded},
- {"prompt_tokens", n_prompt_tokens},
- {"total_tokens", n_decoded + n_prompt_tokens}}},
- {"id", oaicompat_cmpl_id}};
+ json res = json {
+ {"choices", json::array({choice})},
+ {"created", t},
+ {"model", oaicompat_model},
+ {"system_fingerprint", build_info},
+ {"object", "chat.completion"},
+ {"usage", json {
+ {"completion_tokens", n_decoded},
+ {"prompt_tokens", n_prompt_tokens},
+ {"total_tokens", n_decoded + n_prompt_tokens}
+ }},
+ {"id", oaicompat_cmpl_id}
+ };
// extra fields for debugging purposes
if (verbose) {
@@ -783,21 +792,24 @@ struct server_task_result_cmpl_final : server_task_result {
finish_reason = "stop";
}
- json choice = json{{"finish_reason", finish_reason}, {"index", 0}, {"delta", json::object()}};
+ json choice = json {
+ {"finish_reason", finish_reason},
+ {"index", 0},
+ {"delta", json::object()}
+ };
- json ret = json{
- {"choices", json::array({choice})},
- {"created", t},
- {"id", oaicompat_cmpl_id},
- {"model", oaicompat_model},
+ json ret = json {
+ {"choices", json::array({choice})},
+ {"created", t},
+ {"id", oaicompat_cmpl_id},
+ {"model", oaicompat_model},
{"system_fingerprint", build_info},
- {"object", "chat.completion.chunk"},
- {"usage",
- json{
- {"completion_tokens", n_decoded},
- {"prompt_tokens", n_prompt_tokens},
- {"total_tokens", n_decoded + n_prompt_tokens},
- }},
+ {"object", "chat.completion.chunk"},
+ {"usage", json {
+ {"completion_tokens", n_decoded},
+ {"prompt_tokens", n_prompt_tokens},
+ {"total_tokens", n_decoded + n_prompt_tokens},
+ }},
};
if (timings.prompt_n >= 0) {
@@ -811,7 +823,7 @@ struct server_task_result_cmpl_final : server_task_result {
struct server_task_result_cmpl_partial : server_task_result {
int index = 0;
- std::string content;
+ std::string content;
llama_tokens tokens;
int32_t n_decoded;
@@ -822,12 +834,14 @@ struct server_task_result_cmpl_partial : server_task_result {
result_timings timings;
// OAI-compat fields
- bool verbose = false;
+ bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
- std::string oaicompat_model;
- std::string oaicompat_cmpl_id;
+ std::string oaicompat_model;
+ std::string oaicompat_cmpl_id;
- virtual int get_index() override { return index; }
+ virtual int get_index() override {
+ return index;
+ }
virtual bool is_stop() override {
return false; // in stream mode, partial responses are not considered stop
@@ -835,25 +849,25 @@ struct server_task_result_cmpl_partial : server_task_result {
virtual json to_json() override {
switch (oaicompat) {
- case OAICOMPAT_TYPE_NONE:
- return to_json_non_oaicompat();
- case OAICOMPAT_TYPE_COMPLETION:
- return to_json_oaicompat();
- case OAICOMPAT_TYPE_CHAT:
- return to_json_oaicompat_chat();
- default:
- GGML_ASSERT(false && "Invalid oaicompat_type");
+ case OAICOMPAT_TYPE_NONE:
+ return to_json_non_oaicompat();
+ case OAICOMPAT_TYPE_COMPLETION:
+ return to_json_oaicompat();
+ case OAICOMPAT_TYPE_CHAT:
+ return to_json_oaicompat_chat();
+ default:
+ GGML_ASSERT(false && "Invalid oaicompat_type");
}
}
json to_json_non_oaicompat() {
// non-OAI-compat JSON
- json res = json{
- {"index", index},
- {"content", content},
- {"tokens", tokens},
- {"stop", false},
- {"id_slot", id_slot},
+ json res = json {
+ {"index", index},
+ {"content", content},
+ {"tokens", tokens},
+ {"stop", false},
+ {"id_slot", id_slot},
{"tokens_predicted", n_decoded},
{"tokens_evaluated", n_prompt_tokens},
};
@@ -862,8 +876,7 @@ struct server_task_result_cmpl_partial : server_task_result {
res.push_back({"timings", timings.to_json()});
}
if (!prob_output.probs.empty()) {
- res["completion_probabilities"] =
- completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
+ res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
}
return res;
}
@@ -876,17 +889,21 @@ struct server_task_result_cmpl_partial : server_task_result {
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
};
}
- json res = json{{"choices", json::array({json{
- {"text", content},
- {"index", index},
- {"logprobs", logprobs},
- {"finish_reason", nullptr},
- }})},
- {"created", t},
- {"model", oaicompat_model},
- {"system_fingerprint", build_info},
- {"object", "text_completion"},
- {"id", oaicompat_cmpl_id}};
+ json res = json {
+ {"choices", json::array({
+ json{
+ {"text", content},
+ {"index", index},
+ {"logprobs", logprobs},
+ {"finish_reason", nullptr},
+ }
+ })},
+ {"created", t},
+ {"model", oaicompat_model},
+ {"system_fingerprint", build_info},
+ {"object", "text_completion"},
+ {"id", oaicompat_cmpl_id}
+ };
// extra fields for debugging purposes
if (verbose) {
@@ -906,26 +923,32 @@ struct server_task_result_cmpl_partial : server_task_result {
if (first) {
if (content.empty()) {
- choices = json::array(
- {json{{"finish_reason", nullptr}, {"index", 0}, {"delta", json{{"role", "assistant"}}}}});
+ choices = json::array({json{{"finish_reason", nullptr},
+ {"index", 0},
+ {"delta", json{{"role", "assistant"}}}}});
} else {
// We have to send this as two updates to conform to openai behavior
- json initial_ret = json{{"choices", json::array({json{{"finish_reason", nullptr},
- {"index", 0},
- {"delta", json{{"role", "assistant"}}}}})},
- {"created", t},
- {"id", oaicompat_cmpl_id},
- {"model", oaicompat_model},
- {"object", "chat.completion.chunk"}};
-
- json second_ret =
- json{{"choices",
- json::array(
- {json{{"finish_reason", nullptr}, {"index", 0}, {"delta", json{{"content", content}}}}})},
- {"created", t},
- {"id", oaicompat_cmpl_id},
- {"model", oaicompat_model},
- {"object", "chat.completion.chunk"}};
+ json initial_ret = json{{"choices", json::array({json{
+ {"finish_reason", nullptr},
+ {"index", 0},
+ {"delta", json{
+ {"role", "assistant"}
+ }}}})},
+ {"created", t},
+ {"id", oaicompat_cmpl_id},
+ {"model", oaicompat_model},
+ {"object", "chat.completion.chunk"}};
+
+ json second_ret = json{
+ {"choices", json::array({json{{"finish_reason", nullptr},
+ {"index", 0},
+ {"delta", json {
+ {"content", content}}}
+ }})},
+ {"created", t},
+ {"id", oaicompat_cmpl_id},
+ {"model", oaicompat_model},
+ {"object", "chat.completion.chunk"}};
return std::vector({initial_ret, second_ret});
}
@@ -934,9 +957,9 @@ struct server_task_result_cmpl_partial : server_task_result {
{"finish_reason", nullptr},
{"index", 0},
{"delta",
- json{
- {"content", content},
- }},
+ json {
+ {"content", content},
+ }},
}});
}
@@ -948,12 +971,14 @@ struct server_task_result_cmpl_partial : server_task_result {
};
}
- json ret = json{{"choices", choices},
- {"created", t},
- {"id", oaicompat_cmpl_id},
- {"model", oaicompat_model},
- {"system_fingerprint", build_info},
- {"object", "chat.completion.chunk"}};
+ json ret = json {
+ {"choices", choices},
+ {"created", t},
+ {"id", oaicompat_cmpl_id},
+ {"model", oaicompat_model},
+ {"system_fingerprint", build_info},
+ {"object", "chat.completion.chunk"}
+ };
if (timings.prompt_n >= 0) {
ret.push_back({"timings", timings.to_json()});
@@ -972,23 +997,27 @@ struct server_task_result_embd : server_task_result {
// OAI-compat fields
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
- virtual int get_index() override { return index; }
+ virtual int get_index() override {
+ return index;
+ }
virtual json to_json() override {
- return oaicompat == OAICOMPAT_TYPE_EMBEDDING ? to_json_oaicompat() : to_json_non_oaicompat();
+ return oaicompat == OAICOMPAT_TYPE_EMBEDDING
+ ? to_json_oaicompat()
+ : to_json_non_oaicompat();
}
json to_json_non_oaicompat() {
- return json{
- {"index", index},
+ return json {
+ {"index", index},
{"embedding", embedding},
};
}
json to_json_oaicompat() {
- return json{
- {"index", index},
- {"embedding", embedding[0]},
+ return json {
+ {"index", index},
+ {"embedding", embedding[0]},
{"tokens_evaluated", n_tokens},
};
}
@@ -1000,52 +1029,54 @@ struct server_task_result_rerank : server_task_result {
int32_t n_tokens;
- virtual int get_index() override { return index; }
+ virtual int get_index() override {
+ return index;
+ }
virtual json to_json() override {
- return json{
- {"index", index},
- {"score", score},
+ return json {
+ {"index", index},
+ {"score", score},
{"tokens_evaluated", n_tokens},
};
}
};
// this function maybe used outside of server_task_result_error
-static json format_error_response(const std::string &message, const enum error_type type) {
+static json format_error_response(const std::string & message, const enum error_type type) {
std::string type_str;
int code = 500;
switch (type) {
- case ERROR_TYPE_INVALID_REQUEST:
- type_str = "invalid_request_error";
- code = 400;
- break;
- case ERROR_TYPE_AUTHENTICATION:
- type_str = "authentication_error";
- code = 401;
- break;
- case ERROR_TYPE_NOT_FOUND:
- type_str = "not_found_error";
- code = 404;
- break;
- case ERROR_TYPE_SERVER:
- type_str = "server_error";
- code = 500;
- break;
- case ERROR_TYPE_PERMISSION:
- type_str = "permission_error";
- code = 403;
- break;
- case ERROR_TYPE_NOT_SUPPORTED:
- type_str = "not_supported_error";
- code = 501;
- break;
- case ERROR_TYPE_UNAVAILABLE:
- type_str = "unavailable_error";
- code = 503;
- break;
- }
- return json{
+ case ERROR_TYPE_INVALID_REQUEST:
+ type_str = "invalid_request_error";
+ code = 400;
+ break;
+ case ERROR_TYPE_AUTHENTICATION:
+ type_str = "authentication_error";
+ code = 401;
+ break;
+ case ERROR_TYPE_NOT_FOUND:
+ type_str = "not_found_error";
+ code = 404;
+ break;
+ case ERROR_TYPE_SERVER:
+ type_str = "server_error";
+ code = 500;
+ break;
+ case ERROR_TYPE_PERMISSION:
+ type_str = "permission_error";
+ code = 403;
+ break;
+ case ERROR_TYPE_NOT_SUPPORTED:
+ type_str = "not_supported_error";
+ code = 501;
+ break;
+ case ERROR_TYPE_UNAVAILABLE:
+ type_str = "unavailable_error";
+ code = 503;
+ break;
+ }
+ return json {
{"code", code},
{"message", message},
{"type", type_str},
@@ -1057,9 +1088,13 @@ struct server_task_result_error : server_task_result {
error_type err_type = ERROR_TYPE_SERVER;
std::string err_msg;
- virtual bool is_error() override { return true; }
+ virtual bool is_error() override {
+ return true;
+ }
- virtual json to_json() override { return format_error_response(err_msg, err_type); }
+ virtual json to_json() override {
+ return format_error_response(err_msg, err_type);
+ }
};
struct server_task_result_metrics : server_task_result {
@@ -1073,17 +1108,17 @@ struct server_task_result_metrics : server_task_result {
// TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
uint64_t n_prompt_tokens_processed_total = 0;
- uint64_t t_prompt_processing_total = 0;
- uint64_t n_tokens_predicted_total = 0;
- uint64_t t_tokens_generation_total = 0;
+ uint64_t t_prompt_processing_total = 0;
+ uint64_t n_tokens_predicted_total = 0;
+ uint64_t t_tokens_generation_total = 0;
uint64_t n_prompt_tokens_processed = 0;
- uint64_t t_prompt_processing = 0;
+ uint64_t t_prompt_processing = 0;
- uint64_t n_tokens_predicted = 0;
+ uint64_t n_tokens_predicted = 0;
uint64_t t_tokens_generation = 0;
- uint64_t n_decode_total = 0;
+ uint64_t n_decode_total = 0;
uint64_t n_busy_slots_total = 0;
// while we can also use std::vector this requires copying the slot object which can be quite messy
@@ -1091,29 +1126,29 @@ struct server_task_result_metrics : server_task_result {
json slots_data = json::array();
virtual json to_json() override {
- return json{
- {"idle", n_idle_slots},
- {"processing", n_processing_slots},
- {"deferred", n_tasks_deferred},
- {"t_start", t_start},
+ return json {
+ { "idle", n_idle_slots },
+ { "processing", n_processing_slots },
+ { "deferred", n_tasks_deferred },
+ { "t_start", t_start },
- {"n_prompt_tokens_processed_total", n_prompt_tokens_processed_total},
- {"t_tokens_generation_total", t_tokens_generation_total},
- {"n_tokens_predicted_total", n_tokens_predicted_total},
- {"t_prompt_processing_total", t_prompt_processing_total},
+ { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total },
+ { "t_tokens_generation_total", t_tokens_generation_total },
+ { "n_tokens_predicted_total", n_tokens_predicted_total },
+ { "t_prompt_processing_total", t_prompt_processing_total },
- {"n_prompt_tokens_processed", n_prompt_tokens_processed},
- {"t_prompt_processing", t_prompt_processing},
- {"n_tokens_predicted", n_tokens_predicted},
- {"t_tokens_generation", t_tokens_generation},
+ { "n_prompt_tokens_processed", n_prompt_tokens_processed },
+ { "t_prompt_processing", t_prompt_processing },
+ { "n_tokens_predicted", n_tokens_predicted },
+ { "t_tokens_generation", t_tokens_generation },
- {"n_decode_total", n_decode_total},
- {"n_busy_slots_total", n_busy_slots_total},
+ { "n_decode_total", n_decode_total },
+ { "n_busy_slots_total", n_busy_slots_total },
- {"kv_cache_tokens_count", kv_cache_tokens_count},
- {"kv_cache_used_cells", kv_cache_used_cells},
+ { "kv_cache_tokens_count", kv_cache_tokens_count },
+ { "kv_cache_used_cells", kv_cache_used_cells },
- {"slots", slots_data},
+ { "slots", slots_data },
};
}
};
@@ -1128,17 +1163,24 @@ struct server_task_result_slot_save_load : server_task_result {
virtual json to_json() override {
if (is_save) {
- return json{
- {"id_slot", id_slot}, {"filename", filename}, {"n_saved", n_tokens},
- {"n_written", n_bytes}, {"timings", {{"save_ms", t_ms}}},
+ return json {
+ { "id_slot", id_slot },
+ { "filename", filename },
+ { "n_saved", n_tokens },
+ { "n_written", n_bytes },
+ { "timings", {
+ { "save_ms", t_ms }
+ }},
};
} else {
- return json{
- {"id_slot", id_slot},
- {"filename", filename},
- {"n_restored", n_tokens},
- {"n_read", n_bytes},
- {"timings", {{"restore_ms", t_ms}}},
+ return json {
+ { "id_slot", id_slot },
+ { "filename", filename },
+ { "n_restored", n_tokens },
+ { "n_read", n_bytes },
+ { "timings", {
+ { "restore_ms", t_ms }
+ }},
};
}
}
@@ -1148,15 +1190,17 @@ struct server_task_result_slot_erase : server_task_result {
size_t n_erased;
virtual json to_json() override {
- return json{
- {"id_slot", id_slot},
- {"n_erased", n_erased},
+ return json {
+ { "id_slot", id_slot },
+ { "n_erased", n_erased },
};
}
};
struct server_task_result_apply_lora : server_task_result {
- virtual json to_json() override { return json{{"success", true}}; }
+ virtual json to_json() override {
+ return json {{ "success", true }};
+ }
};
struct server_slot {
@@ -1168,10 +1212,10 @@ struct server_slot {
llama_batch batch_spec = {};
- llama_context *ctx = nullptr;
- llama_context *ctx_dft = nullptr;
+ llama_context * ctx = nullptr;
+ llama_context * ctx_dft = nullptr;
- common_speculative *spec = nullptr;
+ common_speculative * spec = nullptr;
std::vector lora;
@@ -1186,15 +1230,15 @@ struct server_slot {
int64_t t_last_used = -1;
// generation props
- int32_t n_ctx = 0; // context size per slot
- int32_t n_past = 0;
- int32_t n_decoded = 0;
+ int32_t n_ctx = 0; // context size per slot
+ int32_t n_past = 0;
+ int32_t n_decoded = 0;
int32_t n_remaining = -1;
- int32_t i_batch = -1;
- int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
+ int32_t i_batch = -1;
+ int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
// n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
- int32_t n_prompt_tokens = 0;
+ int32_t n_prompt_tokens = 0;
int32_t n_prompt_tokens_processed = 0;
// input prompt tokens
@@ -1202,7 +1246,7 @@ struct server_slot {
size_t last_nl_pos = 0;
- std::string generated_text;
+ std::string generated_text;
llama_tokens generated_tokens;
llama_tokens cache_tokens;
@@ -1210,8 +1254,8 @@ struct server_slot {
std::vector generated_token_probs;
bool has_next_token = true;
- bool has_new_line = false;
- bool truncated = false;
+ bool has_new_line = false;
+ bool truncated = false;
stop_type stop;
std::string stopping_word;
@@ -1219,14 +1263,14 @@ struct server_slot {
// sampling
json json_schema;
- struct common_sampler *smpl = nullptr;
+ struct common_sampler * smpl = nullptr;
llama_token sampled;
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
// stats
- size_t n_sent_text = 0; // number of sent text character
+ size_t n_sent_text = 0; // number of sent text character
int64_t t_start_process_prompt;
int64_t t_start_generation;
@@ -1239,16 +1283,16 @@ struct server_slot {
void reset() {
SLT_DBG(*this, "%s", "\n");
- n_prompt_tokens = 0;
- last_nl_pos = 0;
- generated_text = "";
- has_new_line = false;
- truncated = false;
- stop = STOP_TYPE_NONE;
- stopping_word = "";
- n_past = 0;
- n_sent_text = 0;
- task_type = SERVER_TASK_TYPE_COMPLETION;
+ n_prompt_tokens = 0;
+ last_nl_pos = 0;
+ generated_text = "";
+ has_new_line = false;
+ truncated = false;
+ stop = STOP_TYPE_NONE;
+ stopping_word = "";
+ n_past = 0;
+ n_sent_text = 0;
+ task_type = SERVER_TASK_TYPE_COMPLETION;
generated_tokens.clear();
generated_token_probs.clear();
@@ -1258,11 +1302,12 @@ struct server_slot {
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
}
- bool can_batch_with(server_slot &other_slot) {
- return is_non_causal() == other_slot.is_non_causal() && are_lora_equal(lora, other_slot.lora);
+ bool can_batch_with(server_slot & other_slot) const {
+ return is_non_causal() == other_slot.is_non_causal()
+ && are_lora_equal(lora, other_slot.lora);
}
- bool has_budget(const common_params &global_params) {
+ bool has_budget(const common_params & global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1) {
return true; // limitless
}
@@ -1278,11 +1323,15 @@ struct server_slot {
return n_remaining > 0; // no budget
}
- bool is_processing() const { return state != SLOT_STATE_IDLE; }
+ bool is_processing() const {
+ return state != SLOT_STATE_IDLE;
+ }
- bool can_speculate() const { return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; }
+ bool can_speculate() const {
+ return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt;
+ }
- void add_token(const completion_token_output &token) {
+ void add_token(const completion_token_output & token) {
if (!is_processing()) {
SLT_WRN(*this, "%s", "slot is not processing\n");
return;
@@ -1316,14 +1365,14 @@ struct server_slot {
return timings;
}
- size_t find_stopping_strings(const std::string &text, const size_t last_token_size, bool is_full_stop) {
+ size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
size_t stop_pos = std::string::npos;
- for (const std::string &word : params.antiprompt) {
+ for (const std::string & word : params.antiprompt) {
size_t pos;
if (is_full_stop) {
- const size_t tmp = word.size() + last_token_size;
+ const size_t tmp = word.size() + last_token_size;
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
pos = text.find(word, from_pos);
@@ -1334,8 +1383,8 @@ struct server_slot {
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
if (is_full_stop) {
- stop = STOP_TYPE_WORD;
- stopping_word = word;
+ stop = STOP_TYPE_WORD;
+ stopping_word = word;
has_next_token = false;
}
stop_pos = pos;
@@ -1346,10 +1395,10 @@ struct server_slot {
}
void print_timings() const {
- const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
+ const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
- const double t_gen = t_token_generation / n_decoded;
+ const double t_gen = t_token_generation / n_decoded;
const double n_gen_second = 1e3 / t_token_generation * n_decoded;
SLT_INF(*this,
@@ -1357,29 +1406,30 @@ struct server_slot {
"prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
" eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
" total time = %10.2f ms / %5d tokens\n",
- t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, t_token_generation,
- n_decoded, t_gen, n_gen_second, t_prompt_processing + t_token_generation,
- n_prompt_tokens_processed + n_decoded);
+ t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
+ t_token_generation, n_decoded, t_gen, n_gen_second,
+ t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
}
json to_json() const {
- return json{
- {"id", id},
- {"id_task", id_task},
- {"n_ctx", n_ctx},
- {"speculative", can_speculate()},
+ return json {
+ {"id", id},
+ {"id_task", id_task},
+ {"n_ctx", n_ctx},
+ {"speculative", can_speculate()},
{"is_processing", is_processing()},
- {"non_causal", is_non_causal()},
- {"params", params.to_json()},
- {"prompt", common_detokenize(ctx, prompt_tokens)},
+ {"non_causal", is_non_causal()},
+ {"params", params.to_json()},
+ {"prompt", common_detokenize(ctx, prompt_tokens)},
{"next_token",
- {
- {"has_next_token", has_next_token},
- {"has_new_line", has_new_line},
- {"n_remain", n_remaining},
- {"n_decoded", n_decoded},
- {"stopping_word", stopping_word},
- }},
+ {
+ {"has_next_token", has_next_token},
+ {"has_new_line", has_new_line},
+ {"n_remain", n_remaining},
+ {"n_decoded", n_decoded},
+ {"stopping_word", stopping_word},
+ }
+ },
};
}
};
@@ -1388,38 +1438,40 @@ struct server_metrics {
int64_t t_start = 0;
uint64_t n_prompt_tokens_processed_total = 0;
- uint64_t t_prompt_processing_total = 0;
- uint64_t n_tokens_predicted_total = 0;
- uint64_t t_tokens_generation_total = 0;
+ uint64_t t_prompt_processing_total = 0;
+ uint64_t n_tokens_predicted_total = 0;
+ uint64_t t_tokens_generation_total = 0;
uint64_t n_prompt_tokens_processed = 0;
- uint64_t t_prompt_processing = 0;
+ uint64_t t_prompt_processing = 0;
- uint64_t n_tokens_predicted = 0;
+ uint64_t n_tokens_predicted = 0;
uint64_t t_tokens_generation = 0;
- uint64_t n_decode_total = 0;
+ uint64_t n_decode_total = 0;
uint64_t n_busy_slots_total = 0;
- void init() { t_start = ggml_time_us(); }
+ void init() {
+ t_start = ggml_time_us();
+ }
- void on_prompt_eval(const server_slot &slot) {
+ void on_prompt_eval(const server_slot & slot) {
n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed;
- n_prompt_tokens_processed += slot.n_prompt_tokens_processed;
- t_prompt_processing += slot.t_prompt_processing;
- t_prompt_processing_total += slot.t_prompt_processing;
+ n_prompt_tokens_processed += slot.n_prompt_tokens_processed;
+ t_prompt_processing += slot.t_prompt_processing;
+ t_prompt_processing_total += slot.t_prompt_processing;
}
- void on_prediction(const server_slot &slot) {
- n_tokens_predicted_total += slot.n_decoded;
- n_tokens_predicted += slot.n_decoded;
- t_tokens_generation += slot.t_token_generation;
- t_tokens_generation_total += slot.t_token_generation;
+ void on_prediction(const server_slot & slot) {
+ n_tokens_predicted_total += slot.n_decoded;
+ n_tokens_predicted += slot.n_decoded;
+ t_tokens_generation += slot.t_token_generation;
+ t_tokens_generation_total += slot.t_token_generation;
}
- void on_decoded(const std::vector &slots) {
+ void on_decoded(const std::vector & slots) {
n_decode_total++;
- for (const auto &slot : slots) {
+ for (const auto & slot : slots) {
if (slot.is_processing()) {
n_busy_slots_total++;
}
@@ -1428,9 +1480,9 @@ struct server_metrics {
void reset_bucket() {
n_prompt_tokens_processed = 0;
- t_prompt_processing = 0;
- n_tokens_predicted = 0;
- t_tokens_generation = 0;
+ t_prompt_processing = 0;
+ n_tokens_predicted = 0;
+ t_tokens_generation = 0;
}
};
@@ -1447,7 +1499,7 @@ struct server_queue {
// callback functions
std::function callback_new_task;
- std::function callback_update_slots;
+ std::function callback_update_slots;
// Add a new task to the end of the queue
int post(server_task task, bool front = false) {
@@ -1468,9 +1520,9 @@ struct server_queue {
}
// multi-task version of post()
- int post(std::vector &tasks, bool front = false) {
+ int post(std::vector & tasks, bool front = false) {
std::unique_lock lock(mutex_tasks);
- for (auto &task : tasks) {
+ for (auto & task : tasks) {
if (task.id == -1) {
task.id = id++;
}
@@ -1478,7 +1530,7 @@ struct server_queue {
if (task.type == SERVER_TASK_TYPE_CANCEL) {
cleanup_pending_task(task.id_target);
}
- QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int)tasks.size(), front);
+ QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
if (front) {
queue_tasks.push_front(std::move(task));
} else {
@@ -1505,10 +1557,14 @@ struct server_queue {
}
// Register function to process a new task
- void on_new_task(std::function callback) { callback_new_task = std::move(callback); }
+ void on_new_task(std::function callback) {
+ callback_new_task = std::move(callback);
+ }
// Register the function to be called when all slots data is ready to be processed
- void on_update_slots(std::function callback) { callback_update_slots = std::move(callback); }
+ void on_update_slots(std::function callback) {
+ callback_update_slots = std::move(callback);
+ }
// Call when the state of one slot is changed, it will move one task from deferred to main queue
void pop_deferred_task() {
@@ -1571,19 +1627,26 @@ struct server_queue {
return;
}
if (queue_tasks.empty()) {
- condition_tasks.wait(lock, [&] { return (!queue_tasks.empty() || !running); });
+ condition_tasks.wait(lock, [&]{
+ return (!queue_tasks.empty() || !running);
+ });
}
}
}
}
- private:
+private:
void cleanup_pending_task(int id_target) {
// no need lock because this is called exclusively by post()
- auto rm_func = [id_target](const server_task &task) { return task.id_target == id_target; };
- queue_tasks.erase(std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), queue_tasks.end());
- queue_tasks_deferred.erase(std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
- queue_tasks_deferred.end());
+ auto rm_func = [id_target](const server_task & task) {
+ return task.id_target == id_target;
+ };
+ queue_tasks.erase(
+ std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
+ queue_tasks.end());
+ queue_tasks_deferred.erase(
+ std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
+ queue_tasks_deferred.end());
}
};
@@ -1599,51 +1662,51 @@ struct server_response {
// add the id_task to the list of tasks waiting for response
void add_waiting_task_id(int id_task) {
- SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task,
- (int)waiting_task_ids.size());
+ SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
std::unique_lock lock(mutex_results);
waiting_task_ids.insert(id_task);
}
- void add_waiting_tasks(const std::vector &tasks) {
+ void add_waiting_tasks(const std::vector & tasks) {
std::unique_lock lock(mutex_results);
- for (const auto &task : tasks) {
- SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id,
- (int)waiting_task_ids.size());
+ for (const auto & task : tasks) {
+ SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size());
waiting_task_ids.insert(task.id);
}
}
// when the request is finished, we can remove task associated with it
void remove_waiting_task_id(int id_task) {
- SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task,
- (int)waiting_task_ids.size());
+ SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
std::unique_lock lock(mutex_results);
waiting_task_ids.erase(id_task);
// make sure to clean up all pending results
- queue_results.erase(std::remove_if(queue_results.begin(), queue_results.end(),
- [id_task](const server_task_result_ptr &res) { return res->id == id_task; }),
- queue_results.end());
+ queue_results.erase(
+ std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) {
+ return res->id == id_task;
+ }),
+ queue_results.end());
}
- void remove_waiting_task_ids(const std::unordered_set &id_tasks) {
+ void remove_waiting_task_ids(const std::unordered_set & id_tasks) {
std::unique_lock lock(mutex_results);
- for (const auto &id_task : id_tasks) {
- SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task,
- (int)waiting_task_ids.size());
+ for (const auto & id_task : id_tasks) {
+ SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
waiting_task_ids.erase(id_task);
}
}
// This function blocks the thread until there is a response for one of the id_tasks
- server_task_result_ptr recv(const std::unordered_set &id_tasks) {
+ server_task_result_ptr recv(const std::unordered_set & id_tasks) {
while (true) {
std::unique_lock lock(mutex_results);
- condition_results.wait(lock, [&] { return !queue_results.empty(); });
+ condition_results.wait(lock, [&]{
+ return !queue_results.empty();
+ });
for (size_t i = 0; i < queue_results.size(); i++) {
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
@@ -1659,11 +1722,11 @@ struct server_response {
// same as recv(), but have timeout in seconds
// if timeout is reached, nullptr is returned
- server_task_result_ptr recv_with_timeout(const std::unordered_set &id_tasks, int timeout) {
+ server_task_result_ptr recv_with_timeout(const std::unordered_set & id_tasks, int timeout) {
while (true) {
std::unique_lock lock(mutex_results);
- for (int i = 0; i < (int)queue_results.size(); i++) {
+ for (int i = 0; i < (int) queue_results.size(); i++) {
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
server_task_result_ptr res = std::move(queue_results[i]);
queue_results.erase(queue_results.begin() + i);
@@ -1687,11 +1750,11 @@ struct server_response {
}
// Send a new result to a waiting id_task
- void send(server_task_result_ptr &&result) {
+ void send(server_task_result_ptr && result) {
SRV_DBG("sending result for task id = %d\n", result->id);
std::unique_lock lock(mutex_results);
- for (const auto &id_task : waiting_task_ids) {
+ for (const auto & id_task : waiting_task_ids) {
if (result->id == id_task) {
SRV_DBG("task id = %d pushed to result queue\n", result->id);
@@ -1710,20 +1773,20 @@ struct server_context {
common_init_result llama_init;
common_init_result llama_init_dft;
- llama_model *model = nullptr;
- llama_context *ctx = nullptr;
+ llama_model * model = nullptr;
+ llama_context * ctx = nullptr;
- const llama_vocab *vocab = nullptr;
+ const llama_vocab * vocab = nullptr;
- llama_model *model_dft = nullptr;
+ llama_model * model_dft = nullptr;
llama_context_params cparams_dft;
llama_batch batch = {};
bool clean_kv_cache = true;
- bool add_bos_token = true;
- bool has_eos_token = false;
+ bool add_bos_token = true;
+ bool has_eos_token = false;
int32_t n_ctx; // total context for all clients / slots
@@ -1731,7 +1794,7 @@ struct server_context {
std::vector slots;
json default_generation_settings_for_props;
- server_queue queue_tasks;
+ server_queue queue_tasks;
server_response queue_results;
server_metrics metrics;
@@ -1743,7 +1806,7 @@ struct server_context {
~server_context() {
// Clear any sampling context
- for (server_slot &slot : slots) {
+ for (server_slot & slot : slots) {
common_sampler_free(slot.smpl);
slot.smpl = nullptr;
@@ -1759,7 +1822,7 @@ struct server_context {
llama_batch_free(batch);
}
- bool load_model(const common_params ¶ms) {
+ bool load_model(const common_params & params) {
SRV_INF("loading model '%s'\n", params.model.c_str());
params_base = params;
@@ -1767,7 +1830,7 @@ struct server_context {
llama_init = common_init_from_params(params_base);
model = llama_init.model.get();
- ctx = llama_init.context.get();
+ ctx = llama_init.context.get();
if (model == nullptr) {
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
@@ -1786,15 +1849,14 @@ struct server_context {
auto params_dft = params_base;
- params_dft.devices = params_base.speculative.devices;
- params_dft.hf_file = params_base.speculative.hf_file;
- params_dft.hf_repo = params_base.speculative.hf_repo;
- params_dft.model = params_base.speculative.model;
- params_dft.model_url = params_base.speculative.model_url;
- params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel
- : params_base.speculative.n_ctx;
+ params_dft.devices = params_base.speculative.devices;
+ params_dft.hf_file = params_base.speculative.hf_file;
+ params_dft.hf_repo = params_base.speculative.hf_repo;
+ params_dft.model = params_base.speculative.model;
+ params_dft.model_url = params_base.speculative.model_url;
+ params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
- params_dft.n_parallel = 1;
+ params_dft.n_parallel = 1;
llama_init_dft = common_init_from_params(params_dft);
@@ -1806,8 +1868,7 @@ struct server_context {
}
if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) {
- SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n",
- params_base.speculative.model.c_str(), params_base.model.c_str());
+ SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.c_str(), params_base.model.c_str());
return false;
}
@@ -1828,10 +1889,9 @@ struct server_context {
chat_templates = common_chat_templates_init(model, params_base.chat_template);
try {
common_chat_format_example(chat_templates.get(), params.use_jinja);
- } catch (const std::exception &e) {
- SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. "
- "This may cause the model to output suboptimal responses\n",
- __func__);
+ } catch (const std::exception & e) {
+ SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what());
+ SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
chat_templates = common_chat_templates_init(model, "chatml");
}
@@ -1871,7 +1931,9 @@ struct server_context {
slot.params.sampling = params_base.sampling;
- slot.callback_on_release = [this](int) { queue_tasks.pop_deferred_task(); };
+ slot.callback_on_release = [this](int) {
+ queue_tasks.pop_deferred_task();
+ };
slot.reset();
@@ -1881,8 +1943,7 @@ struct server_context {
default_generation_settings_for_props = slots[0].to_json();
// the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
- // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not
- // used)
+ // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
{
const int32_t n_batch = llama_n_batch(ctx);
@@ -1893,8 +1954,8 @@ struct server_context {
metrics.init();
}
- server_slot *get_slot_by_id(int id) {
- for (server_slot &slot : slots) {
+ server_slot * get_slot_by_id(int id) {
+ for (server_slot & slot : slots) {
if (slot.id == id) {
return &slot;
}
@@ -1903,15 +1964,15 @@ struct server_context {
return nullptr;
}
- server_slot *get_available_slot(const server_task &task) {
- server_slot *ret = nullptr;
+ server_slot * get_available_slot(const server_task & task) {
+ server_slot * ret = nullptr;
// find the slot that has at least n% prompt similarity
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
int lcs_len = 0;
float similarity = 0;
- for (server_slot &slot : slots) {
+ for (server_slot & slot : slots) {
// skip the slot if it is not available
if (slot.is_processing()) {
continue;
@@ -1944,7 +2005,7 @@ struct server_context {
// find the slot that has been least recently used
if (ret == nullptr) {
int64_t t_last = ggml_time_us();
- for (server_slot &slot : slots) {
+ for (server_slot & slot : slots) {
// skip the slot if it is not available
if (slot.is_processing()) {
continue;
@@ -1965,12 +2026,24 @@ struct server_context {
return ret;
}
- bool launch_slot_with_task(server_slot &slot, const server_task &task) {
+ bool can_be_detokenized(const struct llama_context * ctx, const std::vector & tokens) {
+ const llama_model * model = llama_get_model(ctx);
+ const llama_vocab * vocab = llama_model_get_vocab(model);
+ const int32_t n_vocab = llama_vocab_n_tokens(vocab);
+ for (const auto & token : tokens) {
+ if (token < 0 || token >= n_vocab) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ bool launch_slot_with_task(server_slot & slot, const server_task & task) {
slot.reset();
- slot.id_task = task.id;
- slot.index = task.index;
- slot.task_type = task.type;
- slot.params = std::move(task.params);
+ slot.id_task = task.id;
+ slot.index = task.index;
+ slot.task_type = task.type;
+ slot.params = std::move(task.params);
slot.prompt_tokens = std::move(task.prompt_tokens);
if (!are_lora_equal(task.params.lora, slot.lora)) {
@@ -1979,12 +2052,16 @@ struct server_context {
slot.lora = task.params.lora;
}
+ bool can_detokenize = can_be_detokenized(ctx, slot.prompt_tokens);
+ if (!can_detokenize) {
+ send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
+ return false;
+ }
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
// Might be better to reject the request with a 400 ?
- SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict,
- slot.n_predict);
+ SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict);
slot.params.n_predict = slot.n_predict;
}
@@ -2022,11 +2099,11 @@ struct server_context {
SRV_DBG("%s", "clearing KV cache\n");
// clear the entire KV cache
- llama_kv_cache_clear(ctx);
+ llama_kv_self_clear(ctx);
clean_kv_cache = false;
}
- bool process_token(completion_token_output &result, server_slot &slot) {
+ bool process_token(completion_token_output & result, server_slot & slot) {
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = result.text_to_send;
slot.sampled = result.tok;
@@ -2049,7 +2126,9 @@ struct server_context {
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
if (stop_pos != std::string::npos) {
- slot.generated_text.erase(slot.generated_text.begin() + pos + stop_pos, slot.generated_text.end());
+ slot.generated_text.erase(
+ slot.generated_text.begin() + pos + stop_pos,
+ slot.generated_text.end());
pos = std::min(slot.n_sent_text, slot.generated_text.size());
} else if (slot.has_next_token) {
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
@@ -2078,23 +2157,13 @@ struct server_context {
// check the limits
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
- slot.stop = STOP_TYPE_LIMIT;
+ slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
}
if (slot.has_new_line) {
- // if we have already seen a new line, we stop after a certain time limit
- if (slot.params.t_max_predict_ms > 0 &&
- (ggml_time_us() - slot.t_start_generation > 1000.0f * slot.params.t_max_predict_ms)) {
- slot.stop = STOP_TYPE_LIMIT;
- slot.has_next_token = false;
-
- SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded,
- (int)slot.params.t_max_predict_ms);
- }
-
// require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
if (slot.params.n_indent > 0) {
// check the current indentation
@@ -2103,21 +2172,19 @@ struct server_context {
size_t pos = slot.last_nl_pos;
int n_indent = 0;
- while (pos < slot.generated_text.size() &&
- (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
+ while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
n_indent++;
pos++;
}
if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
- slot.stop = STOP_TYPE_LIMIT;
+ slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
// cut the last line
slot.generated_text.erase(pos, std::string::npos);
- SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded,
- n_indent);
+ SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
}
}
@@ -2135,22 +2202,28 @@ struct server_context {
// check if there is a new line in the generated text
if (result.text_to_send.find('\n') != std::string::npos) {
slot.has_new_line = true;
+
+ // if we have seen a new line, we stop after a certain time limit, but only upon another new line
+ if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
+ slot.stop = STOP_TYPE_LIMIT;
+ slot.has_next_token = false;
+
+ SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
+ }
}
// if context shift is disabled, we stop when it reaches the context limit
if (slot.n_past >= slot.n_ctx) {
- slot.truncated = true;
- slot.stop = STOP_TYPE_LIMIT;
+ slot.truncated = true;
+ slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
- SLT_DBG(slot,
- "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = "
- "%d, n_ctx = %d\n",
+ SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx);
}
if (llama_vocab_is_eog(vocab, result.tok)) {
- slot.stop = STOP_TYPE_EOS;
+ slot.stop = STOP_TYPE_EOS;
slot.has_next_token = false;
SLT_DBG(slot, "%s", "stopped by EOS\n");
@@ -2159,8 +2232,8 @@ struct server_context {
const auto n_ctx_train = llama_model_n_ctx_train(model);
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
- slot.truncated = true;
- slot.stop = STOP_TYPE_LIMIT;
+ slot.truncated = true;
+ slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false; // stop prediction
SLT_WRN(slot,
@@ -2169,18 +2242,16 @@ struct server_context {
slot.params.n_predict, n_ctx_train);
}
- SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining,
- result.tok, token_str.c_str());
+ SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
return slot.has_next_token; // continue
}
- void populate_token_probs(const server_slot &slot, completion_token_output &result, bool post_sampling,
- bool special, int idx) {
+ void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
size_t n_probs = slot.params.sampling.n_probs;
size_t n_vocab = llama_vocab_n_tokens(vocab);
if (post_sampling) {
- const auto *cur_p = common_sampler_get_candidates(slot.smpl);
+ const auto * cur_p = common_sampler_get_candidates(slot.smpl);
const size_t max_probs = cur_p->size;
// set probability for sampled token
@@ -2194,8 +2265,11 @@ struct server_context {
// set probability for top n_probs tokens
result.probs.reserve(max_probs);
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
- result.probs.push_back(
- {cur_p->data[i].id, common_token_to_piece(ctx, cur_p->data[i].id, special), cur_p->data[i].p});
+ result.probs.push_back({
+ cur_p->data[i].id,
+ common_token_to_piece(ctx, cur_p->data[i].id, special),
+ cur_p->data[i].p
+ });
}
} else {
// TODO: optimize this with min-p optimization
@@ -2213,45 +2287,49 @@ struct server_context {
// set probability for top n_probs tokens
result.probs.reserve(n_probs);
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
- result.probs.push_back({cur[i].id, common_token_to_piece(ctx, cur[i].id, special), cur[i].p});
+ result.probs.push_back({
+ cur[i].id,
+ common_token_to_piece(ctx, cur[i].id, special),
+ cur[i].p
+ });
}
}
}
- void send_error(const server_task &task, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER) {
+ void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
send_error(task.id, error, type);
}
- void send_error(const server_slot &slot, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER) {
+ void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
send_error(slot.id_task, error, type);
}
- void send_error(const int id_task, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER) {
+ void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
auto res = std::make_unique();
- res->id = id_task;
+ res->id = id_task;
res->err_type = type;
- res->err_msg = error;
+ res->err_msg = error;
queue_results.send(std::move(res));
}
- void send_partial_response(server_slot &slot, const completion_token_output &tkn) {
+ void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
auto res = std::make_unique();
- res->id = slot.id_task;
- res->index = slot.index;
+ res->id = slot.id_task;
+ res->index = slot.index;
res->content = tkn.text_to_send;
- res->tokens = {tkn.tok};
+ res->tokens = { tkn.tok };
- res->n_decoded = slot.n_decoded;
- res->n_prompt_tokens = slot.n_prompt_tokens;
+ res->n_decoded = slot.n_decoded;
+ res->n_prompt_tokens = slot.n_prompt_tokens;
res->post_sampling_probs = slot.params.post_sampling_probs;
- res->verbose = slot.params.verbose;
- res->oaicompat = slot.params.oaicompat;
- res->oaicompat_model = slot.params.oaicompat_model;
+ res->verbose = slot.params.verbose;
+ res->oaicompat = slot.params.oaicompat;
+ res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
// populate res.probs_output
@@ -2267,32 +2345,32 @@ struct server_context {
queue_results.send(std::move(res));
}
- void send_final_response(server_slot &slot) {
+ void send_final_response(server_slot & slot) {
auto res = std::make_unique();
- res->id = slot.id_task;
- res->id_slot = slot.id;
-
- res->index = slot.index;
- res->content = std::move(slot.generated_text);
- res->tokens = std::move(slot.generated_tokens);
- res->timings = slot.get_timings();
- res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
+ res->id = slot.id_task;
+ res->id_slot = slot.id;
+
+ res->index = slot.index;
+ res->content = std::move(slot.generated_text);
+ res->tokens = std::move(slot.generated_tokens);
+ res->timings = slot.get_timings();
+ res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
res->response_fields = std::move(slot.params.response_fields);
- res->truncated = slot.truncated;
- res->n_decoded = slot.n_decoded;
- res->n_prompt_tokens = slot.n_prompt_tokens;
- res->n_tokens_cached = slot.n_past;
- res->has_new_line = slot.has_new_line;
- res->stopping_word = slot.stopping_word;
- res->stop = slot.stop;
+ res->truncated = slot.truncated;
+ res->n_decoded = slot.n_decoded;
+ res->n_prompt_tokens = slot.n_prompt_tokens;
+ res->n_tokens_cached = slot.n_past;
+ res->has_new_line = slot.has_new_line;
+ res->stopping_word = slot.stopping_word;
+ res->stop = slot.stop;
res->post_sampling_probs = slot.params.post_sampling_probs;
- res->verbose = slot.params.verbose;
- res->stream = slot.params.stream;
- res->oaicompat = slot.params.oaicompat;
- res->oaicompat_model = slot.params.oaicompat_model;
- res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
+ res->verbose = slot.params.verbose;
+ res->stream = slot.params.stream;
+ res->oaicompat = slot.params.oaicompat;
+ res->oaicompat_model = slot.params.oaicompat_model;
+ res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->oaicompat_chat_format = slot.params.oaicompat_chat_format;
// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
@@ -2301,10 +2379,12 @@ struct server_context {
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
res->probs_output = std::vector(
- slot.generated_token_probs.begin(), slot.generated_token_probs.end() - safe_offset);
+ slot.generated_token_probs.begin(),
+ slot.generated_token_probs.end() - safe_offset);
} else {
- res->probs_output = std::vector(slot.generated_token_probs.begin(),
- slot.generated_token_probs.end());
+ res->probs_output = std::vector(
+ slot.generated_token_probs.begin(),
+ slot.generated_token_probs.end());
}
}
@@ -2313,11 +2393,11 @@ struct server_context {
queue_results.send(std::move(res));
}
- void send_embedding(const server_slot &slot, const llama_batch &batch) {
+ void send_embedding(const server_slot & slot, const llama_batch & batch) {
auto res = std::make_unique();
- res->id = slot.id_task;
- res->index = slot.index;
- res->n_tokens = slot.n_prompt_tokens;
+ res->id = slot.id_task;
+ res->index = slot.index;
+ res->n_tokens = slot.n_prompt_tokens;
res->oaicompat = slot.params.oaicompat;
const int n_embd = llama_model_n_embd(model);
@@ -2329,14 +2409,13 @@ struct server_context {
continue;
}
- const float *embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
+ const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
}
if (embd == NULL) {
- SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i],
- batch.seq_id[i][0]);
+ SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
res->embedding.push_back(std::vector(n_embd, 0.0f));
continue;
@@ -2348,7 +2427,7 @@ struct server_context {
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
res->embedding.push_back(embd_res);
} else {
- res->embedding.push_back({embd, embd + n_embd});
+ res->embedding.push_back({ embd, embd + n_embd });
}
}
@@ -2357,9 +2436,9 @@ struct server_context {
queue_results.send(std::move(res));
}
- void send_rerank(const server_slot &slot, const llama_batch &batch) {
+ void send_rerank(const server_slot & slot, const llama_batch & batch) {
auto res = std::make_unique();
- res->id = slot.id_task;
+ res->id = slot.id_task;
res->index = slot.index;
res->n_tokens = slot.n_prompt_tokens;
@@ -2368,14 +2447,13 @@ struct server_context {
continue;
}
- const float *embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
+ const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
}
if (embd == NULL) {
- SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i],
- batch.seq_id[i][0]);
+ SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
res->score = -1e6;
continue;
@@ -2393,10 +2471,10 @@ struct server_context {
// Functions to create new task(s) and receive result(s)
//
- void cancel_tasks(const std::unordered_set &id_tasks) {
+ void cancel_tasks(const std::unordered_set & id_tasks) {
std::vector cancel_tasks;
cancel_tasks.reserve(id_tasks.size());
- for (const auto &id_task : id_tasks) {
+ for (const auto & id_task : id_tasks) {
SRV_WRN("cancel task, id_task = %d\n", id_task);
server_task task(SERVER_TASK_TYPE_CANCEL);
@@ -2409,10 +2487,11 @@ struct server_context {
}
// receive the results from task(s)
- void receive_multi_results(const std::unordered_set &id_tasks,
- const std::function &)> &result_handler,
- const std::function &error_handler,
- const std::function &is_connection_closed) {
+ void receive_multi_results(
+ const std::unordered_set & id_tasks,
+ const std::function&)> & result_handler,
+ const std::function & error_handler,
+ const std::function & is_connection_closed) {
std::vector results(id_tasks.size());
for (int i = 0; i < (int)id_tasks.size(); i++) {
server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
@@ -2433,9 +2512,11 @@ struct server_context {
return;
}
- GGML_ASSERT(dynamic_cast(result.get()) != nullptr ||
- dynamic_cast(result.get()) != nullptr ||
- dynamic_cast(result.get()) != nullptr);
+ GGML_ASSERT(
+ dynamic_cast(result.get()) != nullptr
+ || dynamic_cast(result.get()) != nullptr
+ || dynamic_cast(result.get()) != nullptr
+ );
const size_t idx = result->get_index();
GGML_ASSERT(idx < results.size() && "index out of range");
results[idx] = std::move(result);
@@ -2444,10 +2525,11 @@ struct server_context {
}
// receive the results from task(s), in stream mode
- void receive_cmpl_results_stream(const std::unordered_set &id_tasks,
- const std::function &result_handler,
- const std::function &error_handler,
- const std::function &is_connection_closed) {
+ void receive_cmpl_results_stream(
+ const std::unordered_set & id_tasks,
+ const std::function & result_handler,
+ const std::function & error_handler,
+ const std::function & is_connection_closed) {
size_t n_finished = 0;
while (true) {
server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
@@ -2467,8 +2549,10 @@ struct server_context {
return;
}
- GGML_ASSERT(dynamic_cast(result.get()) != nullptr ||
- dynamic_cast(result.get()) != nullptr);
+ GGML_ASSERT(
+ dynamic_cast(result.get()) != nullptr
+ || dynamic_cast(result.get()) != nullptr
+ );
if (!result_handler(result)) {
cancel_tasks(id_tasks);
break;
@@ -2488,203 +2572,208 @@ struct server_context {
void process_single_task(server_task task) {
switch (task.type) {
- case SERVER_TASK_TYPE_COMPLETION:
- case SERVER_TASK_TYPE_INFILL:
- case SERVER_TASK_TYPE_EMBEDDING:
- case SERVER_TASK_TYPE_RERANK: {
- const int id_slot = task.id_selected_slot;
-
- server_slot *slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
-
- if (slot == nullptr) {
- // if no slot is available, we defer this task for processing later
- SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
- queue_tasks.defer(task);
- break;
- }
- if (slot->is_processing()) {
- // if requested slot is unavailable, we defer this task for processing later
- SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
- queue_tasks.defer(task);
- break;
- }
-
- if (!launch_slot_with_task(*slot, task)) {
- SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
- break;
- }
- } break;
- case SERVER_TASK_TYPE_CANCEL: {
- // release slot linked with the task id
- for (auto &slot : slots) {
- if (slot.id_task == task.id_target) {
- slot.release();
- break;
- }
- }
- } break;
- case SERVER_TASK_TYPE_NEXT_RESPONSE: {
- // do nothing
- } break;
- case SERVER_TASK_TYPE_METRICS: {
- json slots_data = json::array();
+ case SERVER_TASK_TYPE_COMPLETION:
+ case SERVER_TASK_TYPE_INFILL:
+ case SERVER_TASK_TYPE_EMBEDDING:
+ case SERVER_TASK_TYPE_RERANK:
+ {
+ const int id_slot = task.id_selected_slot;
- int n_idle_slots = 0;
- int n_processing_slots = 0;
+ server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
- for (server_slot &slot : slots) {
- json slot_data = slot.to_json();
+ if (slot == nullptr) {
+ // if no slot is available, we defer this task for processing later
+ SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
+ queue_tasks.defer(task);
+ break;
+ }
+ if (slot->is_processing()) {
+ // if requested slot is unavailable, we defer this task for processing later
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
+ queue_tasks.defer(task);
+ break;
+ }
- if (slot.is_processing()) {
- n_processing_slots++;
- } else {
- n_idle_slots++;
- }
+ if (!launch_slot_with_task(*slot, task)) {
+ SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
+ break;
+ }
+ } break;
+ case SERVER_TASK_TYPE_CANCEL:
+ {
+ // release slot linked with the task id
+ for (auto & slot : slots) {
+ if (slot.id_task == task.id_target) {
+ slot.release();
+ break;
+ }
+ }
+ } break;
+ case SERVER_TASK_TYPE_NEXT_RESPONSE:
+ {
+ // do nothing
+ } break;
+ case SERVER_TASK_TYPE_METRICS:
+ {
+ json slots_data = json::array();
- slots_data.push_back(slot_data);
- }
- SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots);
+ int n_idle_slots = 0;
+ int n_processing_slots = 0;
- auto res = std::make_unique();
- res->id = task.id;
- res->slots_data = std::move(slots_data);
- res->n_idle_slots = n_idle_slots;
- res->n_processing_slots = n_processing_slots;
- res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
- res->t_start = metrics.t_start;
+ for (server_slot & slot : slots) {
+ json slot_data = slot.to_json();
- res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx);
- res->kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx);
+ if (slot.is_processing()) {
+ n_processing_slots++;
+ } else {
+ n_idle_slots++;
+ }
- res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
- res->t_prompt_processing_total = metrics.t_prompt_processing_total;
- res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
- res->t_tokens_generation_total = metrics.t_tokens_generation_total;
+ slots_data.push_back(slot_data);
+ }
+ SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots);
+
+ auto res = std::make_unique();
+ res->id = task.id;
+ res->slots_data = std::move(slots_data);
+ res->n_idle_slots = n_idle_slots;
+ res->n_processing_slots = n_processing_slots;
+ res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
+ res->t_start = metrics.t_start;
+
+ res->kv_cache_tokens_count = llama_kv_self_n_tokens(ctx);
+ res->kv_cache_used_cells = llama_kv_self_used_cells(ctx);
+
+ res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
+ res->t_prompt_processing_total = metrics.t_prompt_processing_total;
+ res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
+ res->t_tokens_generation_total = metrics.t_tokens_generation_total;
+
+ res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
+ res->t_prompt_processing = metrics.t_prompt_processing;
+ res->n_tokens_predicted = metrics.n_tokens_predicted;
+ res->t_tokens_generation = metrics.t_tokens_generation;
+
+ res->n_decode_total = metrics.n_decode_total;
+ res->n_busy_slots_total = metrics.n_busy_slots_total;
+
+ if (task.metrics_reset_bucket) {
+ metrics.reset_bucket();
+ }
+ queue_results.send(std::move(res));
+ } break;
+ case SERVER_TASK_TYPE_SLOT_SAVE:
+ {
+ int id_slot = task.slot_action.slot_id;
+ server_slot * slot = get_slot_by_id(id_slot);
+ if (slot == nullptr) {
+ send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+ if (slot->is_processing()) {
+ // if requested slot is unavailable, we defer this task for processing later
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
+ queue_tasks.defer(task);
+ break;
+ }
- res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
- res->t_prompt_processing = metrics.t_prompt_processing;
- res->n_tokens_predicted = metrics.n_tokens_predicted;
- res->t_tokens_generation = metrics.t_tokens_generation;
+ const size_t token_count = slot->cache_tokens.size();
+ const int64_t t_start = ggml_time_us();
- res->n_decode_total = metrics.n_decode_total;
- res->n_busy_slots_total = metrics.n_busy_slots_total;
+ std::string filename = task.slot_action.filename;
+ std::string filepath = task.slot_action.filepath;
- if (task.metrics_reset_bucket) {
- metrics.reset_bucket();
- }
- queue_results.send(std::move(res));
- } break;
- case SERVER_TASK_TYPE_SLOT_SAVE: {
- int id_slot = task.slot_action.slot_id;
- server_slot *slot = get_slot_by_id(id_slot);
- if (slot == nullptr) {
- send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
- break;
- }
- if (slot->is_processing()) {
- // if requested slot is unavailable, we defer this task for processing later
- SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
- queue_tasks.defer(task);
- break;
- }
+ const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
- const size_t token_count = slot->cache_tokens.size();
- const int64_t t_start = ggml_time_us();
-
- std::string filename = task.slot_action.filename;
- std::string filepath = task.slot_action.filepath;
-
- const size_t nwrite =
- llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
-
- const int64_t t_end = ggml_time_us();
- const double t_save_ms = (t_end - t_start) / 1000.0;
-
- auto res = std::make_unique();
- res->id = task.id;
- res->id_slot = id_slot;
- res->filename = filename;
- res->is_save = true;
- res->n_tokens = token_count;
- res->n_bytes = nwrite;
- res->t_ms = t_save_ms;
- queue_results.send(std::move(res));
- } break;
- case SERVER_TASK_TYPE_SLOT_RESTORE: {
- int id_slot = task.slot_action.slot_id;
- server_slot *slot = get_slot_by_id(id_slot);
- if (slot == nullptr) {
- send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
- break;
- }
- if (slot->is_processing()) {
- // if requested slot is unavailable, we defer this task for processing later
- SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
- queue_tasks.defer(task);
- break;
- }
+ const int64_t t_end = ggml_time_us();
+ const double t_save_ms = (t_end - t_start) / 1000.0;
- const int64_t t_start = ggml_time_us();
+ auto res = std::make_unique();
+ res->id = task.id;
+ res->id_slot = id_slot;
+ res->filename = filename;
+ res->is_save = true;
+ res->n_tokens = token_count;
+ res->n_bytes = nwrite;
+ res->t_ms = t_save_ms;
+ queue_results.send(std::move(res));
+ } break;
+ case SERVER_TASK_TYPE_SLOT_RESTORE:
+ {
+ int id_slot = task.slot_action.slot_id;
+ server_slot * slot = get_slot_by_id(id_slot);
+ if (slot == nullptr) {
+ send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+ if (slot->is_processing()) {
+ // if requested slot is unavailable, we defer this task for processing later
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
+ queue_tasks.defer(task);
+ break;
+ }
- std::string filename = task.slot_action.filename;
- std::string filepath = task.slot_action.filepath;
+ const int64_t t_start = ggml_time_us();
- slot->cache_tokens.resize(slot->n_ctx);
- size_t token_count = 0;
- size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(),
- slot->cache_tokens.size(), &token_count);
- if (nread == 0) {
- slot->cache_tokens.resize(0);
- send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file",
- ERROR_TYPE_INVALID_REQUEST);
- break;
- }
- slot->cache_tokens.resize(token_count);
-
- const int64_t t_end = ggml_time_us();
- const double t_restore_ms = (t_end - t_start) / 1000.0;
-
- auto res = std::make_unique();
- res->id = task.id;
- res->id_slot = id_slot;
- res->filename = filename;
- res->is_save = false;
- res->n_tokens = token_count;
- res->n_bytes = nread;
- res->t_ms = t_restore_ms;
- queue_results.send(std::move(res));
- } break;
- case SERVER_TASK_TYPE_SLOT_ERASE: {
- int id_slot = task.slot_action.slot_id;
- server_slot *slot = get_slot_by_id(id_slot);
- if (slot == nullptr) {
- send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
- break;
- }
- if (slot->is_processing()) {
- // if requested slot is unavailable, we defer this task for processing later
- SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
- queue_tasks.defer(task);
- break;
- }
+ std::string filename = task.slot_action.filename;
+ std::string filepath = task.slot_action.filepath;
- // Erase token cache
- const size_t n_erased = slot->cache_tokens.size();
- llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
- slot->cache_tokens.clear();
+ slot->cache_tokens.resize(slot->n_ctx);
+ size_t token_count = 0;
+ size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
+ if (nread == 0) {
+ slot->cache_tokens.resize(0);
+ send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+ slot->cache_tokens.resize(token_count);
+
+ const int64_t t_end = ggml_time_us();
+ const double t_restore_ms = (t_end - t_start) / 1000.0;
+
+ auto res = std::make_unique();
+ res->id = task.id;
+ res->id_slot = id_slot;
+ res->filename = filename;
+ res->is_save = false;
+ res->n_tokens = token_count;
+ res->n_bytes = nread;
+ res->t_ms = t_restore_ms;
+ queue_results.send(std::move(res));
+ } break;
+ case SERVER_TASK_TYPE_SLOT_ERASE:
+ {
+ int id_slot = task.slot_action.slot_id;
+ server_slot * slot = get_slot_by_id(id_slot);
+ if (slot == nullptr) {
+ send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+ if (slot->is_processing()) {
+ // if requested slot is unavailable, we defer this task for processing later
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
+ queue_tasks.defer(task);
+ break;
+ }
- auto res = std::make_unique();
- res->id = task.id;
- res->id_slot = id_slot;
- res->n_erased = n_erased;
- queue_results.send(std::move(res));
- } break;
- case SERVER_TASK_TYPE_SET_LORA: {
- params_base.lora_adapters = std::move(task.set_lora);
- auto res = std::make_unique();
- res->id = task.id;
- queue_results.send(std::move(res));
- } break;
+ // Erase token cache
+ const size_t n_erased = slot->cache_tokens.size();
+ llama_kv_self_seq_rm(ctx, slot->id, -1, -1);
+ slot->cache_tokens.clear();
+
+ auto res = std::make_unique();
+ res->id = task.id;
+ res->id_slot = id_slot;
+ res->n_erased = n_erased;
+ queue_results.send(std::move(res));
+ } break;
+ case SERVER_TASK_TYPE_SET_LORA:
+ {
+ params_base.lora_adapters = std::move(task.set_lora);
+ auto res = std::make_unique();
+ res->id = task.id;
+ queue_results.send(std::move(res));
+ } break;
}
}
@@ -2693,7 +2782,7 @@ struct server_context {
{
bool all_idle = true;
- for (auto &slot : slots) {
+ for (auto & slot : slots) {
if (slot.is_processing()) {
all_idle = false;
break;
@@ -2720,7 +2809,7 @@ struct server_context {
// apply context-shift if needed
// TODO: simplify and improve
- for (server_slot &slot : slots) {
+ for (server_slot & slot : slots) {
if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
if (!params_base.ctx_shift) {
// this check is redundant (for good)
@@ -2731,15 +2820,14 @@ struct server_context {
}
// Shift context
- const int n_keep = slot.params.n_keep + add_bos_token;
- const int n_left = slot.n_past - n_keep;
+ const int n_keep = slot.params.n_keep + add_bos_token;
+ const int n_left = slot.n_past - n_keep;
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
- SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left,
- n_discard);
+ SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
- llama_kv_cache_seq_rm(ctx, slot.id, n_keep, n_keep + n_discard);
- llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
+ llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
+ llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
@@ -2759,15 +2847,14 @@ struct server_context {
common_batch_clear(batch);
// track if given slot can be batched with slots already in the batch
- server_slot *slot_batched = nullptr;
+ server_slot * slot_batched = nullptr;
- auto accept_special_token = [&](server_slot &slot, llama_token token) {
- return params_base.special ||
- slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
+ auto accept_special_token = [&](server_slot & slot, llama_token token) {
+ return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
};
// frist, add sampled tokens from any ongoing sequences
- for (auto &slot : slots) {
+ for (auto & slot : slots) {
if (slot.state != SLOT_STATE_GENERATING) {
continue;
}
@@ -2781,7 +2868,7 @@ struct server_context {
slot.i_batch = batch.n_tokens;
- common_batch_add(batch, slot.sampled, slot.n_past, {slot.id}, true);
+ common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
slot.n_past += 1;
@@ -2790,16 +2877,16 @@ struct server_context {
}
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
- slot.n_ctx, slot.n_past, (int)slot.cache_tokens.size(), slot.truncated);
+ slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
}
// process in chunks of params.n_batch
- int32_t n_batch = llama_n_batch(ctx);
+ int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);
// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
- for (auto &slot : slots) {
+ for (auto & slot : slots) {
// check if we can batch this slot with the previous one
if (slot.is_processing()) {
if (!slot_batched) {
@@ -2811,7 +2898,7 @@ struct server_context {
// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
- auto &prompt_tokens = slot.prompt_tokens;
+ auto & prompt_tokens = slot.prompt_tokens;
// TODO: maybe move branch to outside of this loop in the future
if (slot.state == SLOT_STATE_STARTED) {
@@ -2822,21 +2909,18 @@ struct server_context {
slot.n_prompt_tokens = prompt_tokens.size();
slot.state = SLOT_STATE_PROCESSING_PROMPT;
- SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx,
- slot.params.n_keep, slot.n_prompt_tokens);
+ SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
// print prompt tokens (for debugging)
if (1) {
// first 16 tokens (avoid flooding logs)
for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) {
- SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i],
- common_token_to_piece(ctx, prompt_tokens[i]).c_str());
+ SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
}
} else {
// all
- for (int i = 0; i < (int)prompt_tokens.size(); i++) {
- SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i],
- common_token_to_piece(ctx, prompt_tokens[i]).c_str());
+ for (int i = 0; i < (int) prompt_tokens.size(); i++) {
+ SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
}
}
@@ -2853,15 +2937,13 @@ struct server_context {
if (slot.is_non_causal()) {
if (slot.n_prompt_tokens > n_ubatch) {
slot.release();
- send_error(slot, "input is too large to process. increase the physical batch size",
- ERROR_TYPE_SERVER);
+ send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
continue;
}
if (slot.n_prompt_tokens > slot.n_ctx) {
slot.release();
- send_error(slot, "input is larger than the max context size. skipping",
- ERROR_TYPE_SERVER);
+ send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER);
continue;
}
} else {
@@ -2871,10 +2953,7 @@ struct server_context {
// context shift should be applied only during the generation phase
if (slot.n_prompt_tokens >= slot.n_ctx) {
slot.release();
- send_error(slot,
- "the request exceeds the available context size. try increasing the "
- "context size or enable context shift",
- ERROR_TYPE_INVALID_REQUEST);
+ send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
continue;
}
}
@@ -2888,25 +2967,23 @@ struct server_context {
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_block_size = n_left / 2;
- const int erased_blocks =
- (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
+ const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
- llama_tokens new_tokens(prompt_tokens.begin(),
- prompt_tokens.begin() + slot.params.n_keep);
+ llama_tokens new_tokens(
+ prompt_tokens.begin(),
+ prompt_tokens.begin() + slot.params.n_keep);
- new_tokens.insert(new_tokens.end(),
- prompt_tokens.begin() + slot.params.n_keep +
- erased_blocks * n_block_size,
- prompt_tokens.end());
+ new_tokens.insert(
+ new_tokens.end(),
+ prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
+ prompt_tokens.end());
prompt_tokens = std::move(new_tokens);
slot.truncated = true;
slot.n_prompt_tokens = prompt_tokens.size();
- SLT_WRN(slot,
- "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n",
- slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);
+ SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
}
@@ -2920,33 +2997,29 @@ struct server_context {
size_t head_c = slot.n_past; // cache
size_t head_p = slot.n_past; // current prompt
- SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n",
- params_base.n_cache_reuse, slot.n_past);
+ SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
- while (head_c < slot.cache_tokens.size() && head_p < prompt_tokens.size()) {
+ while (head_c < slot.cache_tokens.size() &&
+ head_p < prompt_tokens.size()) {
size_t n_match = 0;
while (head_c + n_match < slot.cache_tokens.size() &&
- head_p + n_match < prompt_tokens.size() &&
+ head_p + n_match < prompt_tokens.size() &&
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
n_match++;
}
- if (n_match >= (size_t)params_base.n_cache_reuse) {
- SLT_INF(slot,
- "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> "
- "[%zu, %zu)\n",
- n_match, head_c, head_c + n_match, head_p, head_p + n_match);
- // for (size_t i = head_p; i < head_p + n_match; i++) {
- // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i],
- // common_token_to_piece(ctx, prompt_tokens[i]).c_str());
- // }
+ if (n_match >= (size_t) params_base.n_cache_reuse) {
+ SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
+ //for (size_t i = head_p; i < head_p + n_match; i++) {
+ // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
+ //}
- const int64_t kv_shift = (int64_t)head_p - (int64_t)head_c;
+ const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
- llama_kv_cache_seq_rm(ctx, slot.id, head_p, head_c);
- llama_kv_cache_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift);
+ llama_kv_self_seq_rm (ctx, slot.id, head_p, head_c);
+ llama_kv_self_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift);
for (size_t i = 0; i < n_match; i++) {
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
@@ -2967,10 +3040,7 @@ struct server_context {
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
// we have to evaluate at least 1 token to generate logits.
- SLT_WRN(slot,
- "need to evaluate at least 1 token to generate logits, n_past = %d, "
- "n_prompt_tokens = %d\n",
- slot.n_past, slot.n_prompt_tokens);
+ SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
slot.n_past--;
}
@@ -2987,9 +3057,9 @@ struct server_context {
}
// keep only the common part
- if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
+ if (!llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1)) {
// could not partially delete (likely using a non-Transformer model)
- llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
+ llama_kv_self_seq_rm(ctx, slot.id, -1, -1);
// there is no common part left
slot.n_past = 0;
@@ -3003,10 +3073,9 @@ struct server_context {
// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
// without pooling, we want to output the embeddings for all the tokens in the batch
- const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING &&
- llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
+ const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
- common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, {slot.id}, need_embd);
+ common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
@@ -3016,8 +3085,7 @@ struct server_context {
slot.n_past++;
}
- SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n",
- slot.n_past, batch.n_tokens, (float)slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
+ SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
// entire prompt has been processed
if (slot.n_past == slot.n_prompt_tokens) {
@@ -3036,7 +3104,7 @@ struct server_context {
batch.logits[batch.n_tokens - 1] = true;
slot.n_decoded = 0;
- slot.i_batch = batch.n_tokens - 1;
+ slot.i_batch = batch.n_tokens - 1;
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
}
@@ -3067,8 +3135,13 @@ struct server_context {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
llama_batch batch_view = {
- n_tokens, batch.token + i, nullptr, batch.pos + i,
- batch.n_seq_id + i, batch.seq_id + i, batch.logits + i,
+ n_tokens,
+ batch.token + i,
+ nullptr,
+ batch.pos + i,
+ batch.n_seq_id + i,
+ batch.seq_id + i,
+ batch.logits + i,
};
const int ret = llama_decode(ctx, batch_view);
@@ -3077,10 +3150,8 @@ struct server_context {
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
// if you get here, it means the KV cache is full - try increasing it via the context size
- SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i "
- "= %d, n_batch = %d, ret = %d\n",
- i, n_batch, ret);
- for (auto &slot : slots) {
+ SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
+ for (auto & slot : slots) {
slot.release();
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
}
@@ -3091,15 +3162,13 @@ struct server_context {
n_batch /= 2;
i -= n_batch;
- SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing "
- "it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n",
- i, n_batch, ret);
+ SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
continue; // continue loop of n_batch
}
- for (auto &slot : slots) {
- if (slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
+ for (auto & slot : slots) {
+ if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
continue; // continue loop of slots
}
@@ -3146,9 +3215,9 @@ struct server_context {
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
completion_token_output result;
- result.tok = id;
+ result.tok = id;
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
- result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
+ result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
if (slot.params.sampling.n_probs > 0) {
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
@@ -3165,7 +3234,7 @@ struct server_context {
}
// do speculative decoding
- for (auto &slot : slots) {
+ for (auto & slot : slots) {
if (!slot.is_processing() || !slot.can_speculate()) {
continue;
}
@@ -3188,8 +3257,7 @@ struct server_context {
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
if (n_draft_max < slot.params.speculative.n_min) {
- SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n",
- n_draft_max, slot.params.speculative.n_min);
+ SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
continue;
}
@@ -3197,25 +3265,25 @@ struct server_context {
llama_token id = slot.sampled;
struct common_speculative_params params_spec;
- params_spec.n_draft = n_draft_max;
- params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
- params_spec.p_min = slot.params.speculative.p_min;
+ params_spec.n_draft = n_draft_max;
+ params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
+ params_spec.p_min = slot.params.speculative.p_min;
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
// ignore small drafts
- if (slot.params.speculative.n_min > (int)draft.size()) {
- SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min);
+ if (slot.params.speculative.n_min > (int) draft.size()) {
+ SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
continue;
}
// construct the speculation batch
common_batch_clear(slot.batch_spec);
- common_batch_add(slot.batch_spec, id, slot.n_past, {slot.id}, true);
+ common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
for (size_t i = 0; i < draft.size(); ++i) {
- common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, {slot.id}, true);
+ common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
}
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
@@ -3225,21 +3293,20 @@ struct server_context {
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
- slot.n_past += ids.size();
+ slot.n_past += ids.size();
slot.n_decoded += ids.size();
slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
- llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
+ llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
- result.tok = ids[i];
- result.text_to_send =
- common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
- result.prob = 1.0f; // set later
+ result.tok = ids[i];
+ result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
+ result.prob = 1.0f; // set later
// TODO: set result.probs
@@ -3253,8 +3320,7 @@ struct server_context {
}
}
- SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int)ids.size() - 1, (int)draft.size(),
- slot.n_past);
+ SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
}
}
@@ -3262,14 +3328,31 @@ struct server_context {
}
json model_meta() const {
- return json{
- {"vocab_type", llama_vocab_type(vocab)}, {"n_vocab", llama_vocab_n_tokens(vocab)},
- {"n_ctx_train", llama_model_n_ctx_train(model)}, {"n_embd", llama_model_n_embd(model)},
- {"n_params", llama_model_n_params(model)}, {"size", llama_model_size(model)},
+ return json {
+ {"vocab_type", llama_vocab_type (vocab)},
+ {"n_vocab", llama_vocab_n_tokens (vocab)},
+ {"n_ctx_train", llama_model_n_ctx_train(model)},
+ {"n_embd", llama_model_n_embd (model)},
+ {"n_params", llama_model_n_params (model)},
+ {"size", llama_model_size (model)},
};
}
};
+std::function shutdown_handler;
+std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
+
+inline void signal_handler(int signal) {
+ if (is_terminating.test_and_set()) {
+ // in case it hangs, we can force terminate the server by hitting Ctrl+C twice
+ // this is for better developer experience, we can remove when the server is stable enough
+ fprintf(stderr, "Received second interrupt, terminating immediately.\n");
+ exit(1);
+ }
+
+ shutdown_handler(signal);
+}
+
static void common_params_handle_model_default(std::string &model, const std::string &model_url, std::string &hf_repo,
std::string &hf_file, const std::string &hf_token) {
if (!hf_repo.empty()) {
diff --git a/src/main/cpp/utils.hpp b/src/main/cpp/utils.hpp
index 603424b..ca0a327 100644
--- a/src/main/cpp/utils.hpp
+++ b/src/main/cpp/utils.hpp
@@ -48,14 +48,13 @@ using json = nlohmann::ordered_json;
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
-template static T json_value(const json &body, const std::string &key, const T &default_value) {
+template static T json_value(const json & body, const std::string & key, const T & default_value) {
// Fallback null to default value
if (body.contains(key) && !body.at(key).is_null()) {
try {
return body.at(key);
} catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) {
- LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value\n", key.c_str(),
- json(default_value).type_name());
+ LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value\n", key.c_str(), json(default_value).type_name());
return default_value;
}
} else {
@@ -69,9 +68,9 @@ const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "
// tokenizer and input processing utils
//
-static bool json_is_array_of_numbers(const json &data) {
+static bool json_is_array_of_numbers(const json & data) {
if (data.is_array()) {
- for (const auto &e : data) {
+ for (const auto & e : data) {
if (!e.is_number_integer()) {
return false;
}
@@ -82,11 +81,11 @@ static bool json_is_array_of_numbers(const json &data) {
}
// is array having BOTH numbers & strings?
-static bool json_is_array_of_mixed_numbers_strings(const json &data) {
+static bool json_is_array_of_mixed_numbers_strings(const json & data) {
bool seen_string = false;
bool seen_number = false;
if (data.is_array()) {
- for (const auto &e : data) {
+ for (const auto & e : data) {
seen_string |= e.is_string();
seen_number |= e.is_number_integer();
if (seen_number && seen_string) {
@@ -98,14 +97,14 @@ static bool json_is_array_of_mixed_numbers_strings(const json &data) {
}
// get value by path(key1 / key2)
-static json json_get_nested_values(const std::vector &paths, const json &js) {
+static json json_get_nested_values(const std::vector & paths, const json & js) {
json result = json::object();
- for (const std::string &path : paths) {
+ for (const std::string & path : paths) {
json current = js;
const auto keys = string_split(path, /*separator*/ '/');
bool valid_path = true;
- for (const std::string &k : keys) {
+ for (const std::string & k : keys) {
if (valid_path && current.is_object() && current.contains(k)) {
current = current[k];
} else {
@@ -124,15 +123,14 @@ static json json_get_nested_values(const std::vector &paths, const
* - only string, example: "string"
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
*/
-static llama_tokens tokenize_mixed(const llama_vocab *vocab, const json &json_prompt, bool add_special,
- bool parse_special) {
+static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string.
llama_tokens prompt_tokens;
if (json_prompt.is_array()) {
bool first = true;
- for (const auto &p : json_prompt) {
+ for (const auto & p : json_prompt) {
if (p.is_string()) {
auto s = p.template get();
@@ -173,8 +171,7 @@ static llama_tokens tokenize_mixed(const llama_vocab *vocab, const json &json_pr
* - "prompt": [[12, 34, 56], [78, 90, 12]]
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
*/
-static std::vector tokenize_input_prompts(const llama_vocab *vocab, const json &json_prompt,
- bool add_special, bool parse_special) {
+static std::vector tokenize_input_prompts(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
std::vector result;
if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
// string or mixed
@@ -185,20 +182,18 @@ static std::vector tokenize_input_prompts(const llama_vocab *vocab
} else if (json_prompt.is_array()) {
// array of prompts
result.reserve(json_prompt.size());
- for (const auto &p : json_prompt) {
+ for (const auto & p : json_prompt) {
if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
result.push_back(tokenize_mixed(vocab, p, add_special, parse_special));
} else if (json_is_array_of_numbers(p)) {
// array of tokens
result.push_back(p.get());
} else {
- throw std::runtime_error(
- "element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens");
+ throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens");
}
}
} else {
- throw std::runtime_error(
- "\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
+ throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
}
if (result.empty()) {
throw std::runtime_error("\"prompt\" must not be empty");
@@ -209,10 +204,9 @@ static std::vector tokenize_input_prompts(const llama_vocab *vocab
// return the last index of character that can form a valid string
// if the last character is potentially cut in half, return the index before the cut
// if validate_utf8(text) == text.size(), then the whole text is valid utf8
-static size_t validate_utf8(const std::string &text) {
+static size_t validate_utf8(const std::string& text) {
size_t len = text.size();
- if (len == 0)
- return 0;
+ if (len == 0) return 0;
// Check the last few bytes to see if a multi-byte character is cut off
for (size_t i = 1; i <= 4 && i <= len; ++i) {
@@ -221,18 +215,15 @@ static size_t validate_utf8(const std::string &text) {
if ((c & 0xE0) == 0xC0) {
// 2-byte character start: 110xxxxx
// Needs at least 2 bytes
- if (i < 2)
- return len - i;
+ if (i < 2) return len - i;
} else if ((c & 0xF0) == 0xE0) {
// 3-byte character start: 1110xxxx
// Needs at least 3 bytes
- if (i < 3)
- return len - i;
+ if (i < 3) return len - i;
} else if ((c & 0xF8) == 0xF0) {
// 4-byte character start: 11110xxx
// Needs at least 4 bytes
- if (i < 4)
- return len - i;
+ if (i < 4) return len - i;
}
}
@@ -245,7 +236,7 @@ static size_t validate_utf8(const std::string &text) {
//
// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
-static llama_tokens format_rerank(const struct llama_vocab *vocab, const llama_tokens &query, const llama_tokens &doc) {
+static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
llama_tokens result;
result.reserve(doc.size() + query.size() + 4);
@@ -260,9 +251,17 @@ static llama_tokens format_rerank(const struct llama_vocab *vocab, const llama_t
}
// format infill task
-static llama_tokens format_infill(const llama_vocab *vocab, const json &input_prefix, const json &input_suffix,
- const json &input_extra, const int n_batch, const int n_predict, const int n_ctx,
- const bool spm_infill, const llama_tokens &tokens_prompt) {
+static llama_tokens format_infill(
+ const llama_vocab * vocab,
+ const json & input_prefix,
+ const json & input_suffix,
+ const json & input_extra,
+ const int n_batch,
+ const int n_predict,
+ const int n_ctx,
+ const bool spm_infill,
+ const llama_tokens & tokens_prompt
+ ) {
// TODO: optimize this block by reducing memory allocations and movement
// use FIM repo-level pattern:
@@ -290,9 +289,9 @@ static llama_tokens format_infill(const llama_vocab *vocab, const json &input_pr
extra_tokens.push_back(llama_vocab_fim_rep(vocab));
extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
}
- for (const auto &chunk : input_extra) {
+ for (const auto & chunk : input_extra) {
// { "text": string, "filename": string }
- const std::string text = json_value(chunk, "text", std::string());
+ const std::string text = json_value(chunk, "text", std::string());
const std::string filename = json_value(chunk, "filename", std::string("tmp"));
if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
@@ -302,8 +301,7 @@ static llama_tokens format_infill(const llama_vocab *vocab, const json &input_pr
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
} else {
// chunk separator in binary form to avoid confusing the AI
- static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70,
- 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
+ static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false);
extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
@@ -322,21 +320,19 @@ static llama_tokens format_infill(const llama_vocab *vocab, const json &input_pr
}
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
- const int n_prefix_take = std::min(tokens_prefix.size(), 3 * (n_batch / 4));
- const int n_suffix_take =
- std::min(tokens_suffix.size(), std::max(0, (n_batch / 4) - (2 + tokens_prompt.size())));
+ const int n_prefix_take = std::min(tokens_prefix.size(), 3*(n_batch/4));
+ const int n_suffix_take = std::min(tokens_suffix.size(), std::max(0, (n_batch/4) - (2 + tokens_prompt.size())));
- SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take,
- (n_prefix_take + n_suffix_take));
+ SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take));
// fill the rest of the context with extra chunks
- const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch)-2 * n_predict), extra_tokens.size());
+ const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size());
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
tokens_suffix.resize(n_suffix_take);
tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab));
- tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
+ tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab));
auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
@@ -346,7 +342,7 @@ static llama_tokens format_infill(const llama_vocab *vocab, const json &input_pr
embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab));
}
- SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int)extra_tokens.size());
+ SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
// put the extra context before the FIM prefix
embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
@@ -361,13 +357,16 @@ static llama_tokens format_infill(const llama_vocab *vocab, const json &input_pr
// base64 utils (TODO: move to common in the future)
//
-static const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
- "abcdefghijklmnopqrstuvwxyz"
- "0123456789+/";
+static const std::string base64_chars =
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "abcdefghijklmnopqrstuvwxyz"
+ "0123456789+/";
-static inline bool is_base64(uint8_t c) { return (isalnum(c) || (c == '+') || (c == '/')); }
+static inline bool is_base64(uint8_t c) {
+ return (isalnum(c) || (c == '+') || (c == '/'));
+}
-static inline std::vector base64_decode(const std::string &encoded_string) {
+static inline std::vector base64_decode(const std::string & encoded_string) {
int i = 0;
int j = 0;
int in_ = 0;
@@ -380,16 +379,15 @@ static inline std::vector base64_decode(const std::string &encoded_stri
std::vector ret;
while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
- char_array_4[i++] = encoded_string[in_];
- in_++;
+ char_array_4[i++] = encoded_string[in_]; in_++;
if (i == 4) {
for (i = 0; i < 4; i++) {
char_array_4[i] = base64_chars.find(char_array_4[i]);
}
- char_array_3[0] = ((char_array_4[0]) << 2) + ((char_array_4[1] & 0x30) >> 4);
+ char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
- char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
+ char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (i = 0; (i < 3); i++) {
ret.push_back(char_array_3[i]);
@@ -408,9 +406,9 @@ static inline std::vector base64_decode(const std::string &encoded_stri
char_array_4[j] = base64_chars.find(char_array_4[j]);
}
- char_array_3[0] = ((char_array_4[0]) << 2) + ((char_array_4[1] & 0x30) >> 4);
+ char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
- char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
+ char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (j = 0; j < i - 1; j++) {
ret.push_back(char_array_3[j]);
@@ -439,13 +437,19 @@ static std::string random_string() {
return result;
}
-static std::string gen_chatcmplid() { return "chatcmpl-" + random_string(); }
+static std::string gen_chatcmplid() {
+ return "chatcmpl-" + random_string();
+}
+
+static std::string gen_tool_call_id() {
+ return random_string();
+}
//
// other common utils
//
-static bool ends_with(const std::string &str, const std::string &suffix) {
+static bool ends_with(const std::string & str, const std::string & suffix) {
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}
@@ -466,7 +470,8 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin
}
// TODO: reuse llama_detokenize
-template static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) {
+template
+static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
std::string ret;
for (; begin != end; ++begin) {
ret += common_token_to_piece(ctx, *begin);
@@ -476,7 +481,7 @@ template static std::string tokens_to_str(llama_context *ctx, Iter
}
// format incomplete utf-8 multibyte character for output
-static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token) {
+static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) {
std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token);
// if the size is 1 and first bit is 1, meaning it's a partial character
@@ -491,22 +496,22 @@ static std::string tokens_to_output_formatted_string(const llama_context *ctx, c
return out;
}
-// static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) {
-// const std::string str =
-// std::string(event) + ": " +
-// data.dump(-1, ' ', false, json::error_handler_t::replace) +
-// "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
+//static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) {
+// const std::string str =
+// std::string(event) + ": " +
+// data.dump(-1, ' ', false, json::error_handler_t::replace) +
+// "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
//
-// LOG_DBG("data stream, to_send: %s", str.c_str());
+// LOG_DBG("data stream, to_send: %s", str.c_str());
//
-// return sink.write(str.c_str(), str.size());
-// }
+// return sink.write(str.c_str(), str.size());
+//}
//
// OAI utils
//
-static json oaicompat_completion_params_parse(const json &body) {
+static json oaicompat_completion_params_parse(const json & body) {
json llama_params;
if (!body.contains("prompt")) {
@@ -532,15 +537,15 @@ static json oaicompat_completion_params_parse(const json &body) {
}
// Params supported by OAI but unsupported by llama.cpp
- static const std::vector unsupported_params{"best_of", "suffix"};
- for (const auto ¶m : unsupported_params) {
+ static const std::vector unsupported_params { "best_of", "suffix" };
+ for (const auto & param : unsupported_params) {
if (body.contains(param)) {
throw std::runtime_error("Unsupported param: " + param);
}
}
// Copy remaining properties to llama_params
- for (const auto &item : body.items()) {
+ for (const auto & item : body.items()) {
// Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
llama_params[item.key()] = item.value();
@@ -550,9 +555,12 @@ static json oaicompat_completion_params_parse(const json &body) {
return llama_params;
}
-static json oaicompat_completion_params_parse(const json &body, /* openai api json semantics */
- bool use_jinja, common_reasoning_format reasoning_format,
- const struct common_chat_templates *tmpls) {
+static json oaicompat_completion_params_parse(
+ const json & body, /* openai api json semantics */
+ bool use_jinja,
+ common_reasoning_format reasoning_format,
+ const struct common_chat_templates * tmpls)
+{
json llama_params;
auto tools = json_value(body, "tools", json());
@@ -587,7 +595,7 @@ static json oaicompat_completion_params_parse(const json &body, /* openai api js
// Handle "response_format" field
if (body.contains("response_format")) {
- json response_format = json_value(body, "response_format", json::object());
+ json response_format = json_value(body, "response_format", json::object());
std::string response_type = json_value(response_format, "type", std::string());
if (response_type == "json_object") {
json_schema = json_value(response_format, "schema", json::object());
@@ -595,21 +603,20 @@ static json oaicompat_completion_params_parse(const json &body, /* openai api js
auto schema_wrapper = json_value(response_format, "json_schema", json::object());
json_schema = json_value(schema_wrapper, "schema", json::object());
} else if (!response_type.empty() && response_type != "text") {
- throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " +
- response_type);
+ throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
}
}
common_chat_templates_inputs inputs;
- inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages"));
- inputs.tools = common_chat_tools_parse_oaicompat(tools);
- inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
- inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
- inputs.grammar = grammar;
+ inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages"));
+ inputs.tools = common_chat_tools_parse_oaicompat(tools);
+ inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
+ inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
+ inputs.grammar = grammar;
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
- inputs.use_jinja = use_jinja;
- inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
- inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
+ inputs.use_jinja = use_jinja;
+ inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
+ inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
@@ -618,17 +625,19 @@ static json oaicompat_completion_params_parse(const json &body, /* openai api js
// Apply chat template to the list of messages
auto chat_params = common_chat_templates_apply(tmpls, inputs);
- llama_params["chat_format"] = static_cast(chat_params.format);
- llama_params["prompt"] = chat_params.prompt;
- llama_params["grammar"] = chat_params.grammar;
- llama_params["grammar_lazy"] = chat_params.grammar_lazy;
+ llama_params["chat_format"] = static_cast(chat_params.format);
+ llama_params["prompt"] = chat_params.prompt;
+ if (!chat_params.grammar.empty()) {
+ llama_params["grammar"] = chat_params.grammar;
+ }
+ llama_params["grammar_lazy"] = chat_params.grammar_lazy;
auto grammar_triggers = json::array();
- for (const auto &trigger : chat_params.grammar_triggers) {
+ for (const auto & trigger : chat_params.grammar_triggers) {
grammar_triggers.push_back(trigger.to_json());
}
llama_params["grammar_triggers"] = grammar_triggers;
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
- for (const auto &stop : chat_params.additional_stops) {
+ for (const auto & stop : chat_params.additional_stops) {
llama_params["stop"].push_back(stop);
}
@@ -639,8 +648,7 @@ static json oaicompat_completion_params_parse(const json &body, /* openai api js
}
// Handle "logprobs" field
- // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may
- // need to fix it in the future
+ // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
if (json_value(body, "logprobs", false)) {
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
@@ -650,7 +658,7 @@ static json oaicompat_completion_params_parse(const json &body, /* openai api js
// Copy remaining properties to llama_params
// This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint.
// See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
- for (const auto &item : body.items()) {
+ for (const auto & item : body.items()) {
// Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
llama_params[item.key()] = item.value();
@@ -660,46 +668,59 @@ static json oaicompat_completion_params_parse(const json &body, /* openai api js
return llama_params;
}
-static json format_embeddings_response_oaicompat(const json &request, const json &embeddings, bool use_base64 = false) {
+static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) {
json data = json::array();
int32_t n_tokens = 0;
int i = 0;
- for (const auto &elem : embeddings) {
+ for (const auto & elem : embeddings) {
json embedding_obj;
if (use_base64) {
- const auto &vec = json_value(elem, "embedding", json::array()).get>();
- const char *data_ptr = reinterpret_cast(vec.data());
+ const auto& vec = json_value(elem, "embedding", json::array()).get>();
+ const char* data_ptr = reinterpret_cast(vec.data());
size_t data_size = vec.size() * sizeof(float);
- embedding_obj = {{"embedding", base64::encode(data_ptr, data_size)},
- {"index", i++},
- {"object", "embedding"},
- {"encoding_format", "base64"}};
+ embedding_obj = {
+ {"embedding", base64::encode(data_ptr, data_size)},
+ {"index", i++},
+ {"object", "embedding"},
+ {"encoding_format", "base64"}
+ };
} else {
embedding_obj = {
- {"embedding", json_value(elem, "embedding", json::array())}, {"index", i++}, {"object", "embedding"}};
+ {"embedding", json_value(elem, "embedding", json::array())},
+ {"index", i++},
+ {"object", "embedding"}
+ };
}
data.push_back(embedding_obj);
n_tokens += json_value(elem, "tokens_evaluated", 0);
}
- json res = json{{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
- {"object", "list"},
- {"usage", json{{"prompt_tokens", n_tokens}, {"total_tokens", n_tokens}}},
- {"data", data}};
+ json res = json {
+ {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
+ {"object", "list"},
+ {"usage", json {
+ {"prompt_tokens", n_tokens},
+ {"total_tokens", n_tokens}
+ }},
+ {"data", data}
+ };
return res;
}
-static json format_response_rerank(const json &request, const json &ranks, bool is_tei_format,
- std::vector &texts) {
+static json format_response_rerank(
+ const json & request,
+ const json & ranks,
+ bool is_tei_format,
+ std::vector & texts) {
json res;
if (is_tei_format) {
// TEI response format
res = json::array();
bool return_text = json_value(request, "return_text", false);
- for (const auto &rank : ranks) {
+ for (const auto & rank : ranks) {
int index = json_value(rank, "index", 0);
json elem = json{
{"index", index},
@@ -714,27 +735,32 @@ static json format_response_rerank(const json &request, const json &ranks, bool
// Jina response format
json results = json::array();
int32_t n_tokens = 0;
- for (const auto &rank : ranks) {
+ for (const auto & rank : ranks) {
results.push_back(json{
- {"index", json_value(rank, "index", 0)},
+ {"index", json_value(rank, "index", 0)},
{"relevance_score", json_value(rank, "score", 0.0)},
});
n_tokens += json_value(rank, "tokens_evaluated", 0);
}
- res = json{{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
- {"object", "list"},
- {"usage", json{{"prompt_tokens", n_tokens}, {"total_tokens", n_tokens}}},
- {"results", results}};
+ res = json{
+ {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
+ {"object", "list"},
+ {"usage", json{
+ {"prompt_tokens", n_tokens},
+ {"total_tokens", n_tokens}
+ }},
+ {"results", results}
+ };
}
return res;
}
-static bool is_valid_utf8(const std::string &str) {
- const unsigned char *bytes = reinterpret_cast(str.data());
- const unsigned char *end = bytes + str.length();
+static bool is_valid_utf8(const std::string & str) {
+ const unsigned char* bytes = reinterpret_cast(str.data());
+ const unsigned char* end = bytes + str.length();
while (bytes < end) {
if (*bytes <= 0x7F) {
@@ -752,7 +778,8 @@ static bool is_valid_utf8(const std::string &str) {
bytes += 3;
} else if ((*bytes & 0xF8) == 0xF0) {
// 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx)
- if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80)
+ if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 ||
+ (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80)
return false;
bytes += 4;
} else {
@@ -764,13 +791,21 @@ static bool is_valid_utf8(const std::string &str) {
return true;
}
-static json format_tokenizer_response(const json &tokens) { return json{{"tokens", tokens}}; }
+static json format_tokenizer_response(const json & tokens) {
+ return json {
+ {"tokens", tokens}
+ };
+}
-static json format_detokenized_response(const std::string &content) { return json{{"content", content}}; }
+static json format_detokenized_response(const std::string & content) {
+ return json {
+ {"content", content}
+ };
+}
-static json format_logit_bias(const std::vector &logit_bias) {
+static json format_logit_bias(const std::vector & logit_bias) {
json data = json::array();
- for (const auto &lb : logit_bias) {
+ for (const auto & lb : logit_bias) {
data.push_back(json{
{"bias", lb.bias},
{"token", lb.token},
@@ -779,16 +814,16 @@ static json format_logit_bias(const std::vector &logit_bias) {
return data;
}
-static std::string safe_json_to_str(const json &data) {
+static std::string safe_json_to_str(const json & data) {
return data.dump(-1, ' ', false, json::error_handler_t::replace);
}
-static std::vector get_token_probabilities(llama_context *ctx, int idx) {
+static std::vector get_token_probabilities(llama_context * ctx, int idx) {
std::vector cur;
- const auto *logits = llama_get_logits_ith(ctx, idx);
+ const auto * logits = llama_get_logits_ith(ctx, idx);
- const llama_model *model = llama_get_model(ctx);
- const llama_vocab *vocab = llama_model_get_vocab(model);
+ const llama_model * model = llama_get_model(ctx);
+ const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_vocab_n_tokens(vocab);
@@ -798,8 +833,9 @@ static std::vector get_token_probabilities(llama_context *ctx,
}
// sort tokens by logits
- std::sort(cur.begin(), cur.end(),
- [](const llama_token_data &a, const llama_token_data &b) { return a.logit > b.logit; });
+ std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) {
+ return a.logit > b.logit;
+ });
// apply softmax
float max_l = cur[0].logit;
@@ -816,8 +852,9 @@ static std::vector get_token_probabilities(llama_context *ctx,
return cur;
}
-static bool are_lora_equal(const std::vector &l1,
- const std::vector &l2) {
+static bool are_lora_equal(
+ const std::vector & l1,
+ const std::vector & l2) {
if (l1.size() != l2.size()) {
return false;
}
@@ -831,19 +868,20 @@ static bool are_lora_equal(const std::vector &l1,
}
// parse lora config from JSON request, returned a copy of lora_base with updated scale
-static std::vector parse_lora_request(const std::vector &lora_base,
- const json &data) {
+static std::vector parse_lora_request(
+ const std::vector & lora_base,
+ const json & data) {
std::vector lora(lora_base);
int max_idx = lora.size();
// clear existing value
- for (auto &entry : lora) {
+ for (auto & entry : lora) {
entry.scale = 0.0f;
}
// set value
- for (const auto &entry : data) {
- int id = json_value(entry, "id", -1);
+ for (const auto & entry : data) {
+ int id = json_value(entry, "id", -1);
float scale = json_value(entry, "scale", 0.0f);
if (0 <= id && id < max_idx) {
lora[id].scale = scale;