From 36092aa1d70cf87f8ee8fffcfc56ae4ab7d72c89 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Sat, 15 Feb 2025 10:24:05 +0000 Subject: [PATCH 01/41] address clang-tidy lints --- common/chat.cpp | 36 ++++++++++++++++++------------------ common/chat.hpp | 2 +- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index f21a9d2a63a4b..37a6e8981f126 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -38,22 +38,22 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons json_error_locator() : position(0), found_error(false) {} - bool parse_error(std::size_t position, const std::string &, const json::exception &) override { + bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT this->position = position - 1; this->found_error = true; return false; } - bool null() override { return true; } - bool boolean(bool) override { return true; } - bool number_integer(number_integer_t) override { return true; } - bool number_unsigned(number_unsigned_t) override { return true; } - bool number_float(number_float_t, const string_t &) override { return true; } - bool string(string_t &) override { return true; } - bool binary(binary_t &) override { return true; } - bool start_object(std::size_t) override { return true; } - bool key(string_t &) override { return true; } + bool null() override { return true; } // NOLINT + bool boolean(bool) override { return true; } // NOLINT + bool number_integer(number_integer_t) override { return true; } // NOLINT + bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT + bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT + bool string(string_t &) override { return true; } // NOLINT + bool binary(binary_t &) override { return true; } // NOLINT + bool start_object(std::size_t) override { return true; } // NOLINT + bool key(string_t &) override { return true; } // NOLINT bool end_object() override { return true; } - bool start_array(std::size_t) override { return true; } + bool start_array(std::size_t) override { return true; } // NOLINT bool end_array() override { return true; } }; json_error_locator err_loc; @@ -455,10 +455,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame const auto & parameters_required = parameters.at("required"); for (const auto & prop : expected_properties) { if (!parameters_properties.contains(prop)) { - throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); + throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT } if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) { - throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); + throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT } } if (parameters_properties.size() != expected_properties.size()) { @@ -474,10 +474,8 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com std::vector tool_rules; auto handle_builtin_tool = [&](const std::string & name, const json & parameters) { - if (name == "wolfram_alpha") { + if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py - expect_tool_parameters(name, parameters, {"query"}); - } else if (name == "web_search" || name == "brave_search") { // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py expect_tool_parameters(name, parameters, {"query"}); } else if (name == "python" || name == "code_interpreter") { @@ -489,7 +487,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com std::vector kvs; for (const auto & [key, value] : parameters.at("properties").items()) { - kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); + kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT } tool_rules.push_back( @@ -588,6 +586,7 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ const auto & function = tool.at("function"); std::string name = function.at("name"); auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); auto args_rule = builder.add_schema(name + "-args", parameters); tool_rules.push_back(builder.add_rule(name + "-call", "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n" @@ -727,6 +726,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ const auto & function = tool.at("function"); std::string name = function.at("name"); auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); auto args_rule = builder.add_schema(name + "-args", parameters); first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); @@ -814,7 +814,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con throw std::runtime_error("Missing type in python tool"); } has_raw_python = true; - auto type = parameters.at("type"); + const auto & type = parameters.at("type"); if (type == "object") { auto properties = parameters.at("properties"); for (auto it = properties.begin(); it != properties.end(); ++it) { diff --git a/common/chat.hpp b/common/chat.hpp index ba1632f669cf7..4147c80cdbafc 100644 --- a/common/chat.hpp +++ b/common/chat.hpp @@ -50,6 +50,6 @@ struct common_chat_params { std::vector additional_stops; }; -struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params); +struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs); std::string common_chat_format_name(common_chat_format format); common_chat_msg common_chat_parse( const std::string & input, common_chat_format format); From ef9b91ac8a69126668abe31a8040b15645621652 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Sat, 15 Feb 2025 10:36:32 +0000 Subject: [PATCH 02/41] tool-call: massive refactoring --- common/arg.cpp | 1 + common/chat.cpp | 453 ++++++++++++++++++++++++++------ common/chat.hpp | 86 +++++- common/common.cpp | 168 ------------ common/common.h | 56 ---- examples/main/main.cpp | 22 +- examples/run/run.cpp | 69 ++--- examples/server/server.cpp | 59 ++--- examples/server/utils.hpp | 150 ++++++----- tests/test-chat-template.cpp | 51 ++-- tests/test-chat.cpp | 494 +++++++++++++++-------------------- 11 files changed, 825 insertions(+), 784 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index a4d65ad00f675..c7a8a2bbf84af 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2,6 +2,7 @@ #include "log.h" #include "sampling.h" +#include "chat.hpp" #include #include diff --git a/common/chat.cpp b/common/chat.cpp index 37a6e8981f126..c6d8dbb394927 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1,9 +1,184 @@ #include "chat.hpp" +#include #include "chat-template.hpp" #include "json-schema-to-grammar.h" #include "log.h" #include "minja.hpp" +namespace minja { + class chat_template; +} + +typedef minja::chat_template common_chat_template; + +struct common_chat_templates { + bool has_explicit_template; // Model had builtin template or template overridde was specified. + std::unique_ptr template_default; // always set (defaults to chatml) + std::unique_ptr template_tool_use; +}; + +struct templates_params { + json messages; + json tools; + common_chat_tool_choice tool_choice; + json json_schema; + bool parallel_tool_calls; + bool stream; + std::string grammar; + bool add_generation_prompt = true; + bool extract_reasoning = true; +}; + +bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { + if (use_jinja) { + try { + common_chat_msg msg; + msg.role = "user"; + msg.content = "test"; + + auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl); + + common_chat_templates_inputs inputs; + inputs.messages = {msg}; + + common_chat_templates_apply(tmpls, inputs); + return true; + } catch (const std::exception & e) { + LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); + return false; + } + } + llama_chat_message chat[] = {{"user", "test"}}; + const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0); + return res >= 0; +} + +std::string common_chat_format_single( + const struct common_chat_templates * tmpls, + const std::vector & past_msg, + const common_chat_msg & new_msg, + bool add_ass, + bool use_jinja) { + + common_chat_templates_inputs inputs; + inputs.use_jinja = use_jinja; + + std::string fmt_past_msg; + if (!past_msg.empty()) { + inputs.messages = past_msg; + inputs.add_generation_prompt = false; + fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt; + } + std::ostringstream ss; + // if the past_msg ends with a newline, we must preserve it in the formatted version + if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { + ss << "\n"; + }; + // format chat with new_msg + inputs.messages.push_back(new_msg); + inputs.add_generation_prompt = add_ass; + auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt; + // get the diff part + ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); + return ss.str(); +} + +std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) { + common_chat_templates_inputs inputs; + inputs.use_jinja = use_jinja; + inputs.messages = { + {"system", "You are a helpful assistant", {}, {}, ""}, + {"user", "Hello", {}, {}, ""}, + {"assistant", "Hi there", {}, {}, ""}, + {"user", "How are you?", {}, {}, ""}, + }; + return common_chat_templates_apply(tmpls, inputs).prompt; +} + +#define CHATML_TEMPLATE_SRC \ + "{%- for message in messages -%}\n" \ + " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- if add_generation_prompt -%}\n" \ + " {{- '<|im_start|>assistant\n' -}}\n" \ + "{%- endif -%}" + +void common_chat_templates_free(struct common_chat_templates * tmpls) { + delete tmpls; +} + +bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) { + return tmpls->has_explicit_template; +} + +struct common_chat_templates * common_chat_templates_init( + const struct llama_model * model, + const std::string & chat_template_override, + const std::string & bos_token_override, + const std::string & eos_token_override) +{ + std::string default_template_src; + std::string template_tool_use_src; + + bool has_explicit_template = !chat_template_override.empty(); + if (chat_template_override.empty()) { + GGML_ASSERT(model != nullptr); + auto str = llama_model_chat_template(model, /* name */ nullptr); + if (str) { + default_template_src = str; + has_explicit_template = true; + } + str = llama_model_chat_template(model, /* name */ "tool_use"); + if (str) { + template_tool_use_src = str; + has_explicit_template = true; + } + } else { + default_template_src = chat_template_override; + } + if (default_template_src.empty() || default_template_src == "chatml") { + if (!template_tool_use_src.empty()) { + default_template_src = template_tool_use_src; + } else { + default_template_src = CHATML_TEMPLATE_SRC; + } + } + std::string token_bos = bos_token_override; + std::string token_eos = eos_token_override; + if (model) { + auto vocab = llama_model_get_vocab(model); + const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { + if (token == LLAMA_TOKEN_NULL) { + if (default_template_src.find(jinja_variable_name) != std::string::npos + || template_tool_use_src.find(jinja_variable_name) != std::string::npos) { + LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name); + } + return std::string(); + } else { + return common_token_to_piece(vocab, token, true); + } + }; + token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); + token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); + } + auto tmpls = new common_chat_templates(); + tmpls->has_explicit_template = has_explicit_template; + try { + tmpls->template_default = std::make_unique(default_template_src, token_bos, token_eos); + } catch (const std::exception & e) { + LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what()); + tmpls->template_default = std::make_unique(CHATML_TEMPLATE_SRC, token_bos, token_eos); + } + if (!template_tool_use_src.empty()) { + try { + tmpls->template_tool_use = std::make_unique(template_tool_use_src, token_bos, token_eos); + } catch (const std::exception & e) { + LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what()); + } + } + return tmpls; +} + std::string common_chat_format_name(common_chat_format format) { switch (format) { case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only"; @@ -193,7 +368,7 @@ static std::string apply( return tmpl.apply(tmpl_inputs, tmpl_opts); } -static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; auto tool_call_schemas = json::array(); @@ -247,7 +422,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp {"required", json::array({"tool_call"})}, }; const auto schema = - inputs.tool_choice != "required" + inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED ? json { {"anyOf", json::array({ tool_call, @@ -303,9 +478,9 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) { return result; } -static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); foreach_function(inputs.tools, [&](const json & tool) { @@ -348,9 +523,9 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input) return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); } -static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); foreach_function(inputs.tools, [&](const json & tool) { @@ -466,10 +641,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame } } -static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) { +static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) { auto builtin_tools = json::array(); common_chat_params data; - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; @@ -561,6 +736,7 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo return { /* .role = */ "assistant", /* .content = */ match.prefix().str(), + /* .content_parts = */ {}, /* .tool_calls = */ { { /* .name = */ match[1], @@ -570,16 +746,17 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo /* .id = */ "", }, }, + /* .reasoning_content = */ "", }; } } return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); } -static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null(); + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null(); data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; foreach_function(inputs.tools, [&](const json & tool) { @@ -665,15 +842,15 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, return msg; } -static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { - fprintf(stderr, "%s\n", __func__); +static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { + LOG_DBG("%s\n", __func__); common_chat_params data; data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { {"datetime", "Jan 29 2025 13:00:00 GMT"}, {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, }); if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); foreach_function(inputs.tools, [&](const json & tool) { @@ -711,14 +888,14 @@ static common_chat_msg common_chat_parse_firefunction_v2(const std::string & inp return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1); } -static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) { // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar common_chat_params data; data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector first_tool_rules; std::vector subsequent_tool_rules; @@ -795,14 +972,14 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in } } -static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt common_chat_params data; json tools = inputs.tools.is_null() ? inputs.tools : json::array(); std::string python_code_argument_name; auto has_raw_python = false; - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; foreach_function(inputs.tools, [&](const json & tool) { @@ -857,13 +1034,15 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s return { /* .role = */ "assistant", /* .content = */ match.prefix().str(), + /* .content_parts = */ {}, /* .tool_calls = */ { { /* .name = */ "python", /* .arguments = */ (json {{"code", code}}).dump(), /* .id = */ "", }, - } + }, + /* .reasoning_content = */ "", }; } static std::regex function_regex(R"()"); @@ -872,10 +1051,10 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); } -static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; // (content)?({"name": "foo", "arguments": {"a": 1}})* - data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; foreach_function(inputs.tools, [&](const json & tool) { @@ -915,7 +1094,9 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) return { /* .role = */ "assistant", /* .content = */ input, + /* .content_parts = */ {}, /* .tool_calls = */ {}, + /* .reasoning_content = */ "", }; } @@ -952,12 +1133,14 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) return { /* .role = */ "assistant", /* .content = */ input, + /* .content_parts = */ {}, /* .tool_calls = */ {}, + /* .reasoning_content = */ "", }; } } -static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { +static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; @@ -973,81 +1156,194 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha return data; } -common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { - const auto & src = tmpl.source(); - const auto & caps = tmpl.original_caps(); - - if (inputs.tools.is_array()) { - if (inputs.tool_choice != "none" && !inputs.grammar.empty()) { - throw std::runtime_error("Cannot specify grammar with tools"); +static json messages_to_json(const std::vector & msgs) { + json messages = json::array(); + for (const auto & msg : msgs) { + if (!msg.content.empty() && !msg.content_parts.empty()) { + throw std::runtime_error("Cannot specify both content and content_parts"); } - if (caps.supports_tool_calls && !caps.supports_tools) { - LOG_WRN("Template supports tool calls but does not natively describe tools. The fallback behaviour used may produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n"); + json jmsg { + {"role", msg.role}, + }; + if (!msg.content.empty()) { + jmsg["content"] = msg.content; + } else if (!msg.content_parts.empty()) { + auto & parts = jmsg["content"] = json::array(); + for (const auto & part : msg.content_parts) { + parts.push_back({ + {"type", part.type}, + {"text", part.text}, + }); + } + } else { + jmsg["content"] = json(); // null + } + if (!msg.reasoning_content.empty()) { + jmsg["reasoning_content"] = msg.reasoning_content; } + if (!msg.tool_calls.empty()) { + auto & tool_calls = jmsg["tool_calls"] = json::array(); + for (const auto & tool_call : msg.tool_calls) { + json tc { + {"type", "function"}, + {"function", { + {"name", tool_call.name}, + {"arguments", tool_call.arguments}, + }}, + }; + if (!tool_call.id.empty()) { + tc["id"] = tool_call.id; + } + tool_calls.push_back(tc); + } + } + messages.push_back(jmsg); } + return messages; +} - // DeepSeek R1: use handler in all cases except json schema (thinking / tools). - if (src.find("<|tool▁calls▁begin|>") != std::string::npos && inputs.json_schema.is_null()) { - return common_chat_params_init_deepseek_r1(tmpl, inputs); - } +common_chat_params common_chat_templates_apply( + const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs) +{ + GGML_ASSERT(tmpls != nullptr); + if (inputs.use_jinja) { + templates_params params; + params.messages = messages_to_json(inputs.messages); + params.add_generation_prompt = inputs.add_generation_prompt; + params.extract_reasoning = inputs.extract_reasoning; + params.tool_choice = inputs.tool_choice; + params.grammar = inputs.grammar; + if (!inputs.json_schema.empty()) { + params.json_schema = json::parse(inputs.json_schema); + } + if (!inputs.tools.empty()) { + params.tools = json::parse(inputs.tools); + } + const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use + ? *tmpls->template_tool_use + : *tmpls->template_default; + const auto & src = tmpl.source(); + const auto & caps = tmpl.original_caps(); + + if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { + LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); + params.parallel_tool_calls = false; + } else { + params.parallel_tool_calls = inputs.parallel_tool_calls; + } - // Command R7B: : use handler in all cases except json schema (thinking / tools). - if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) { - return common_chat_params_init_command_r7b(tmpl, inputs); - } + if (params.tools.is_array()) { + if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) { + throw std::runtime_error("Cannot specify grammar with tools"); + } + if (caps.supports_tool_calls && !caps.supports_tools) { + LOG_WRN("Template supports tool calls but does not natively describe tools. The fallback behaviour used may produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n"); + } + } - // Use generic handler when mixing tools + JSON schema. - // TODO: support that mix in handlers below. - if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) { - return common_chat_params_init_generic(tmpl, inputs); - } + // DeepSeek R1: use handler in all cases except json schema (thinking / tools). + if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) { + return common_chat_params_init_deepseek_r1(tmpl, params); + } - // Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases. - if (src.find(">>>all") != std::string::npos) { - return common_chat_params_init_functionary_v3_2(tmpl, inputs); - } + // Command R7B: : use handler in all cases except json schema (thinking / tools). + if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) { + return common_chat_params_init_command_r7b(tmpl, params); + } - // Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases. - if (src.find(" functools[") != std::string::npos) { - return common_chat_params_init_firefunction_v2(tmpl, inputs); - } + // Use generic handler when mixing tools + JSON schema. + // TODO: support that mix in handlers below. + if ((!params.tools.is_array() && params.json_schema.is_object())) { + return common_chat_params_init_generic(tmpl, params); + } - // Plain handler (no tools) - if (inputs.tools.is_null() || inputs.tool_choice == "none") { - return common_chat_params_init_without_tools(tmpl, inputs); - } + // Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases. + if (src.find(">>>all") != std::string::npos) { + return common_chat_params_init_functionary_v3_2(tmpl, params); + } - // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) - if (src.find("") != std::string::npos) { - return common_chat_params_init_hermes_2_pro(tmpl, inputs); - } + // Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases. + if (src.find(" functools[") != std::string::npos) { + return common_chat_params_init_firefunction_v2(tmpl, params); + } - // Functionary v3.1 (w/ tools) - if (src.find("<|start_header_id|>") != std::string::npos - && src.find("ipython<|end_header_id|>") != std::string::npos) { - auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; - return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools); - } + // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) + if (src.find("") != std::string::npos) { + return common_chat_params_init_hermes_2_pro(tmpl, params); + } - // Mistral Nemo (w/ tools) - if (src.find("[TOOL_CALLS]") != std::string::npos) { - return common_chat_params_init_mistral_nemo(tmpl, inputs); - } + // Functionary v3.1 (w/ tools) + if (src.find("<|start_header_id|>") != std::string::npos + && src.find("ipython<|end_header_id|>") != std::string::npos) { + auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; + return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools); + } + + // Mistral Nemo (w/ tools) + if (src.find("[TOOL_CALLS]") != std::string::npos) { + return common_chat_params_init_mistral_nemo(tmpl, params); + } + + // Generic fallback + return common_chat_params_init_generic(tmpl, params); + } else { + // Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template. - // Generic fallback - return common_chat_params_init_generic(tmpl, inputs); + int alloc_size = 0; + std::vector chat; + for (const auto & msg : inputs.messages) { + chat.push_back({msg.role.c_str(), msg.content.c_str()}); + alloc_size += (msg.role.size() + msg.content.size()) * 1.25; + } + + std::vector buf(alloc_size); + + // run the first time to get the total output length + const auto & src = tmpls->template_default->source(); + int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size()); + + // error: chat template is not supported + if (res < 0) { + // if the custom "tmpl" is not supported, we throw an error + // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() + throw std::runtime_error("this custom template is not supported"); + } + + // if it turns out that our buffer is too small, we resize it + if ((size_t) res > buf.size()) { + buf.resize(res); + res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size()); + } + + common_chat_params params; + params.prompt = std::string(buf.data(), res); + if (!inputs.json_schema.empty()) { + params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema)); + } else { + params.grammar = inputs.grammar; + } + return params; + } } static common_chat_msg common_chat_parse_content_only(const std::string & input) { return { /* .role = */ "assistant", /* .content = */ input, + /* .content_parts = */ {}, /* .tool_calls = */ {}, + /* .reasoning_content = */ "", }; } @@ -1083,3 +1379,10 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format throw std::runtime_error("Unsupported format: " + common_chat_format_name(format)); } } + +common_chat_tool_choice common_chat_tool_choice_parse(const std::string & tool_choice) { + if (tool_choice == "auto") return COMMON_CHAT_TOOL_CHOICE_AUTO; + if (tool_choice == "none") return COMMON_CHAT_TOOL_CHOICE_NONE; + if (tool_choice == "required") return COMMON_CHAT_TOOL_CHOICE_REQUIRED; + throw std::runtime_error("Invalid tool_choice: " + tool_choice); +} diff --git a/common/chat.hpp b/common/chat.hpp index 4147c80cdbafc..bbd5daebf9f18 100644 --- a/common/chat.hpp +++ b/common/chat.hpp @@ -3,23 +3,36 @@ #pragma once #include "common.h" -#include #include #include #include -using json = nlohmann::ordered_json; +struct common_chat_templates; -struct common_chat_inputs { - json messages; - json tools; - json tool_choice; - json json_schema; - bool parallel_tool_calls; - bool stream; - std::string grammar; - bool add_generation_prompt = true; - bool extract_reasoning = true; +struct common_chat_tool_call { + std::string name; + std::string arguments; + std::string id; +}; + +struct common_chat_msg_content_part { + std::string type; + std::string text; +}; + +// same with llama_chat_message, but uses std::string +struct common_chat_msg { + std::string role; + std::string content; + std::vector content_parts; + std::vector tool_calls; + std::string reasoning_content; +}; + +enum common_chat_tool_choice { + COMMON_CHAT_TOOL_CHOICE_AUTO, + COMMON_CHAT_TOOL_CHOICE_REQUIRED, + COMMON_CHAT_TOOL_CHOICE_NONE, }; enum common_chat_format { @@ -42,7 +55,7 @@ enum common_chat_format { struct common_chat_params { common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - json prompt; + std::string prompt; std::string grammar; bool grammar_lazy = false; std::vector grammar_triggers; @@ -50,6 +63,51 @@ struct common_chat_params { std::vector additional_stops; }; -struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs); +// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid +bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); + +struct common_chat_templates * common_chat_templates_init( + const struct llama_model * model, + const std::string & chat_template_override, + const std::string & bos_token_override = "", + const std::string & eos_token_override = ""); + +bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls); +void common_chat_templates_free(struct common_chat_templates * tmpls); + +typedef std::unique_ptr common_chat_templates_ptr; + +struct common_chat_templates_inputs { + std::vector messages; + std::string grammar; + std::string json_schema; + std::string tools; + common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; + bool add_generation_prompt = true; + bool use_jinja = true; + // Parameters below only supported when use_jinja is true + bool parallel_tool_calls = false; + bool extract_reasoning = true; +}; + +struct common_chat_params common_chat_templates_apply( + const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs); + +// Format single message, while taking into account the position of that message in chat history +std::string common_chat_format_single( + const struct common_chat_templates * tmpls, + const std::vector & past_msg, + const common_chat_msg & new_msg, + bool add_ass, + bool use_jinja); + +// Returns an example of formatted chat +std::string common_chat_format_example( + const struct common_chat_templates * tmpls, + bool use_jinja); + std::string common_chat_format_name(common_chat_format format); common_chat_msg common_chat_parse( const std::string & input, common_chat_format format); + +common_chat_tool_choice common_chat_tool_choice_parse(const std::string & tool_choice); diff --git a/common/common.cpp b/common/common.cpp index 8661e164ada6b..f005f1459938c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1768,174 +1768,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto return text; } -// -// Chat template utils -// - -bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { - if (use_jinja) { - try { - auto chat_template = common_chat_template(tmpl, "", ""); - common_chat_inputs inputs; - inputs.messages = json::array({{ - {"role", "user"}, - {"content", "test"}, - }}); - common_chat_params_init(chat_template, inputs); - return true; - } catch (const std::exception & e) { - LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); - return false; - } - } - llama_chat_message chat[] = {{"user", "test"}}; - const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0); - return res >= 0; -} - -std::string common_chat_apply_template( - const common_chat_template & tmpl, - const std::vector & msgs, - bool add_ass, - bool use_jinja) { - if (use_jinja) { - auto messages = json::array(); - for (const auto & msg : msgs) { - messages.push_back({{"role", msg.role}, {"content", msg.content}}); - } - common_chat_inputs inputs; - inputs.messages = messages; - inputs.add_generation_prompt = add_ass; - return common_chat_params_init(tmpl, inputs).prompt; - } - - int alloc_size = 0; - std::vector chat; - for (const auto & msg : msgs) { - chat.push_back({msg.role.c_str(), msg.content.c_str()}); - alloc_size += (msg.role.size() + msg.content.size()) * 1.25; - } - - std::vector buf(alloc_size); - - // run the first time to get the total output length - int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); - - // error: chat template is not supported - if (res < 0) { - // if the custom "tmpl" is not supported, we throw an error - // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() - throw std::runtime_error("this custom template is not supported"); - } - - // if it turns out that our buffer is too small, we resize it - if ((size_t) res > buf.size()) { - buf.resize(res); - res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); - } - - std::string formatted_chat(buf.data(), res); - return formatted_chat; -} - -std::string common_chat_format_single( - const common_chat_template & tmpl, - const std::vector & past_msg, - const common_chat_msg & new_msg, - bool add_ass, - bool use_jinja) { - std::ostringstream ss; - auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja); - std::vector chat_new(past_msg); - // if the past_msg ends with a newline, we must preserve it in the formatted version - if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { - ss << "\n"; - }; - // format chat with new_msg - chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja); - // get the diff part - ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); - return ss.str(); -} - -std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) { - std::vector msgs = { - {"system", "You are a helpful assistant", {}}, - {"user", "Hello", {}}, - {"assistant", "Hi there", {}}, - {"user", "How are you?", {}}, - }; - return common_chat_apply_template(tmpl, msgs, true, use_jinja); -} - -#define CHATML_TEMPLATE_SRC \ - "{%- for message in messages -%}\n" \ - " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ - "{%- endfor -%}\n" \ - "{%- if add_generation_prompt -%}\n" \ - " {{- '<|im_start|>assistant\n' -}}\n" \ - "{%- endif -%}" - -common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) -{ - std::string default_template_src; - std::string template_tool_use_src; - - bool has_explicit_template = !chat_template_override.empty(); - if (chat_template_override.empty()) { - auto str = llama_model_chat_template(model, /* name */ nullptr); - if (str) { - default_template_src = str; - has_explicit_template = true; - } - str = llama_model_chat_template(model, /* name */ "tool_use"); - if (str) { - template_tool_use_src = str; - has_explicit_template = true; - } - } else { - default_template_src = chat_template_override; - } - if (default_template_src.empty() || default_template_src == "chatml") { - if (!template_tool_use_src.empty()) { - default_template_src = template_tool_use_src; - } else { - default_template_src = CHATML_TEMPLATE_SRC; - } - } - auto vocab = llama_model_get_vocab(model); - const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { - if (token == LLAMA_TOKEN_NULL) { - if (default_template_src.find(jinja_variable_name) != std::string::npos - || template_tool_use_src.find(jinja_variable_name) != std::string::npos) { - LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name); - } - return std::string(); - } else { - return common_token_to_piece(vocab, token, true); - } - }; - auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); - auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); - try { - return { - has_explicit_template, - std::make_unique(default_template_src, token_bos, token_eos), - template_tool_use_src.empty() - ? nullptr - : std::make_unique(template_tool_use_src, token_bos, token_eos), - }; - } catch (const std::exception & e) { - LOG_ERR("%s: failed to parse chat template: %s\n", __func__, e.what()); - return { - has_explicit_template, - std::make_unique(CHATML_TEMPLATE_SRC, token_bos, token_eos), - nullptr, - }; - } -} - // // KV cache utils // diff --git a/common/common.h b/common/common.h index 98b9a4464787a..10bcc10d51bb5 100644 --- a/common/common.h +++ b/common/common.h @@ -616,62 +616,6 @@ std::string common_detokenize( const std::vector & tokens, bool special = true); -// -// Chat template utils -// - -struct common_tool_call { - std::string name; - std::string arguments; - std::string id; -}; - -// same with llama_chat_message, but uses std::string -struct common_chat_msg { - std::string role; - std::string content; - std::vector tool_calls; - std::string reasoning_content = ""; -}; - -// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid -bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); - -namespace minja { - class chat_template; -} - -typedef minja::chat_template common_chat_template; - -struct common_chat_templates { - bool has_explicit_template; // Model had builtin template or template overridde was specified. - std::unique_ptr template_default; // always set (defaults to chatml) - std::unique_ptr template_tool_use; -}; - -// CPP wrapper for llama_chat_apply_template -// If the built-in template is not supported, we default to chatml -// If the custom "tmpl" is not supported, we throw an error -std::string common_chat_apply_template( - const common_chat_template & tmpl, - const std::vector & chat, - bool add_ass, - bool use_jinja); - -// Format single message, while taking into account the position of that message in chat history -std::string common_chat_format_single( - const common_chat_template & tmpl, - const std::vector & past_msg, - const common_chat_msg & new_msg, - bool add_ass, - bool use_jinja); - -// Returns an example of formatted chat -std::string common_chat_format_example( - const common_chat_template & tmpl, bool use_jinja); - -common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); - // // KV cache utils // diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e654d3542c6c3..4e953675f8de7 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -4,7 +4,7 @@ #include "log.h" #include "sampling.h" #include "llama.h" -#include "chat-template.hpp" +#include "chat.hpp" #include #include @@ -158,7 +158,9 @@ int main(int argc, char ** argv) { } const llama_vocab * vocab = llama_model_get_vocab(model); - auto chat_templates = common_chat_templates_from_model(model, params.chat_template); + common_chat_templates_ptr chat_templates( + common_chat_templates_init(model, params.chat_template), + &common_chat_templates_free); LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); @@ -201,7 +203,7 @@ int main(int argc, char ** argv) { } // auto enable conversation mode if chat template is available - const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default; + const bool has_chat_template = common_chat_templates_was_explicit(chat_templates.get()); if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) { if (has_chat_template) { LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__); @@ -219,7 +221,7 @@ int main(int argc, char ** argv) { // print chat template example in conversation mode if (params.conversation_mode) { if (params.enable_chat_template) { - LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str()); + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.get(), params.use_jinja).c_str()); } else { LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); } @@ -264,9 +266,15 @@ int main(int argc, char ** argv) { std::vector embd_inp; auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { - common_chat_msg new_msg{role, content, {}}; - auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja); - chat_msgs.push_back({role, content, {}}); + common_chat_msg new_msg { + role, + content, + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", + }; + auto formatted = common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, role == "user", g_params->use_jinja); + chat_msgs.push_back(new_msg); LOG_DBG("formatted: '%s'\n", formatted.c_str()); return formatted; }; diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 9362da22083d3..a70222ccbe20e 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -24,7 +24,7 @@ #include #include -#include "chat-template.hpp" +#include "chat.hpp" #include "common.h" #include "json.hpp" #include "linenoise.cpp/linenoise.h" @@ -557,7 +557,7 @@ class LlamaData { llama_model_ptr model; llama_sampler_ptr sampler; llama_context_ptr context; - std::vector messages; + std::vector messages; // TODO: switch to common_chat_msg std::list msg_strs; std::vector fmtted; @@ -834,44 +834,20 @@ static void add_message(const char * role, const std::string & text, LlamaData & } // Function to apply the chat template and resize `formatted` if needed -static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) { - if (use_jinja) { - json messages = json::array(); - for (const auto & msg : llama_data.messages) { - messages.push_back({ - {"role", msg.role}, - {"content", msg.content}, - }); - } - try { - minja::chat_template_inputs tmpl_inputs; - tmpl_inputs.messages = messages; - tmpl_inputs.add_generation_prompt = append; - - minja::chat_template_options tmpl_opts; - tmpl_opts.use_bos_token = false; - tmpl_opts.use_eos_token = false; - - auto result = tmpl.apply(tmpl_inputs, tmpl_opts); - llama_data.fmtted.resize(result.size() + 1); - memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); - return result.size(); - } catch (const std::exception & e) { - printe("failed to render the chat template: %s\n", e.what()); - return -1; - } - } - int result = llama_chat_apply_template( - tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append, - append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); - if (append && result > static_cast(llama_data.fmtted.size())) { - llama_data.fmtted.resize(result); - result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(), - llama_data.messages.size(), append, llama_data.fmtted.data(), - llama_data.fmtted.size()); - } - - return result; +static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) { + common_chat_templates_inputs inputs; + for (const auto & msg : llama_data.messages) { + inputs.messages.push_back({ msg.role, msg.content, {}, {}, "" }); + } + inputs.add_generation_prompt = append; + inputs.use_jinja = use_jinja; + + auto chat_params = common_chat_templates_apply(tmpls, inputs); + // TODO: use other params for tool calls. + auto result = chat_params.prompt; + llama_data.fmtted.resize(result.size() + 1); + memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); + return result.size(); } // Function to tokenize the prompt @@ -1015,8 +991,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt, } // Helper function to apply the chat template and handle errors -static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { - const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja); +static int apply_chat_template_with_error_handling(const common_chat_templates * tmpls, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { + const int new_len = apply_chat_template(tmpls, llama_data, append, use_jinja); if (new_len < 0) { printe("failed to apply the chat template\n"); return -1; @@ -1078,8 +1054,9 @@ static int get_user_input(std::string & user_input, const std::string & user) { static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) { int prev_len = 0; llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); - auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), ""); - GGML_ASSERT(chat_templates.template_default); + common_chat_templates_ptr chat_templates( + common_chat_templates_init(llama_data.model.get(), ""), + &common_chat_templates_free); static const bool stdout_a_terminal = is_stdout_a_terminal(); while (true) { // Get user input @@ -1090,7 +1067,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_ add_message("user", user.empty() ? user_input : user, llama_data); int new_len; - if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) { + if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) { return 1; } @@ -1105,7 +1082,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_ } add_message("assistant", response, llama_data); - if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) { + if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) { return 1; } } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 71151183b81da..ad754036118d3 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1807,7 +1807,7 @@ struct server_context { // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; - common_chat_templates chat_templates; + struct common_chat_templates * chat_templates = nullptr; ~server_context() { // Clear any sampling context @@ -1825,6 +1825,7 @@ struct server_context { } llama_batch_free(batch); + common_chat_templates_free(chat_templates); } bool load_model(const common_params & params) { @@ -1891,45 +1892,17 @@ struct server_context { llama_init_dft.context.reset(); } - if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) { + chat_templates = common_chat_templates_init(model, params_base.chat_template); + try { + common_chat_format_example(chat_templates, params.use_jinja); + } catch (const std::exception & e) { SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); - chat_templates = common_chat_templates_from_model(model, "chatml"); - } else { - chat_templates = common_chat_templates_from_model(model, params_base.chat_template); + chat_templates = common_chat_templates_init(model, "chatml"); } - GGML_ASSERT(chat_templates.template_default.get() != nullptr); return true; } - bool validate_builtin_chat_template(bool use_jinja) const { - llama_chat_message chat[] = {{"user", "test"}}; - - if (use_jinja) { - auto templates = common_chat_templates_from_model(model, ""); - common_chat_inputs inputs; - inputs.messages = json::array({{ - {"role", "user"}, - {"content", "test"}, - }}); - GGML_ASSERT(templates.template_default); - try { - common_chat_params_init(*templates.template_default, inputs); - if (templates.template_tool_use) { - common_chat_params_init(*templates.template_tool_use, inputs); - } - return true; - } catch (const std::exception & e) { - SRV_ERR("failed to apply template: %s\n", e.what()); - return false; - } - } else { - const char * tmpl = llama_model_chat_template(model, /* name */ nullptr); - const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0); - return chat_res > 0; - } - } - void init() { const int32_t n_ctx_slot = n_ctx / params_base.n_parallel; @@ -3822,14 +3795,14 @@ int main(int argc, char ** argv) { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model }, - { "chat_template", ctx_server.chat_templates.template_default->source() }, - { "bos_token", ctx_server.chat_templates.template_default->bos_token() }, - { "eos_token", ctx_server.chat_templates.template_default->eos_token() }, + // { "chat_template", ctx_server.chat_templates.template_default->source() }, + // { "bos_token", ctx_server.chat_templates.template_default->bos_token() }, + // { "eos_token", ctx_server.chat_templates.template_default->eos_token() }, { "build_info", build_info }, }; - if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) { - data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source(); - } + // if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) { + // data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source(); + // } res_ok(res, data); }; @@ -4481,9 +4454,9 @@ int main(int argc, char ** argv) { LOG_INF("%s: model loaded\n", __func__); // print sample chat example to make it clear which template is used - LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - ctx_server.chat_templates.template_default->source().c_str(), - common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str()); + // LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, + // ctx_server.chat_templates.template_default->source().c_str(), + // common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str()); ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) { ctx_server.process_single_task(task); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 86de0e6d78977..8ced4396793f4 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -347,39 +347,52 @@ static llama_tokens format_infill( return embd_inp; } -// Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const common_chat_template & tmpl, const std::vector & messages) { - std::vector chat; - - for (size_t i = 0; i < messages.size(); ++i) { - const auto & curr_msg = messages[i]; - - std::string role = json_value(curr_msg, "role", std::string("")); - - std::string content; - if (curr_msg.contains("content")) { - if (curr_msg["content"].is_string()) { - content = curr_msg["content"].get(); - } else if (curr_msg["content"].is_array()) { - for (const auto & part : curr_msg["content"]) { - if (part.contains("text")) { - content += "\n" + part["text"].get(); - } +static std::vector oaicompat_messages_parse(const json & messages) { + std::vector msgs; + + for (const auto & message : messages) { + common_chat_msg msg; + msg.role = json_value(message, "role", std::string("")); + + if (message.contains("content")) { + const auto & content = message.at("content"); + if (content.is_string()) { + msg.content = content; + } else if (content.is_array()) { + for (const auto & part : content) { + if (!part.contains("type")) throw std::runtime_error("Missing content part type: " + part.dump()); + const auto & type = part.at("type"); + if (type != "text") throw std::runtime_error("Unsupported content part type: " + type.dump()); + common_chat_msg_content_part msg_part; + msg_part.type = type; + msg_part.text = part.at("text"); + msg.content_parts.push_back(msg_part); } - } else { + } else if (!content.is_null()) { throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); } } else { - throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); + throw std::runtime_error("Expected 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); + } + if (message.contains("reasoning_content")) { + msg.reasoning_content = message.at("reasoning_content"); + } + if (message.contains("tool_calls")) { + for (const auto & tool_call : message.at("tool_calls")) { + common_chat_tool_call tc; + tc.name = json_value(tool_call, "tool", std::string("")); + tc.arguments = tool_call.at("arguments"); + if (tool_call.contains("id")) { + tc.id = tool_call.at("id"); + } + msg.tool_calls.push_back(tc); + } } - chat.push_back({role, content, /* tool_calls= */ {}}); + msgs.push_back(msg); } - const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false); - LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); - - return formatted_chat; + return msgs; } // @@ -579,12 +592,9 @@ static json oaicompat_completion_params_parse( const json & body, /* openai api json semantics */ bool use_jinja, common_reasoning_format reasoning_format, - const common_chat_templates & chat_templates) + const struct common_chat_templates * tmpls) { json llama_params; - const auto & tmpl = body.contains("tools") && chat_templates.template_tool_use - ? *chat_templates.template_tool_use - : *chat_templates.template_default; auto tools = json_value(body, "tools", json()); auto stream = json_value(body, "stream", false); @@ -610,62 +620,58 @@ static json oaicompat_completion_params_parse( llama_params["stop"] = json_value(body, "stop", json::array()); } + auto json_schema = json_value(llama_params, "json_schema", json()); + auto grammar = json_value(llama_params, "grammar", std::string()); + if (!json_schema.is_null() && !grammar.empty()) { + throw std::runtime_error("Cannot use both json_schema and grammar"); + } + // Handle "response_format" field if (body.contains("response_format")) { json response_format = json_value(body, "response_format", json::object()); std::string response_type = json_value(response_format, "type", std::string()); if (response_type == "json_object") { - llama_params["json_schema"] = json_value(response_format, "schema", json::object()); + json_schema = json_value(response_format, "schema", json::object()); } else if (response_type == "json_schema") { json json_schema = json_value(response_format, "json_schema", json::object()); - llama_params["json_schema"] = json_value(json_schema, "schema", json::object()); + json_schema = json_value(json_schema, "schema", json::object()); } else if (!response_type.empty() && response_type != "text") { throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); } } + common_chat_templates_inputs inputs; + inputs.messages = oaicompat_messages_parse(body.at("messages")); + inputs.add_generation_prompt = true; + inputs.use_jinja = use_jinja; + inputs.grammar = grammar; + inputs.tools = tools.is_null() ? "" : tools.dump(); + inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); + inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); + inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; + inputs.tool_choice = common_chat_tool_choice_parse(json_value(body, "tool_choice", std::string("auto"))); + if (inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && llama_params.contains("grammar")) { + throw std::runtime_error("Cannot use custom grammar constraints with tools."); + } + // Apply chat template to the list of messages - if (use_jinja) { - auto tool_choice = json_value(body, "tool_choice", std::string("auto")); - if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") { - throw std::runtime_error("Invalid tool_choice: " + tool_choice); - } - if (tool_choice != "none" && llama_params.contains("grammar")) { - throw std::runtime_error("Cannot use custom grammar constraints with tools."); - } - common_chat_inputs inputs; - inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; - inputs.messages = body.at("messages"); - inputs.tools = tools; - inputs.tool_choice = tool_choice; - inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); - if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { - LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); - inputs.parallel_tool_calls = false; - } - inputs.stream = stream; - // TODO: support mixing schema w/ tools beyond generic format. - inputs.json_schema = json_value(llama_params, "json_schema", json()); - auto chat_params = common_chat_params_init(tmpl, inputs); - - llama_params["chat_format"] = static_cast(chat_params.format); - llama_params["prompt"] = chat_params.prompt; - llama_params["grammar"] = chat_params.grammar; - llama_params["grammar_lazy"] = chat_params.grammar_lazy; - auto grammar_triggers = json::array(); - for (const auto & trigger : chat_params.grammar_triggers) { - grammar_triggers.push_back({ - {"word", trigger.word}, - {"at_start", trigger.at_start}, - }); - } - llama_params["grammar_triggers"] = grammar_triggers; - llama_params["preserved_tokens"] = chat_params.preserved_tokens; - for (const auto & stop : chat_params.additional_stops) { - llama_params["stop"].push_back(stop); - } - } else { - llama_params["prompt"] = format_chat(tmpl, body.at("messages")); + auto chat_params = common_chat_templates_apply(tmpls, inputs); + + llama_params["chat_format"] = static_cast(chat_params.format); + llama_params["prompt"] = chat_params.prompt; + llama_params["grammar"] = chat_params.grammar; + llama_params["grammar_lazy"] = chat_params.grammar_lazy; + auto grammar_triggers = json::array(); + for (const auto & trigger : chat_params.grammar_triggers) { + grammar_triggers.push_back({ + {"word", trigger.word}, + {"at_start", trigger.at_start}, + }); + } + llama_params["grammar_triggers"] = grammar_triggers; + llama_params["preserved_tokens"] = chat_params.preserved_tokens; + for (const auto & stop : chat_params.additional_stops) { + llama_params["stop"].push_back(stop); } // Handle "n" field diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index e0314ae1d6296..80d12b83cbbb9 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -7,7 +7,7 @@ #include "llama.h" #include "common.h" -#include "chat-template.hpp" +#include "chat.hpp" static std::string normalize_newlines(const std::string & s) { #ifdef _WIN32 @@ -304,11 +304,14 @@ int main(void) { } } - json messages = json::array(); + std::vector messages; for (const auto & msg : conversation) { messages.push_back({ - {"role", msg.role}, - {"content", msg.content}, + msg.role, + msg.content, + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", }); } for (const auto & test_case : test_cases) { @@ -317,8 +320,12 @@ int main(void) { } printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str()); try { - minja::chat_template tmpl(test_case.template_str, test_case.bos_token, test_case.eos_token); - auto output = normalize_newlines(tmpl.apply(messages, json(), add_generation_prompt)); + common_chat_templates_ptr tmpls(common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token), &common_chat_templates_free); + common_chat_templates_inputs inputs; + inputs.messages = messages; + inputs.add_generation_prompt = add_generation_prompt; + auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt; + output = normalize_newlines(output); auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja); if (output != expected_output) { printf("Expected:\n%s\n", expected_output.c_str()); @@ -336,11 +343,17 @@ int main(void) { // test llama_chat_format_single for system message printf("\n\n=== llama_chat_format_single (system message) ===\n\n"); std::vector chat2; - common_chat_msg sys_msg{"system", "You are a helpful assistant", {}}; + common_chat_msg sys_msg { + "system", + "You are a helpful assistant", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", + }; auto fmt_sys = [&](std::string tmpl_str) { - minja::chat_template tmpl(tmpl_str, "", ""); - auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false); + common_chat_templates_ptr tmpls(common_chat_templates_init(/* model= */ nullptr, tmpl_str), &common_chat_templates_free); + auto output = common_chat_format_single(tmpls.get(), chat2, sys_msg, false, /* use_jinja= */ false); printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); return output; @@ -360,14 +373,20 @@ int main(void) { // test llama_chat_format_single for user message printf("\n\n=== llama_chat_format_single (user message) ===\n\n"); - chat2.push_back({"system", "You are a helpful assistant", {}}); - chat2.push_back({"user", "Hello", {}}); - chat2.push_back({"assistant", "I am assistant", {}}); - common_chat_msg new_msg{"user", "How are you", {}}; + chat2.push_back({"system", "You are a helpful assistant", {}, {}, ""}); + chat2.push_back({"user", "Hello", {}, {}, ""}); + chat2.push_back({"assistant", "I am assistant", {}, {}, ""}); + common_chat_msg new_msg { + "user", + "How are you", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", + }; - auto fmt_single = [&](std::string tmpl_str) { - minja::chat_template tmpl(tmpl_str, "", ""); - auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false); + auto fmt_single = [&](const std::string & tmpl_str) { + common_chat_templates_ptr tmpls(common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str()), &common_chat_templates_free); + auto output = common_chat_format_single(tmpls.get(), chat2, new_msg, true, /* use_jinja= */ false); printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); return output; diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 2836caf6a71a3..e52019da71754 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -17,31 +17,6 @@ using json = nlohmann::ordered_json; -static common_chat_msg msg_from_json(const json & message) { - common_chat_msg ret; - ret.role = "assistant"; - if (message.contains("content") && !message.at("content").is_null()) { - ret.content = message.at("content"); - } - if (message.contains("tool_plan")) { - ret.reasoning_content = message.at("tool_plan"); - } - if (message.contains("reasoning_content")) { - ret.reasoning_content = message.at("reasoning_content"); - } - auto has_tool_calls = message.contains("tool_calls"); - if (has_tool_calls) { - for (const auto & tc : message.at("tool_calls")) { - const auto & arguments = tc.at("function").at("arguments"); - ret.tool_calls.push_back({ - tc.at("function").at("name").get(), - arguments.is_string() ? arguments.get() : arguments.dump(), - tc.contains("id") ? tc.at("id").get() : "", - }); - } - } - return ret; -} template static void assert_equals(const T & expected, const T & actual) { if (expected != actual) { @@ -70,6 +45,10 @@ static std::string read_file(const std::string & path) { return out; } +static common_chat_templates_ptr read_templates(const std::string & path) { + return common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, read_file(path)), &common_chat_templates_free); +} + static std::unique_ptr build_grammar(const std::string & grammar_str) { return std::unique_ptr( llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0)); @@ -108,6 +87,13 @@ static std::string dump(const json & j) { static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) { assert_equals(expected.role, actual.role); assert_equals(expected.content, actual.content); + assert_equals(expected.content_parts.size(), actual.content_parts.size()); + for (size_t i = 0; i < expected.content_parts.size(); i++) { + const auto & expected_part = expected.content_parts[i]; + const auto & actual_part = actual.content_parts[i]; + assert_equals(expected_part.type, actual_part.type); + assert_equals(expected_part.text, actual_part.text); + } assert_equals(expected.reasoning_content, actual.reasoning_content); assert_equals(expected.tool_calls.size(), actual.tool_calls.size()); for (size_t i = 0; i < expected.tool_calls.size(); i++) { @@ -170,30 +156,29 @@ const auto code_interpreter_tool = json::parse(R"({ } } })"); -const json tools = { special_function_tool, python_tool }; -const json llama_3_1_tools = { special_function_tool, code_interpreter_tool }; +const auto tools = json::array({ special_function_tool, python_tool }).dump(); +const auto llama_3_1_tools = json::array({ special_function_tool, code_interpreter_tool }).dump(); struct delta_data { std::string delta; common_chat_params params; }; -static delta_data init_delta(const common_chat_template & tmpl, const std::vector & end_tokens, - const json & user_message, const json & delta_message, const json & tools, - const json & tool_choice, +static delta_data init_delta(const struct common_chat_templates * tmpls, const std::vector & end_tokens, + const common_chat_msg & user_message, const common_chat_msg & delta_message, const json & tools, + const common_chat_tool_choice & tool_choice, bool think = false) { - common_chat_inputs inputs; + common_chat_templates_inputs inputs; inputs.parallel_tool_calls = true; - inputs.messages = json::array(); inputs.messages.push_back(user_message); inputs.tools = tools; inputs.tool_choice = tool_choice; inputs.extract_reasoning = think; - auto params_prefix = common_chat_params_init(tmpl, inputs); + auto params_prefix = common_chat_templates_apply(tmpls, inputs); inputs.messages.push_back(delta_message); inputs.add_generation_prompt = false; - auto params_full = common_chat_params_init(tmpl, inputs); + auto params_full = common_chat_templates_apply(tmpls, inputs); std::string prefix = params_prefix.prompt; std::string full = params_full.prompt; @@ -234,30 +219,25 @@ static delta_data init_delta(const common_chat_template & tmpl, const std::vecto gets the diff, removes any end tokens and parses the result w/ the grammar, checking that the parsed message is the same as the test_message */ -static void test_template(const common_chat_template & tmpl, const std::vector & end_tokens, - const json & test_message, const json & tools = {}, const std::string & expected_delta = "", +static void test_templates(const struct common_chat_templates * tmpls, const std::vector & end_tokens, + const common_chat_msg & test_message, const std::string & tools = "", const std::string & expected_delta = "", bool expect_grammar_triggered = true, bool test_grammar_if_triggered = true, bool think = false) { - common_chat_msg expected_msg = msg_from_json(test_message); + common_chat_msg user_message = { "user", "Hello, world!", {}, {}, "" }; - auto user_message = json{ - { "role", "user" }, - { "content", "Hello, world!" } - }; - - for (const auto & tool_choice : json({ "auto", "required" })) { - auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice, think); + for (const auto & tool_choice : std::vector {COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED}) { + auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice, think); if (!expected_delta.empty()) { assert_equals(expected_delta, data.delta); } if (expect_grammar_triggered) { const auto msg = common_chat_parse(data.delta, data.params.format); - assert_msg_equals(expected_msg, msg); + assert_msg_equals(test_message, msg); } - if (!expected_msg.tool_calls.empty()) { + if (!test_message.tool_calls.empty()) { GGML_ASSERT(!data.params.grammar.empty()); } if (!data.params.grammar.empty()) { @@ -298,245 +278,196 @@ static void test_template(const common_chat_template & tmpl, const std::vectorI'm thinkingHello, world!\nWhat's up?", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", }; - json message_assist { - { "role", "assistant" }, - { "content", "Hello, world!\nWhat's up?" }, + common_chat_msg message_assist_thoughts_unparsed_r7b { + "assistant", + "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", }; - json message_assist_thoughts_unparsed_think { - { "role", "assistant" }, - { "content", "I'm thinkingHello, world!\nWhat's up?" }, + common_chat_msg message_assist_thoughts { + "assistant", + "Hello, world!\nWhat's up?", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "I'm thinking", }; - json message_assist_thoughts_unparsed_r7b { - { "role", "assistant" }, - { "content", "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?" }, + std::vector tool_calls { + { "special_function", "{\"arg1\": 1}", /* .id = */ "" }, }; - json message_assist_thoughts { - { "role", "assistant" }, - { "content", "Hello, world!\nWhat's up?" }, - { "reasoning_content", "I'm thinking" }, + std::vector tool_calls_idx { + { "special_function", "{\"arg1\": 1}", /* .id = */ "0" }, }; - json tool_calls = json::array({{ - { "type", "function" }, - { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } }, - }}); - - json message_assist_call { - { "role", "assistant"}, - { "content", {}}, - { "tool_calls", { - { - { "type", "function" }, - { "function", { - { "name", "special_function" }, - { "arguments", "{\"arg1\": 1}" }, - }}, - }, - }}, + std::vector tool_calls_id { + { "special_function", "{\"arg1\": 1}", /* .id = */ "123456789" }, }; - json message_assist_call_thoughts = { - { "role", "assistant" }, - { "content", nullptr }, - { "reasoning_content", "I'm\nthinking" }, - { "tool_calls", { - { - { "type", "function" }, - { "function", { - { "name", "special_function" }, - { "arguments", "{\"arg1\": 1}" }, - }}, - }, - }}, + + common_chat_msg message_assist_call { + "assistant", + "", + /* .content_parts = */ {}, + tool_calls, + /* .reasoning_content = */ "", + }; + common_chat_msg message_assist_call_thoughts = { + "assistant", + /* .content = */ "", + /* .content_parts = */ {}, + tool_calls, + /* .reasoning_content = */ "I'm\nthinking", }; - json message_assist_call_thoughts_unparsed = { - { "role", "assistant" }, - { "content", "I'm\nthinking" }, - { "tool_calls", { - { - { "type", "function" }, - { "function", { - { "name", "special_function" }, - { "arguments", "{\"arg1\": 1}" }, - }}, - }, - }}, + common_chat_msg message_assist_call_thoughts_unparsed = { + "assistant", + /* .content = */ "I'm\nthinking", + /* .content_parts = */ {}, + tool_calls, + /* .reasoning_content = */ "", }; - json message_assist_call_id { - { "role", "assistant"}, - { "content", {}}, - { "tool_calls", { - { - { "type", "function" }, - { "function", { - { "name", "special_function" }, - { "arguments", "{\"arg1\": 1}" }, - }}, - {"id", "123456789"}, - }, - }}, - { "role", "assistant" }, - { "content", {} }, - { "tool_calls", tool_calls } + common_chat_msg message_assist_call_id { + "assistant", + "", + /* .content_parts = */ {}, + tool_calls_id, + /* .reasoning_content = */ "", }; - json message_assist_call_idx { - { "role", "assistant"}, - { "content", {}}, - { "tool_calls", { - { - { "type", "function" }, - { "function", { - { "name", "special_function" }, - { "arguments", "{\"arg1\": 1}" }, - }}, - // Index of the tool call in the tool_calls array - {"id", "0"}, - }, - }}, - { "role", "assistant" }, - { "content", {} }, - { "tool_calls", tool_calls } + common_chat_msg message_assist_call_idx { + "assistant", + "", + /* .content_parts = */ {}, + tool_calls_idx, + /* .reasoning_content = */ "", }; - json message_assist_call_tool_plan_idx = message_assist_call_idx; - message_assist_call_tool_plan_idx["tool_plan"] = "I'm thinking"; - - auto python_message_assist_call = json{ - { "role", "assistant" }, - { "content", {} }, - { "tool_calls", json{ { - { "type", "function" }, - { "function", - { - { "name", "python" }, - { "arguments", - { - { "code", "print('hey')" }, - } }, - } }, - } } } + common_chat_msg message_assist_call_python { + "assistant", + "", + /* .content_parts = */ {}, + { { "python", "{\"code\": \"print('hey')\"}", /* .id = */ "" } }, + /* .reasoning_content = */ "", }; - auto code_interpreter_message_assist_call = json{ - { "role", "assistant" }, - { "content", {} }, - { "tool_calls", json{ { - { "type", "function" }, - { "function", - { - { "name", "code_interpreter" }, - { "arguments", - { - { "code", "print('hey')" }, - } }, - } }, - } } } + common_chat_msg message_assist_call_code_interpreter { + "assistant", + "", + /* .content_parts = */ {}, + { { "code_interpreter", "{\"code\": \"print('hey')\"}", /* .id = */ "" } }, + /* .reasoning_content = */ "", }; - common_chat_inputs inputs_no_tools; - inputs_no_tools.messages = json::array({message_user}); + common_chat_templates_inputs inputs_no_tools; + inputs_no_tools.messages = {message_user}; inputs_no_tools.extract_reasoning = false; - common_chat_inputs inputs_no_tools_think; - inputs_no_tools_think.messages = json::array({message_user}); + common_chat_templates_inputs inputs_no_tools_think; + inputs_no_tools_think.messages = {message_user}; inputs_no_tools_think.extract_reasoning = true; - common_chat_inputs inputs_tools; - inputs_tools.messages = json::array({message_user}); - inputs_tools.tools = json::array({special_function_tool}); + common_chat_templates_inputs inputs_tools; + inputs_tools.messages = {message_user}; + inputs_tools.tools = json::array({special_function_tool}).dump(); inputs_tools.extract_reasoning = false; - common_chat_inputs inputs_tools_think; - inputs_tools_think.messages = json::array({message_user}); - inputs_tools_think.tools = json::array({special_function_tool}); + common_chat_templates_inputs inputs_tools_think; + inputs_tools_think.messages = {message_user}; + inputs_tools_think.tools = json::array({special_function_tool}).dump(); inputs_tools_think.extract_reasoning = true; - common_chat_inputs inputs_tools_builtin; - inputs_tools_builtin.messages = json::array({message_user}); - inputs_tools_builtin.tools = json::array({python_tool}); + common_chat_templates_inputs inputs_tools_builtin; + inputs_tools_builtin.messages = {message_user}; + inputs_tools_builtin.tools = json::array({python_tool}).dump(); inputs_tools_builtin.extract_reasoning = false; { // Not supported yet - const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"), "", ""); - assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format); + auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"); + assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format); } { - const common_chat_template tmpl(read_file("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"), "", ""); + auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"); std::vector end_tokens{ "<|END_OF_TURN_TOKEN|>" }; - assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_params_init(tmpl, inputs_tools).format); - assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format); + assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format); - assert_msg_equals(msg_from_json(message_assist), + assert_msg_equals(message_assist, common_chat_parse( "Hello, world!\nWhat's up?", COMMON_CHAT_FORMAT_COMMAND_R7B)); - assert_msg_equals(msg_from_json(message_assist), + assert_msg_equals(message_assist, common_chat_parse( "Hello, world!\nWhat's up?<|END_RESPONSE|>", COMMON_CHAT_FORMAT_COMMAND_R7B)); - assert_msg_equals(msg_from_json(message_assist), + assert_msg_equals(message_assist, common_chat_parse( "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", COMMON_CHAT_FORMAT_COMMAND_R7B)); - assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b), + assert_msg_equals(message_assist_thoughts_unparsed_r7b, common_chat_parse( "<|START_THINKING|>I'm thinking<|END_THINKING|>" "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", COMMON_CHAT_FORMAT_COMMAND_R7B)); - assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_r7b), + assert_msg_equals(message_assist_thoughts_unparsed_r7b, common_chat_parse( "<|START_THINKING|>I'm thinking<|END_THINKING|>" "Hello, world!\nWhat's up?<|END_RESPONSE|>", COMMON_CHAT_FORMAT_COMMAND_R7B)); - assert_msg_equals(msg_from_json(message_assist_thoughts), + assert_msg_equals(message_assist_thoughts, common_chat_parse( "<|START_THINKING|>I'm thinking<|END_THINKING|>" "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING)); - test_template(tmpl, end_tokens, message_assist_call_idx, tools, + test_templates(tmpls.get(), end_tokens, message_assist_call_idx, tools, "<|START_THINKING|><|END_THINKING|>" "<|START_ACTION|>[\n" " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" "]<|END_ACTION|>"); - test_template(tmpl, end_tokens, message_assist_call_tool_plan_idx, tools, - "<|START_THINKING|>I'm thinking<|END_THINKING|>" - "<|START_ACTION|>[\n" - " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" - "]<|END_ACTION|>", - /* expect_grammar_triggered= */ true, - /* test_grammar_if_triggered= */ true, - /* think= */ true); - test_template(tmpl, end_tokens, message_assist, tools, + test_templates(tmpls.get(), end_tokens, message_assist, tools, "<|START_RESPONSE|>Hello, world!\n" "What's up?<|END_RESPONSE|>", /* expect_grammar_triggered= */ false); } { - const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "", ""); + auto tmpls = read_templates("models/templates/google-gemma-2-2b-it.jinja"); std::vector end_tokens{ "" }; - assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format); assert_equals(COMMON_CHAT_FORMAT_GENERIC, - common_chat_params_init( - common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), - "", ""), + common_chat_templates_apply( + read_templates("models/templates/microsoft-Phi-3.5-mini-instruct.jinja").get(), inputs_tools) .format); // Generic tool calls doesn't generate / parse content-only messages symmetrically. - assert_msg_equals(msg_from_json(message_assist), + assert_msg_equals(message_assist, common_chat_parse("{\n" " \"response\": \"Hello, world!\\nWhat's up?\"\n" "}", - common_chat_params_init(tmpl, inputs_tools).format)); - test_template(tmpl, end_tokens, message_assist_call_id, tools, + common_chat_templates_apply(tmpls.get(), inputs_tools).format)); + test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools, "{\n" " \"tool_calls\": [\n" " {\n" @@ -550,143 +481,133 @@ static void test_template_output_parsers() { "}"); } { - const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "", - ""); + auto tmpls = read_templates("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"); std::vector end_tokens{ "" }; - assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template( - tmpl, end_tokens, message_assist_call_id, tools, + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates( + tmpls.get(), end_tokens, message_assist_call_id, tools, "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]"); } { - const common_chat_template tmpl( - read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", ""); + auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"); std::vector end_tokens{ "<|im_end|>" }; - assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format); assert_equals( COMMON_CHAT_FORMAT_HERMES_2_PRO, - common_chat_params_init( - common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"), - "", ""), + common_chat_templates_apply( + read_templates("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja").get(), inputs_tools) .format); assert_equals( COMMON_CHAT_FORMAT_HERMES_2_PRO, - common_chat_params_init( - common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""), + common_chat_templates_apply( + read_templates("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja").get(), inputs_tools) .format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "\n" "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" ""); - test_template(tmpl, end_tokens, python_message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools, "\n" "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n" ""); } { - const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "", - ""); + auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"); std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; - assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format); assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, - common_chat_params_init(tmpl, inputs_tools_builtin).format); + common_chat_templates_apply(tmpls.get(), inputs_tools_builtin).format); assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, - common_chat_params_init( - common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"), - "", ""), + common_chat_templates_apply( + read_templates("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja").get(), inputs_tools_builtin) .format); - // test_template(tmpl, end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, code_interpreter_message_assist_call, llama_3_1_tools, + // test_templates(tmpls.get(), end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_call_code_interpreter, llama_3_1_tools, "<|python_tag|>code_interpreter.call(code=\"print('hey')\")"); - test_template(tmpl, end_tokens, python_message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools, "<|python_tag|>python.call(code=\"print('hey')\")"); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); } { - const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "", - ""); + auto tmpls = read_templates("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"); std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; - assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); } { - const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "", - ""); + auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja"); std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, - common_chat_params_init(tmpl, inputs_tools).format); + common_chat_templates_apply(tmpls.get(), inputs_tools).format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "{\"arg1\": 1}"); } { - const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "", - ""); + auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.2.jinja"); std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; - assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - test_template(tmpl, end_tokens, message_assist, {}, + test_templates(tmpls.get(), end_tokens, message_assist, {}, "all\n" "Hello, world!\n" "What's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "special_function\n" "{\"arg1\": 1}"); } { - const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "", - ""); + auto tmpls = read_templates("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"); std::vector end_tokens{ "<|eot_id|>" }; - assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]"); } { // Original DeepSeek R1 template. Leaves <|tool▁calls▁begin|> and others unclosed. Our logic fixes the prompt. - const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"), - "", ""); + auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"); std::vector end_tokens{ "<|end▁of▁sentence|>" }; - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format); - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think), + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + assert_msg_equals(message_assist_thoughts_unparsed_think, common_chat_parse("I'm thinkingHello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1)); - assert_msg_equals(msg_from_json(message_assist_thoughts), + assert_msg_equals(message_assist_thoughts, common_chat_parse("I'm thinkingHello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); - assert_msg_equals(msg_from_json(message_assist_thoughts), + assert_msg_equals(message_assist_thoughts, // Latest template update (ast of 20250209) adds a trailing \n if add_generation_prompt is true. common_chat_parse("I'm thinkingHello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); - // test_template(tmpl, end_tokens, message_assist_call, tools, + // test_templates(tmpls.get(), end_tokens, message_assist_call, tools, // "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" // "```json\n" // "{\"arg1\": 1}\n" @@ -697,23 +618,22 @@ static void test_template_output_parsers() { } { // Replacement DeepSeek R1 template. Makes the Distill Qwen 7B/32B models happy to call tools and all. - const common_chat_template tmpl(read_file("models/templates/llama-cpp-deepseek-r1.jinja"), - "", ""); + auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja"); std::vector end_tokens{ "<|end▁of▁sentence|>" }; - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format); - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_params_init(tmpl, inputs_tools_think).format); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format); - test_template(tmpl, end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_template(tmpl, end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - assert_msg_equals(msg_from_json(message_assist_thoughts_unparsed_think), + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); + assert_msg_equals(message_assist_thoughts_unparsed_think, common_chat_parse("I'm thinkingHello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1)); - assert_msg_equals(msg_from_json(message_assist_thoughts), + assert_msg_equals(message_assist_thoughts, common_chat_parse("I'm thinkingHello, world!\nWhat's up?", COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); - assert_msg_equals(msg_from_json(message_assist_call_thoughts_unparsed), + assert_msg_equals(message_assist_call_thoughts_unparsed, common_chat_parse( "I'm\nthinking\n\n" "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" @@ -721,7 +641,7 @@ static void test_template_output_parsers() { "{\"arg1\": 1}\n" "```<|tool▁call▁end|><|tool▁calls▁end|>", COMMON_CHAT_FORMAT_DEEPSEEK_R1)); - assert_msg_equals(msg_from_json(message_assist_call_thoughts), + assert_msg_equals(message_assist_call_thoughts, common_chat_parse( "I'm\nthinking\n\n" "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" @@ -729,7 +649,7 @@ static void test_template_output_parsers() { "{\"arg1\": 1}\n" "```<|tool▁call▁end|><|tool▁calls▁end|>", COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING)); - test_template(tmpl, end_tokens, message_assist_call, tools, + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" "```json\n" "{\"arg1\": 1}\n" @@ -740,11 +660,11 @@ static void test_template_output_parsers() { int main(int argc, char ** argv) { #ifndef _WIN32 if (argc > 1) { - common_chat_inputs inputs; + common_chat_templates_inputs inputs; inputs.messages = { - { { "role", "user" }, { "content", "Hey" } } + { "user", "Hey", {}, {}, "" }, }; - inputs.tools = json::array({ special_function_tool }); + inputs.tools = json::array({ special_function_tool }).dump(); std::cout << "| Template | Format |\n"; std::cout << "|----------|--------|\n"; @@ -756,10 +676,10 @@ int main(int argc, char ** argv) { std::cerr << "Skipping non-jinja file: " << path << std::endl; continue; } - common_chat_template tmpl(read_file(path), "", ""); + auto tmpls = read_templates(path); auto parts = string_split(path, "/"); auto name = parts[parts.size() - 1]; - auto format = common_chat_format_name(common_chat_params_init(tmpl, inputs).format); + auto format = common_chat_format_name(common_chat_templates_apply(tmpls.get(), inputs).format); std::cout << "| " << name << " | " << format << " |\n"; } catch (const std::exception & e) { std::cerr << "Failed to process " << argv[i] << ": " << e.what() << std::endl; From 2f683f08af03c4f94edc906860c98bc5030e1261 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Sat, 15 Feb 2025 17:10:13 +0000 Subject: [PATCH 03/41] rm minja dep from util & common --- common/common.cpp | 2 -- examples/server/utils.hpp | 2 -- tests/test-chat.cpp | 8 +------- 3 files changed, 1 insertion(+), 11 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index f005f1459938c..d2b0d50e3ee39 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -12,8 +12,6 @@ #include "json.hpp" #include "json-schema-to-grammar.h" #include "llama.h" -#include "chat.hpp" -#include "chat-template.hpp" #include #include diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 8ced4396793f4..e9f31cd8da8e2 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -12,9 +12,7 @@ // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" -#include "minja.hpp" #include "chat.hpp" -#include "chat-template.hpp" #include #include diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index e52019da71754..8cf9a1634d159 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -10,7 +10,6 @@ #include #include -#include "chat-template.hpp" #include "chat.hpp" #include "llama-grammar.h" #include "unicode.h" @@ -79,11 +78,6 @@ static bool match_string(const std::string & input, llama_grammar * grammar) { return false; } -// Dumps `{"a": 1}` as `"{\"a\": 1}"`, unlike nlohmann::json::dump which would dump it as `"{\"a\":1}"`. -static std::string dump(const json & j) { - return minja::Value(j).dump(-1, /* to_json= */ true); -} - static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) { assert_equals(expected.role, actual.role); assert_equals(expected.content, actual.content); @@ -100,7 +94,7 @@ static void assert_msg_equals(const common_chat_msg & expected, const common_cha const auto & expected_tool_call = expected.tool_calls[i]; const auto & actual_tool_call = actual.tool_calls[i]; assert_equals(expected_tool_call.name, actual_tool_call.name); - assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments))); + assert_equals(json::parse(expected_tool_call.arguments).dump(), json::parse(actual_tool_call.arguments).dump()); assert_equals(expected_tool_call.id, actual_tool_call.id); } } From 7a04ebcb6d754b6a6fb146cd0fae3d2ae5e556c2 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Sat, 15 Feb 2025 17:10:30 +0000 Subject: [PATCH 04/41] move minja to common/minja --- common/CMakeLists.txt | 4 ++-- common/chat.cpp | 4 ++-- common/{ => minja}/chat-template.hpp | 0 common/{ => minja}/minja.hpp | 0 4 files changed, 4 insertions(+), 4 deletions(-) rename common/{ => minja}/chat-template.hpp (100%) rename common/{ => minja}/minja.hpp (100%) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index c2b4aa7d09f1c..bf391c2ad90f0 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -58,7 +58,6 @@ add_library(${TARGET} STATIC base64.hpp chat.cpp chat.hpp - chat-template.hpp common.cpp common.h console.cpp @@ -68,7 +67,8 @@ add_library(${TARGET} STATIC llguidance.cpp log.cpp log.h - minja.hpp + minja/chat-template.hpp + minja/minja.hpp ngram-cache.cpp ngram-cache.h sampling.cpp diff --git a/common/chat.cpp b/common/chat.cpp index c6d8dbb394927..134f56b12fce7 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1,9 +1,9 @@ #include "chat.hpp" #include -#include "chat-template.hpp" #include "json-schema-to-grammar.h" #include "log.h" -#include "minja.hpp" +#include "minja/chat-template.hpp" +#include "minja/minja.hpp" namespace minja { class chat_template; diff --git a/common/chat-template.hpp b/common/minja/chat-template.hpp similarity index 100% rename from common/chat-template.hpp rename to common/minja/chat-template.hpp diff --git a/common/minja.hpp b/common/minja/minja.hpp similarity index 100% rename from common/minja.hpp rename to common/minja/minja.hpp From ece941b862b9d126492f054bb088c31367a1976c Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Sat, 15 Feb 2025 17:14:25 +0000 Subject: [PATCH 05/41] Update utils.hpp --- examples/server/utils.hpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index e9f31cd8da8e2..3d800df946e75 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -348,7 +348,11 @@ static llama_tokens format_infill( static std::vector oaicompat_messages_parse(const json & messages) { std::vector msgs; + if (!messages.is_array()) throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump()); + for (const auto & message : messages) { + if (!message.is_object()) throw std::runtime_error("Expected 'message' to be an object, got " + message.dump()); + common_chat_msg msg; msg.role = json_value(message, "role", std::string("")); From aa09a3ca31cf3cba846349c089cd2854129b9b99 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Sat, 15 Feb 2025 22:55:02 +0000 Subject: [PATCH 06/41] add common_chat_tool --- common/chat.cpp | 21 ++++++- common/chat.hpp | 8 ++- examples/server/utils.hpp | 10 +++- tests/test-chat.cpp | 113 ++++++++++++++++++-------------------- 4 files changed, 88 insertions(+), 64 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index 134f56b12fce7..d1d524c73a12c 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1202,6 +1202,23 @@ static json messages_to_json(const std::vector & msgs) { return messages; } +static json tools_to_json(const std::vector & tools) { + if (tools.empty()) return json(); + + auto result = json::array(); + for (const auto & tool : tools) { + result.push_back({ + {"type", "function"}, + {"function", { + {"name", tool.name}, + {"description", tool.description}, + {"parameters", json::parse(tool.parameters)}, + }}, + }); + } + return result; +} + common_chat_params common_chat_templates_apply( const struct common_chat_templates * tmpls, const struct common_chat_templates_inputs & inputs) @@ -1212,14 +1229,12 @@ common_chat_params common_chat_templates_apply( params.messages = messages_to_json(inputs.messages); params.add_generation_prompt = inputs.add_generation_prompt; params.extract_reasoning = inputs.extract_reasoning; + params.tools = tools_to_json(inputs.tools); params.tool_choice = inputs.tool_choice; params.grammar = inputs.grammar; if (!inputs.json_schema.empty()) { params.json_schema = json::parse(inputs.json_schema); } - if (!inputs.tools.empty()) { - params.tools = json::parse(inputs.tools); - } const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use ? *tmpls->template_tool_use : *tmpls->template_default; diff --git a/common/chat.hpp b/common/chat.hpp index bbd5daebf9f18..bfd59f9dbaac2 100644 --- a/common/chat.hpp +++ b/common/chat.hpp @@ -77,11 +77,17 @@ void common_chat_templates_free(struct common_chat_templates * tmpls); typedef std::unique_ptr common_chat_templates_ptr; +struct common_chat_tool { + std::string name; + std::string description; + std::string parameters; +}; + struct common_chat_templates_inputs { std::vector messages; std::string grammar; std::string json_schema; - std::string tools; + std::vector tools; common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; bool add_generation_prompt = true; bool use_jinja = true; diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 3d800df946e75..ce58fa398dcab 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -647,7 +647,15 @@ static json oaicompat_completion_params_parse( inputs.add_generation_prompt = true; inputs.use_jinja = use_jinja; inputs.grammar = grammar; - inputs.tools = tools.is_null() ? "" : tools.dump(); + if (tools.is_array()) { + for (const auto & tool : tools) { + inputs.tools.push_back({ + /* .name = */ tool.at("name"), + /* .description = */ tool.at("description"), + /* .arguments = */ tool.at("arguments"), + }); + } + } inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 8cf9a1634d159..6f809d334946b 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -99,59 +99,50 @@ static void assert_msg_equals(const common_chat_msg & expected, const common_cha } } -const auto special_function_tool = json::parse(R"({ - "type": "function", - "function": { - "name": "special_function", - "description": "I'm special", - "parameters": { - "type": "object", - "properties": { - "arg1": { - "type": "integer", - "description": "The arg." - } - }, - "required": ["arg1"] - } - } -})"); -const auto python_tool = json::parse(R"({ - "type": "function", - "function": { - "name": "python", - "description": "an ipython interpreter", - "parameters": { - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "Python code to execute." - } - }, - "required": ["code"] - } - } -})"); -const auto code_interpreter_tool = json::parse(R"({ - "type": "function", - "function": { - "name": "code_interpreter", - "description": "an ipython interpreter", - "parameters": { - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "Python code to execute." - } - }, - "required": ["code"] - } - } -})"); -const auto tools = json::array({ special_function_tool, python_tool }).dump(); -const auto llama_3_1_tools = json::array({ special_function_tool, code_interpreter_tool }).dump(); +common_chat_tool special_function_tool { + /* .name = */ "special_function", + /* .description = */ "I'm special", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "arg1": { + "type": "integer", + "description": "The arg." + } + }, + "required": ["arg1"] + })", +}; +common_chat_tool python_tool { + /* .name = */ "python", + /* .description = */ "an ipython interpreter", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute." + } + }, + "required": ["code"] + })", +}; +common_chat_tool code_interpreter_tool { + /* .name = */ "code_interpreter", + /* .description = */ "an ipython interpreter", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute." + } + }, + "required": ["code"] + })", +}; +std::vector tools { special_function_tool, python_tool }; +std::vector llama_3_1_tools { special_function_tool, code_interpreter_tool }; struct delta_data { std::string delta; @@ -159,7 +150,9 @@ struct delta_data { }; static delta_data init_delta(const struct common_chat_templates * tmpls, const std::vector & end_tokens, - const common_chat_msg & user_message, const common_chat_msg & delta_message, const json & tools, + const common_chat_msg & user_message, + const common_chat_msg & delta_message, + const std::vector & tools, const common_chat_tool_choice & tool_choice, bool think = false) { common_chat_templates_inputs inputs; @@ -214,7 +207,9 @@ static delta_data init_delta(const struct common_chat_templates * tmpls, const s the parsed message is the same as the test_message */ static void test_templates(const struct common_chat_templates * tmpls, const std::vector & end_tokens, - const common_chat_msg & test_message, const std::string & tools = "", const std::string & expected_delta = "", + const common_chat_msg & test_message, + const std::vector & tools = {}, + const std::string & expected_delta = "", bool expect_grammar_triggered = true, bool test_grammar_if_triggered = true, bool think = false) { @@ -377,17 +372,17 @@ static void test_template_output_parsers() { common_chat_templates_inputs inputs_tools; inputs_tools.messages = {message_user}; - inputs_tools.tools = json::array({special_function_tool}).dump(); + inputs_tools.tools = {special_function_tool}; inputs_tools.extract_reasoning = false; common_chat_templates_inputs inputs_tools_think; inputs_tools_think.messages = {message_user}; - inputs_tools_think.tools = json::array({special_function_tool}).dump(); + inputs_tools_think.tools = {special_function_tool}; inputs_tools_think.extract_reasoning = true; common_chat_templates_inputs inputs_tools_builtin; inputs_tools_builtin.messages = {message_user}; - inputs_tools_builtin.tools = json::array({python_tool}).dump(); + inputs_tools_builtin.tools = {python_tool}; inputs_tools_builtin.extract_reasoning = false; { @@ -658,7 +653,7 @@ int main(int argc, char ** argv) { inputs.messages = { { "user", "Hey", {}, {}, "" }, }; - inputs.tools = json::array({ special_function_tool }).dump(); + inputs.tools = { special_function_tool }; std::cout << "| Template | Format |\n"; std::cout << "|----------|--------|\n"; From 7ae756005e1857dd23a51c4452837018aed17982 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 15 Feb 2025 23:54:12 +0000 Subject: [PATCH 07/41] force utf8 encoding in get_chat_template --- scripts/get_chat_template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/get_chat_template.py b/scripts/get_chat_template.py index d8143e4005dec..08439169de315 100755 --- a/scripts/get_chat_template.py +++ b/scripts/get_chat_template.py @@ -21,7 +21,7 @@ def get_chat_template(model_id, variant=None): # Use huggingface_hub library if available. # Allows access to gated models if the user has access and ran `huggingface-cli login`. from huggingface_hub import hf_hub_download - with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f: + with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json"), "r", encoding="utf-8") as f: config_str = f.read() except ImportError: import requests From 646528a5c4038bce079cbbd08d98a5d6b78bcfcb Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 01:18:02 +0000 Subject: [PATCH 08/41] fix json tools parsing --- examples/server/utils.hpp | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index ce58fa398dcab..4277f9e047fb3 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -397,6 +397,33 @@ static std::vector oaicompat_messages_parse(const json & messag return msgs; } +static std::vector oaicompat_tools_parse(const json & tools) { + std::vector result; + + try { + if (!tools.is_null()) { + if (!tools.is_array()) throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump()); + for (const auto & tool : tools) { + if (!tool.contains("type")) throw std::runtime_error("Missing tool type: " + tool.dump()); + const auto & type = tool.at("type"); + if (!type.is_string() || type != "function") throw std::runtime_error("Unsupported tool type: " + tool.dump()); + if (!tool.contains("function")) throw std::runtime_error("Missing tool function: " + tool.dump()); + + const auto & function = tool.at("function"); + result.push_back({ + /* .name = */ function.at("name"), + /* .description = */ function.at("description"), + /* .parameters = */ function.at("parameters").dump(), + }); + } + } + } catch (const std::exception & e) { + throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2)); + } + + return result; +} + // // base64 utils (TODO: move to common in the future) // @@ -647,15 +674,7 @@ static json oaicompat_completion_params_parse( inputs.add_generation_prompt = true; inputs.use_jinja = use_jinja; inputs.grammar = grammar; - if (tools.is_array()) { - for (const auto & tool : tools) { - inputs.tools.push_back({ - /* .name = */ tool.at("name"), - /* .description = */ tool.at("description"), - /* .arguments = */ tool.at("arguments"), - }); - } - } + inputs.tools = oaicompat_tools_parse(tools); inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; From db2b44eb7ed7f97feedac5d2fa23bc48dcad1082 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 01:39:05 +0000 Subject: [PATCH 09/41] add json tools / messages parsing helpers to common --- common/chat.cpp | 99 +++++++++++++++++++++++++++++++++++++++ common/chat.hpp | 46 ++++++++++-------- examples/server/utils.hpp | 82 +------------------------------- 3 files changed, 128 insertions(+), 99 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index d1d524c73a12c..ec8bff6b07e9d 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1401,3 +1401,102 @@ common_chat_tool_choice common_chat_tool_choice_parse(const std::string & tool_c if (tool_choice == "required") return COMMON_CHAT_TOOL_CHOICE_REQUIRED; throw std::runtime_error("Invalid tool_choice: " + tool_choice); } + + +template <> +std::vector common_chat_msgs_parse_oaicompat(const json & messages) { + std::vector msgs; + + try { + + if (!messages.is_array()) throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump()); + + for (const auto & message : messages) { + if (!message.is_object()) throw std::runtime_error("Expected 'message' to be an object, got " + message.dump()); + + common_chat_msg msg; + if (!message.contains("role")) throw std::runtime_error("Missing 'role' in message: " + message.dump()); + msg.role = message.at("role"); + + if (message.contains("content")) { + const auto & content = message.at("content"); + if (content.is_string()) { + msg.content = content; + } else if (content.is_array()) { + for (const auto & part : content) { + if (!part.contains("type")) throw std::runtime_error("Missing content part type: " + part.dump()); + const auto & type = part.at("type"); + if (type != "text") throw std::runtime_error("Unsupported content part type: " + type.dump()); + common_chat_msg_content_part msg_part; + msg_part.type = type; + msg_part.text = part.at("text"); + msg.content_parts.push_back(msg_part); + } + } else if (!content.is_null()) { + throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); + } + } else { + throw std::runtime_error("Expected 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); + } + if (message.contains("reasoning_content")) { + msg.reasoning_content = message.at("reasoning_content"); + } + if (message.contains("tool_calls")) { + for (const auto & tool_call : message.at("tool_calls")) { + common_chat_tool_call tc; + if (!tool_call.contains("name")) throw std::runtime_error("Missing tool call name: " + tool_call.dump()); + tc.name = tool_call.at("name"); + tc.arguments = tool_call.at("arguments"); + if (tool_call.contains("id")) { + tc.id = tool_call.at("id"); + } + msg.tool_calls.push_back(tc); + } + } + + msgs.push_back(msg); + } + } catch (const std::exception & e) { + throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2)); + } + + return msgs; +} + +template <> +std::vector common_chat_msgs_parse_oaicompat(const std::string & messages) { + return common_chat_msgs_parse_oaicompat(json::parse(messages)); +} + +template <> +std::vector common_chat_tools_parse_oaicompat(const json & tools) { + std::vector result; + + try { + if (!tools.is_null()) { + if (!tools.is_array()) throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump()); + for (const auto & tool : tools) { + if (!tool.contains("type")) throw std::runtime_error("Missing tool type: " + tool.dump()); + const auto & type = tool.at("type"); + if (!type.is_string() || type != "function") throw std::runtime_error("Unsupported tool type: " + tool.dump()); + if (!tool.contains("function")) throw std::runtime_error("Missing tool function: " + tool.dump()); + + const auto & function = tool.at("function"); + result.push_back({ + /* .name = */ function.at("name"), + /* .description = */ function.at("description"), + /* .parameters = */ function.at("parameters").dump(), + }); + } + } + } catch (const std::exception & e) { + throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2)); + } + + return result; +} + +template <> +std::vector common_chat_tools_parse_oaicompat(const std::string & tools) { + return common_chat_tools_parse_oaicompat(json::parse(tools)); +} diff --git a/common/chat.hpp b/common/chat.hpp index bfd59f9dbaac2..031db2b303ebf 100644 --- a/common/chat.hpp +++ b/common/chat.hpp @@ -29,6 +29,12 @@ struct common_chat_msg { std::string reasoning_content; }; +struct common_chat_tool { + std::string name; + std::string description; + std::string parameters; +}; + enum common_chat_tool_choice { COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED, @@ -53,6 +59,19 @@ enum common_chat_format { COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; +struct common_chat_templates_inputs { + std::vector messages; + std::string grammar; + std::string json_schema; + bool add_generation_prompt = true; + bool use_jinja = true; + // Parameters below only supported when use_jinja is true + std::vector tools; + common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; + bool parallel_tool_calls = false; + bool extract_reasoning = true; +}; + struct common_chat_params { common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; std::string prompt; @@ -77,25 +96,6 @@ void common_chat_templates_free(struct common_chat_templates * tmpls); typedef std::unique_ptr common_chat_templates_ptr; -struct common_chat_tool { - std::string name; - std::string description; - std::string parameters; -}; - -struct common_chat_templates_inputs { - std::vector messages; - std::string grammar; - std::string json_schema; - std::vector tools; - common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; - bool add_generation_prompt = true; - bool use_jinja = true; - // Parameters below only supported when use_jinja is true - bool parallel_tool_calls = false; - bool extract_reasoning = true; -}; - struct common_chat_params common_chat_templates_apply( const struct common_chat_templates * tmpls, const struct common_chat_templates_inputs & inputs); @@ -117,3 +117,11 @@ std::string common_chat_format_name(common_chat_format format); common_chat_msg common_chat_parse( const std::string & input, common_chat_format format); common_chat_tool_choice common_chat_tool_choice_parse(const std::string & tool_choice); + +// Parses a JSON array of messages in OpenAI's chat completion API format. +// T can be std::string containing JSON or nlohmann::ordered_json +template std::vector common_chat_msgs_parse_oaicompat(const T & messages); + +// Parses a JSON array of tools in OpenAI's chat completion tool call API format. +// T can be std::string containing JSON or nlohmann::ordered_json +template std::vector common_chat_tools_parse_oaicompat(const T & tools); \ No newline at end of file diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 4277f9e047fb3..5b9d3b1a24d1c 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -345,84 +345,6 @@ static llama_tokens format_infill( return embd_inp; } -static std::vector oaicompat_messages_parse(const json & messages) { - std::vector msgs; - - if (!messages.is_array()) throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump()); - - for (const auto & message : messages) { - if (!message.is_object()) throw std::runtime_error("Expected 'message' to be an object, got " + message.dump()); - - common_chat_msg msg; - msg.role = json_value(message, "role", std::string("")); - - if (message.contains("content")) { - const auto & content = message.at("content"); - if (content.is_string()) { - msg.content = content; - } else if (content.is_array()) { - for (const auto & part : content) { - if (!part.contains("type")) throw std::runtime_error("Missing content part type: " + part.dump()); - const auto & type = part.at("type"); - if (type != "text") throw std::runtime_error("Unsupported content part type: " + type.dump()); - common_chat_msg_content_part msg_part; - msg_part.type = type; - msg_part.text = part.at("text"); - msg.content_parts.push_back(msg_part); - } - } else if (!content.is_null()) { - throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); - } - } else { - throw std::runtime_error("Expected 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); - } - if (message.contains("reasoning_content")) { - msg.reasoning_content = message.at("reasoning_content"); - } - if (message.contains("tool_calls")) { - for (const auto & tool_call : message.at("tool_calls")) { - common_chat_tool_call tc; - tc.name = json_value(tool_call, "tool", std::string("")); - tc.arguments = tool_call.at("arguments"); - if (tool_call.contains("id")) { - tc.id = tool_call.at("id"); - } - msg.tool_calls.push_back(tc); - } - } - - msgs.push_back(msg); - } - - return msgs; -} - -static std::vector oaicompat_tools_parse(const json & tools) { - std::vector result; - - try { - if (!tools.is_null()) { - if (!tools.is_array()) throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump()); - for (const auto & tool : tools) { - if (!tool.contains("type")) throw std::runtime_error("Missing tool type: " + tool.dump()); - const auto & type = tool.at("type"); - if (!type.is_string() || type != "function") throw std::runtime_error("Unsupported tool type: " + tool.dump()); - if (!tool.contains("function")) throw std::runtime_error("Missing tool function: " + tool.dump()); - - const auto & function = tool.at("function"); - result.push_back({ - /* .name = */ function.at("name"), - /* .description = */ function.at("description"), - /* .parameters = */ function.at("parameters").dump(), - }); - } - } - } catch (const std::exception & e) { - throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2)); - } - - return result; -} // // base64 utils (TODO: move to common in the future) @@ -670,11 +592,11 @@ static json oaicompat_completion_params_parse( } common_chat_templates_inputs inputs; - inputs.messages = oaicompat_messages_parse(body.at("messages")); + inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages")); inputs.add_generation_prompt = true; inputs.use_jinja = use_jinja; inputs.grammar = grammar; - inputs.tools = oaicompat_tools_parse(tools); + inputs.tools = common_chat_tools_parse_oaicompat(tools); inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; From c7c890707834a1f439d02acfbc90b5490ac6e291 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 01:46:21 +0000 Subject: [PATCH 10/41] fix common_chat_msgs_parse_oaicompat --- common/chat.cpp | 17 +++++++++++------ examples/server/tests/unit/test_tool_call.py | 2 +- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index ec8bff6b07e9d..56e016e42c4d0 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1433,7 +1433,7 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa msg.content_parts.push_back(msg_part); } } else if (!content.is_null()) { - throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); + throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); } } else { throw std::runtime_error("Expected 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); @@ -1444,11 +1444,16 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa if (message.contains("tool_calls")) { for (const auto & tool_call : message.at("tool_calls")) { common_chat_tool_call tc; - if (!tool_call.contains("name")) throw std::runtime_error("Missing tool call name: " + tool_call.dump()); - tc.name = tool_call.at("name"); - tc.arguments = tool_call.at("arguments"); - if (tool_call.contains("id")) { - tc.id = tool_call.at("id"); + if (!tool_call.contains("type")) throw std::runtime_error("Missing tool call type: " + tool_call.dump()); + const auto & type = tool_call.at("type"); + if (type != "function") throw std::runtime_error("Unsupported tool call type: " + tool_call.dump()); + if (!tool_call.contains("function")) throw std::runtime_error("Missing tool call function: " + tool_call.dump()); + const auto & fc = tool_call.at("function"); + if (!fc.contains("name")) throw std::runtime_error("Missing tool call name: " + tool_call.dump()); + tc.name = fc.at("name"); + tc.arguments = fc.at("arguments"); + if (fc.contains("id")) { + tc.id = fc.at("id"); } msg.tool_calls.push_back(tc); } diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py index ba3367b4f332d..8aca10011d148 100644 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -401,7 +401,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, { "role": "tool", "name": "calculate", - "content": 0.55644242476, + "content": "0.55644242476", "tool_call_id": "call_6789" } ], From 5f17156db3c2ff4ef9b80c370fcd24b5a4669721 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 01:58:32 +0000 Subject: [PATCH 11/41] concat multipart content in legacy template path --- common/chat.cpp | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index 56e016e42c4d0..d0def3f31f951 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1317,9 +1317,18 @@ common_chat_params common_chat_templates_apply( int alloc_size = 0; std::vector chat; + std::vector contents; for (const auto & msg : inputs.messages) { - chat.push_back({msg.role.c_str(), msg.content.c_str()}); - alloc_size += (msg.role.size() + msg.content.size()) * 1.25; + auto content = msg.content; + for (const auto & part : msg.content_parts) { + if (!content.empty()) { + content += "\n";; + } + content += part.text; + } + contents.emplace_back(std::move(content)); + chat.push_back({msg.role.c_str(), contents.back().c_str()}); + alloc_size += (msg.role.size() + contents.back().size()) * 1.25; } std::vector buf(alloc_size); From ee9b9d69c9aa766e1b0ae31842c6ba14489dd194 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 02:18:40 +0000 Subject: [PATCH 12/41] add name & tool_call_id to common_chat_msg --- common/chat.cpp | 17 +++++++++++++---- common/chat.hpp | 4 +++- examples/server/utils.hpp | 24 ++++++++++++------------ 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index d0def3f31f951..169bcd915c054 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -5,9 +5,6 @@ #include "minja/chat-template.hpp" #include "minja/minja.hpp" -namespace minja { - class chat_template; -} typedef minja::chat_template common_chat_template; @@ -1181,6 +1178,12 @@ static json messages_to_json(const std::vector & msgs) { if (!msg.reasoning_content.empty()) { jmsg["reasoning_content"] = msg.reasoning_content; } + if (!msg.tool_name.empty()) { + jmsg["name"] = msg.tool_name; + } + if (!msg.tool_call_id.empty()) { + jmsg["tool_call_id"] = json::parse(msg.tool_call_id); + } if (!msg.tool_calls.empty()) { auto & tool_calls = jmsg["tool_calls"] = json::array(); for (const auto & tool_call : msg.tool_calls) { @@ -1404,7 +1407,7 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format } } -common_chat_tool_choice common_chat_tool_choice_parse(const std::string & tool_choice) { +common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) { if (tool_choice == "auto") return COMMON_CHAT_TOOL_CHOICE_AUTO; if (tool_choice == "none") return COMMON_CHAT_TOOL_CHOICE_NONE; if (tool_choice == "required") return COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -1450,6 +1453,12 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa if (message.contains("reasoning_content")) { msg.reasoning_content = message.at("reasoning_content"); } + if (message.contains("name")) { + msg.tool_name = message.at("name"); + } + if (message.contains("tool_call_id")) { + msg.tool_call_id = message.at("tool_call_id"); + } if (message.contains("tool_calls")) { for (const auto & tool_call : message.at("tool_calls")) { common_chat_tool_call tc; diff --git a/common/chat.hpp b/common/chat.hpp index 031db2b303ebf..e67a3d67f32ab 100644 --- a/common/chat.hpp +++ b/common/chat.hpp @@ -27,6 +27,8 @@ struct common_chat_msg { std::vector content_parts; std::vector tool_calls; std::string reasoning_content; + std::string tool_name; + std::string tool_call_id; }; struct common_chat_tool { @@ -116,7 +118,7 @@ std::string common_chat_format_example( std::string common_chat_format_name(common_chat_format format); common_chat_msg common_chat_parse( const std::string & input, common_chat_format format); -common_chat_tool_choice common_chat_tool_choice_parse(const std::string & tool_choice); +common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice); // Parses a JSON array of messages in OpenAI's chat completion API format. // T can be std::string containing JSON or nlohmann::ordered_json diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 5b9d3b1a24d1c..d7ddfd0ae43c2 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -592,15 +592,15 @@ static json oaicompat_completion_params_parse( } common_chat_templates_inputs inputs; - inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages")); + inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages")); + inputs.tools = common_chat_tools_parse_oaicompat(tools); + inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto"))); + inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); + inputs.grammar = grammar; inputs.add_generation_prompt = true; - inputs.use_jinja = use_jinja; - inputs.grammar = grammar; - inputs.tools = common_chat_tools_parse_oaicompat(tools); - inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); - inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); - inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; - inputs.tool_choice = common_chat_tool_choice_parse(json_value(body, "tool_choice", std::string("auto"))); + inputs.use_jinja = use_jinja; + inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); + inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; if (inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && llama_params.contains("grammar")) { throw std::runtime_error("Cannot use custom grammar constraints with tools."); } @@ -608,10 +608,10 @@ static json oaicompat_completion_params_parse( // Apply chat template to the list of messages auto chat_params = common_chat_templates_apply(tmpls, inputs); - llama_params["chat_format"] = static_cast(chat_params.format); - llama_params["prompt"] = chat_params.prompt; - llama_params["grammar"] = chat_params.grammar; - llama_params["grammar_lazy"] = chat_params.grammar_lazy; + llama_params["chat_format"] = static_cast(chat_params.format); + llama_params["prompt"] = chat_params.prompt; + llama_params["grammar"] = chat_params.grammar; + llama_params["grammar_lazy"] = chat_params.grammar_lazy; auto grammar_triggers = json::array(); for (const auto & trigger : chat_params.grammar_triggers) { grammar_triggers.push_back({ From 07f0ad0e9bc6cee3a2c34c499a9e1e3746062d5b Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 02:20:25 +0000 Subject: [PATCH 13/41] Update test-chat.cpp --- tests/test-chat.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 6f809d334946b..7e0c08b4ad54d 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -318,6 +318,8 @@ static void test_template_output_parsers() { /* .content_parts = */ {}, tool_calls, /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", }; common_chat_msg message_assist_call_thoughts = { "assistant", @@ -325,6 +327,8 @@ static void test_template_output_parsers() { /* .content_parts = */ {}, tool_calls, /* .reasoning_content = */ "I'm\nthinking", + /* .tool_name = */ "", + /* .tool_call_id = */ "", }; common_chat_msg message_assist_call_thoughts_unparsed = { "assistant", @@ -332,6 +336,8 @@ static void test_template_output_parsers() { /* .content_parts = */ {}, tool_calls, /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", }; common_chat_msg message_assist_call_id { "assistant", @@ -339,6 +345,8 @@ static void test_template_output_parsers() { /* .content_parts = */ {}, tool_calls_id, /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", }; common_chat_msg message_assist_call_idx { "assistant", @@ -346,6 +354,8 @@ static void test_template_output_parsers() { /* .content_parts = */ {}, tool_calls_idx, /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", }; common_chat_msg message_assist_call_python { "assistant", @@ -353,6 +363,8 @@ static void test_template_output_parsers() { /* .content_parts = */ {}, { { "python", "{\"code\": \"print('hey')\"}", /* .id = */ "" } }, /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", }; common_chat_msg message_assist_call_code_interpreter { "assistant", @@ -360,6 +372,8 @@ static void test_template_output_parsers() { /* .content_parts = */ {}, { { "code_interpreter", "{\"code\": \"print('hey')\"}", /* .id = */ "" } }, /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", }; common_chat_templates_inputs inputs_no_tools; From 1acda5f597019deda29ed3d6f5681d7f7f8c0c83 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 02:45:26 +0000 Subject: [PATCH 14/41] test & fix json<->msg conversions --- common/chat.cpp | 377 ++++++++++++++++++++++---------------------- common/chat.hpp | 14 +- tests/test-chat.cpp | 236 ++++++++++++++------------- 3 files changed, 326 insertions(+), 301 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index 169bcd915c054..e906fb6aa2fc1 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -26,6 +26,193 @@ struct templates_params { bool extract_reasoning = true; }; +common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) { + if (tool_choice == "auto") return COMMON_CHAT_TOOL_CHOICE_AUTO; + if (tool_choice == "none") return COMMON_CHAT_TOOL_CHOICE_NONE; + if (tool_choice == "required") return COMMON_CHAT_TOOL_CHOICE_REQUIRED; + throw std::runtime_error("Invalid tool_choice: " + tool_choice); +} + +template <> +std::vector common_chat_msgs_parse_oaicompat(const json & messages) { + std::vector msgs; + + try { + + if (!messages.is_array()) throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump()); + + for (const auto & message : messages) { + if (!message.is_object()) throw std::runtime_error("Expected 'message' to be an object, got " + message.dump()); + + common_chat_msg msg; + if (!message.contains("role")) throw std::runtime_error("Missing 'role' in message: " + message.dump()); + msg.role = message.at("role"); + + if (message.contains("content")) { + const auto & content = message.at("content"); + if (content.is_string()) { + msg.content = content; + } else if (content.is_array()) { + for (const auto & part : content) { + if (!part.contains("type")) throw std::runtime_error("Missing content part type: " + part.dump()); + const auto & type = part.at("type"); + if (type != "text") throw std::runtime_error("Unsupported content part type: " + type.dump()); + common_chat_msg_content_part msg_part; + msg_part.type = type; + msg_part.text = part.at("text"); + msg.content_parts.push_back(msg_part); + } + } else if (!content.is_null()) { + throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); + } + } else { + throw std::runtime_error("Expected 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); + } + if (message.contains("reasoning_content")) { + msg.reasoning_content = message.at("reasoning_content"); + } + if (message.contains("name")) { + msg.tool_name = message.at("name"); + } + if (message.contains("tool_call_id")) { + msg.tool_call_id = message.at("tool_call_id"); + } + if (message.contains("tool_calls")) { + for (const auto & tool_call : message.at("tool_calls")) { + common_chat_tool_call tc; + if (!tool_call.contains("type")) throw std::runtime_error("Missing tool call type: " + tool_call.dump()); + const auto & type = tool_call.at("type"); + if (type != "function") throw std::runtime_error("Unsupported tool call type: " + tool_call.dump()); + if (!tool_call.contains("function")) throw std::runtime_error("Missing tool call function: " + tool_call.dump()); + const auto & fc = tool_call.at("function"); + if (!fc.contains("name")) throw std::runtime_error("Missing tool call name: " + tool_call.dump()); + tc.name = fc.at("name"); + tc.arguments = fc.at("arguments"); + if (tool_call.contains("id")) { + tc.id = tool_call.at("id"); + } + msg.tool_calls.push_back(tc); + } + } + + msgs.push_back(msg); + } + } catch (const std::exception & e) { + throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2)); + } + + return msgs; +} + +template <> +json common_chat_msgs_to_json_oaicompat(const std::vector & msgs) { + json messages = json::array(); + for (const auto & msg : msgs) { + if (!msg.content.empty() && !msg.content_parts.empty()) { + throw std::runtime_error("Cannot specify both content and content_parts"); + } + json jmsg { + {"role", msg.role}, + }; + if (!msg.content.empty()) { + jmsg["content"] = msg.content; + } else if (!msg.content_parts.empty()) { + auto & parts = jmsg["content"] = json::array(); + for (const auto & part : msg.content_parts) { + parts.push_back({ + {"type", part.type}, + {"text", part.text}, + }); + } + } else { + jmsg["content"] = json(); // null + } + if (!msg.reasoning_content.empty()) { + jmsg["reasoning_content"] = msg.reasoning_content; + } + if (!msg.tool_name.empty()) { + jmsg["name"] = msg.tool_name; + } + if (!msg.tool_call_id.empty()) { + jmsg["tool_call_id"] = json::parse(msg.tool_call_id); + } + if (!msg.tool_calls.empty()) { + auto & tool_calls = jmsg["tool_calls"] = json::array(); + for (const auto & tool_call : msg.tool_calls) { + json tc { + {"type", "function"}, + {"function", { + {"name", tool_call.name}, + {"arguments", tool_call.arguments}, + }}, + }; + if (!tool_call.id.empty()) { + tc["id"] = tool_call.id; + } + tool_calls.push_back(tc); + } + } + messages.push_back(jmsg); + } + return messages; +} + +template <> +std::vector common_chat_msgs_parse_oaicompat(const std::string & messages) { + return common_chat_msgs_parse_oaicompat(json::parse(messages)); +} + +template <> +std::vector common_chat_tools_parse_oaicompat(const json & tools) { + std::vector result; + + try { + if (!tools.is_null()) { + if (!tools.is_array()) throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump()); + for (const auto & tool : tools) { + if (!tool.contains("type")) throw std::runtime_error("Missing tool type: " + tool.dump()); + const auto & type = tool.at("type"); + if (!type.is_string() || type != "function") throw std::runtime_error("Unsupported tool type: " + tool.dump()); + if (!tool.contains("function")) throw std::runtime_error("Missing tool function: " + tool.dump()); + + const auto & function = tool.at("function"); + result.push_back({ + /* .name = */ function.at("name"), + /* .description = */ function.at("description"), + /* .parameters = */ function.at("parameters").dump(), + }); + } + } + } catch (const std::exception & e) { + throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2)); + } + + return result; +} + +template <> +std::vector common_chat_tools_parse_oaicompat(const std::string & tools) { + return common_chat_tools_parse_oaicompat(json::parse(tools)); +} + +template <> +json common_chat_tools_to_json_oaicompat(const std::vector & tools) { + if (tools.empty()) return json(); + + auto result = json::array(); + for (const auto & tool : tools) { + result.push_back({ + {"type", "function"}, + {"function", { + {"name", tool.name}, + {"description", tool.description}, + {"parameters", json::parse(tool.parameters)}, + }}, + }); + } + return result; +} + bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { @@ -1153,75 +1340,6 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha return data; } -static json messages_to_json(const std::vector & msgs) { - json messages = json::array(); - for (const auto & msg : msgs) { - if (!msg.content.empty() && !msg.content_parts.empty()) { - throw std::runtime_error("Cannot specify both content and content_parts"); - } - json jmsg { - {"role", msg.role}, - }; - if (!msg.content.empty()) { - jmsg["content"] = msg.content; - } else if (!msg.content_parts.empty()) { - auto & parts = jmsg["content"] = json::array(); - for (const auto & part : msg.content_parts) { - parts.push_back({ - {"type", part.type}, - {"text", part.text}, - }); - } - } else { - jmsg["content"] = json(); // null - } - if (!msg.reasoning_content.empty()) { - jmsg["reasoning_content"] = msg.reasoning_content; - } - if (!msg.tool_name.empty()) { - jmsg["name"] = msg.tool_name; - } - if (!msg.tool_call_id.empty()) { - jmsg["tool_call_id"] = json::parse(msg.tool_call_id); - } - if (!msg.tool_calls.empty()) { - auto & tool_calls = jmsg["tool_calls"] = json::array(); - for (const auto & tool_call : msg.tool_calls) { - json tc { - {"type", "function"}, - {"function", { - {"name", tool_call.name}, - {"arguments", tool_call.arguments}, - }}, - }; - if (!tool_call.id.empty()) { - tc["id"] = tool_call.id; - } - tool_calls.push_back(tc); - } - } - messages.push_back(jmsg); - } - return messages; -} - -static json tools_to_json(const std::vector & tools) { - if (tools.empty()) return json(); - - auto result = json::array(); - for (const auto & tool : tools) { - result.push_back({ - {"type", "function"}, - {"function", { - {"name", tool.name}, - {"description", tool.description}, - {"parameters", json::parse(tool.parameters)}, - }}, - }); - } - return result; -} - common_chat_params common_chat_templates_apply( const struct common_chat_templates * tmpls, const struct common_chat_templates_inputs & inputs) @@ -1229,10 +1347,10 @@ common_chat_params common_chat_templates_apply( GGML_ASSERT(tmpls != nullptr); if (inputs.use_jinja) { templates_params params; - params.messages = messages_to_json(inputs.messages); + params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages); params.add_generation_prompt = inputs.add_generation_prompt; params.extract_reasoning = inputs.extract_reasoning; - params.tools = tools_to_json(inputs.tools); + params.tools = common_chat_tools_to_json_oaicompat(inputs.tools); params.tool_choice = inputs.tool_choice; params.grammar = inputs.grammar; if (!inputs.json_schema.empty()) { @@ -1406,120 +1524,3 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format throw std::runtime_error("Unsupported format: " + common_chat_format_name(format)); } } - -common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) { - if (tool_choice == "auto") return COMMON_CHAT_TOOL_CHOICE_AUTO; - if (tool_choice == "none") return COMMON_CHAT_TOOL_CHOICE_NONE; - if (tool_choice == "required") return COMMON_CHAT_TOOL_CHOICE_REQUIRED; - throw std::runtime_error("Invalid tool_choice: " + tool_choice); -} - - -template <> -std::vector common_chat_msgs_parse_oaicompat(const json & messages) { - std::vector msgs; - - try { - - if (!messages.is_array()) throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump()); - - for (const auto & message : messages) { - if (!message.is_object()) throw std::runtime_error("Expected 'message' to be an object, got " + message.dump()); - - common_chat_msg msg; - if (!message.contains("role")) throw std::runtime_error("Missing 'role' in message: " + message.dump()); - msg.role = message.at("role"); - - if (message.contains("content")) { - const auto & content = message.at("content"); - if (content.is_string()) { - msg.content = content; - } else if (content.is_array()) { - for (const auto & part : content) { - if (!part.contains("type")) throw std::runtime_error("Missing content part type: " + part.dump()); - const auto & type = part.at("type"); - if (type != "text") throw std::runtime_error("Unsupported content part type: " + type.dump()); - common_chat_msg_content_part msg_part; - msg_part.type = type; - msg_part.text = part.at("text"); - msg.content_parts.push_back(msg_part); - } - } else if (!content.is_null()) { - throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); - } - } else { - throw std::runtime_error("Expected 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); - } - if (message.contains("reasoning_content")) { - msg.reasoning_content = message.at("reasoning_content"); - } - if (message.contains("name")) { - msg.tool_name = message.at("name"); - } - if (message.contains("tool_call_id")) { - msg.tool_call_id = message.at("tool_call_id"); - } - if (message.contains("tool_calls")) { - for (const auto & tool_call : message.at("tool_calls")) { - common_chat_tool_call tc; - if (!tool_call.contains("type")) throw std::runtime_error("Missing tool call type: " + tool_call.dump()); - const auto & type = tool_call.at("type"); - if (type != "function") throw std::runtime_error("Unsupported tool call type: " + tool_call.dump()); - if (!tool_call.contains("function")) throw std::runtime_error("Missing tool call function: " + tool_call.dump()); - const auto & fc = tool_call.at("function"); - if (!fc.contains("name")) throw std::runtime_error("Missing tool call name: " + tool_call.dump()); - tc.name = fc.at("name"); - tc.arguments = fc.at("arguments"); - if (fc.contains("id")) { - tc.id = fc.at("id"); - } - msg.tool_calls.push_back(tc); - } - } - - msgs.push_back(msg); - } - } catch (const std::exception & e) { - throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2)); - } - - return msgs; -} - -template <> -std::vector common_chat_msgs_parse_oaicompat(const std::string & messages) { - return common_chat_msgs_parse_oaicompat(json::parse(messages)); -} - -template <> -std::vector common_chat_tools_parse_oaicompat(const json & tools) { - std::vector result; - - try { - if (!tools.is_null()) { - if (!tools.is_array()) throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump()); - for (const auto & tool : tools) { - if (!tool.contains("type")) throw std::runtime_error("Missing tool type: " + tool.dump()); - const auto & type = tool.at("type"); - if (!type.is_string() || type != "function") throw std::runtime_error("Unsupported tool type: " + tool.dump()); - if (!tool.contains("function")) throw std::runtime_error("Missing tool function: " + tool.dump()); - - const auto & function = tool.at("function"); - result.push_back({ - /* .name = */ function.at("name"), - /* .description = */ function.at("description"), - /* .parameters = */ function.at("parameters").dump(), - }); - } - } - } catch (const std::exception & e) { - throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2)); - } - - return result; -} - -template <> -std::vector common_chat_tools_parse_oaicompat(const std::string & tools) { - return common_chat_tools_parse_oaicompat(json::parse(tools)); -} diff --git a/common/chat.hpp b/common/chat.hpp index e67a3d67f32ab..1bfc22dc199b1 100644 --- a/common/chat.hpp +++ b/common/chat.hpp @@ -24,11 +24,11 @@ struct common_chat_msg_content_part { struct common_chat_msg { std::string role; std::string content; - std::vector content_parts; - std::vector tool_calls; - std::string reasoning_content; - std::string tool_name; - std::string tool_call_id; + std::vector content_parts = {}; + std::vector tool_calls = {}; + std::string reasoning_content = ""; + std::string tool_name = ""; + std::string tool_call_id = ""; }; struct common_chat_tool { @@ -123,7 +123,9 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin // Parses a JSON array of messages in OpenAI's chat completion API format. // T can be std::string containing JSON or nlohmann::ordered_json template std::vector common_chat_msgs_parse_oaicompat(const T & messages); +template T common_chat_msgs_to_json_oaicompat(const std::vector & messages); // Parses a JSON array of tools in OpenAI's chat completion tool call API format. // T can be std::string containing JSON or nlohmann::ordered_json -template std::vector common_chat_tools_parse_oaicompat(const T & tools); \ No newline at end of file +template std::vector common_chat_tools_parse_oaicompat(const T & tools); +template T common_chat_tools_to_json_oaicompat(const std::vector & tools); \ No newline at end of file diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 7e0c08b4ad54d..df7a19a579fca 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -266,115 +266,136 @@ static void test_templates(const struct common_chat_templates * tmpls, const std } } -static void test_template_output_parsers() { - common_chat_msg message_user { - "user", - "Hey there!", - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "", - }; - common_chat_msg message_assist { - "assistant", - "Hello, world!\nWhat's up?", - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "", - }; - common_chat_msg message_assist_thoughts_unparsed_think { - "assistant", - "I'm thinkingHello, world!\nWhat's up?", - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "", - }; - common_chat_msg message_assist_thoughts_unparsed_r7b { - "assistant", - "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?", - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "", - }; - common_chat_msg message_assist_thoughts { - "assistant", - "Hello, world!\nWhat's up?", - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "I'm thinking", - }; - std::vector tool_calls { - { "special_function", "{\"arg1\": 1}", /* .id = */ "" }, - }; - std::vector tool_calls_idx { - { "special_function", "{\"arg1\": 1}", /* .id = */ "0" }, - }; - std::vector tool_calls_id { - { "special_function", "{\"arg1\": 1}", /* .id = */ "123456789" }, - }; +const common_chat_msg message_user { + "user", + "Hey there!", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", +}; +const common_chat_msg message_assist { + "assistant", + "Hello, world!\nWhat's up?", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", +}; +const common_chat_msg message_assist_thoughts_unparsed_think { + "assistant", + "I'm thinkingHello, world!\nWhat's up?", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", +}; +const common_chat_msg message_assist_thoughts_unparsed_r7b { + "assistant", + "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", +}; +const common_chat_msg message_assist_thoughts { + "assistant", + "Hello, world!\nWhat's up?", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "I'm thinking", +}; +const std::vector tool_calls { + { "special_function", "{\"arg1\": 1}", /* .id = */ "" }, +}; +const std::vector tool_calls_idx { + { "special_function", "{\"arg1\": 1}", /* .id = */ "0" }, +}; +const std::vector tool_calls_id { + { "special_function", "{\"arg1\": 1}", /* .id = */ "123456789" }, +}; - common_chat_msg message_assist_call { - "assistant", - "", - /* .content_parts = */ {}, - tool_calls, - /* .reasoning_content = */ "", - /* .tool_name = */ "", - /* .tool_call_id = */ "", - }; - common_chat_msg message_assist_call_thoughts = { - "assistant", - /* .content = */ "", - /* .content_parts = */ {}, - tool_calls, - /* .reasoning_content = */ "I'm\nthinking", - /* .tool_name = */ "", - /* .tool_call_id = */ "", - }; - common_chat_msg message_assist_call_thoughts_unparsed = { - "assistant", - /* .content = */ "I'm\nthinking", - /* .content_parts = */ {}, - tool_calls, - /* .reasoning_content = */ "", - /* .tool_name = */ "", - /* .tool_call_id = */ "", - }; - common_chat_msg message_assist_call_id { - "assistant", - "", - /* .content_parts = */ {}, - tool_calls_id, - /* .reasoning_content = */ "", - /* .tool_name = */ "", - /* .tool_call_id = */ "", - }; - common_chat_msg message_assist_call_idx { - "assistant", - "", - /* .content_parts = */ {}, - tool_calls_idx, - /* .reasoning_content = */ "", - /* .tool_name = */ "", - /* .tool_call_id = */ "", - }; - common_chat_msg message_assist_call_python { - "assistant", - "", - /* .content_parts = */ {}, - { { "python", "{\"code\": \"print('hey')\"}", /* .id = */ "" } }, - /* .reasoning_content = */ "", - /* .tool_name = */ "", - /* .tool_call_id = */ "", - }; - common_chat_msg message_assist_call_code_interpreter { - "assistant", - "", - /* .content_parts = */ {}, - { { "code_interpreter", "{\"code\": \"print('hey')\"}", /* .id = */ "" } }, - /* .reasoning_content = */ "", - /* .tool_name = */ "", - /* .tool_call_id = */ "", +const common_chat_msg message_assist_call { + "assistant", + "", + /* .content_parts = */ {}, + tool_calls, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_call_thoughts = { + "assistant", + /* .content = */ "", + /* .content_parts = */ {}, + tool_calls, + /* .reasoning_content = */ "I'm\nthinking", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_call_thoughts_unparsed = { + "assistant", + /* .content = */ "I'm\nthinking", + /* .content_parts = */ {}, + tool_calls, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_call_id { + "assistant", + "", + /* .content_parts = */ {}, + tool_calls_id, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_call_idx { + "assistant", + "", + /* .content_parts = */ {}, + tool_calls_idx, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_call_python { + "assistant", + "", + /* .content_parts = */ {}, + { { "python", "{\"code\": \"print('hey')\"}", /* .id = */ "" } }, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; +const common_chat_msg message_assist_call_code_interpreter { + "assistant", + "", + /* .content_parts = */ {}, + { { "code_interpreter", "{\"code\": \"print('hey')\"}", /* .id = */ "" } }, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; + +static void test_oaicompat_json_conversion() { + std::vector msgs{ + message_assist_call, + message_assist_call_thoughts, + message_assist_call_thoughts_unparsed, + message_assist_call_id, + message_assist_call_idx, + message_assist_call_python, + message_assist_call_code_interpreter, }; + for (const auto & msg : msgs) { + auto oai_json = common_chat_msgs_to_json_oaicompat({msg}); + fprintf(stderr, "OAI JSON: %s\n", oai_json.dump(2).c_str()); + auto msgs2 = common_chat_msgs_parse_oaicompat(oai_json); + assert_equals((size_t) 1, msgs2.size()); + auto msg2 = msgs2[0]; + assert_msg_equals(msg, msg2); + } +} + +static void test_template_output_parsers() { common_chat_templates_inputs inputs_no_tools; inputs_no_tools.messages = {message_user}; @@ -691,6 +712,7 @@ int main(int argc, char ** argv) { } else #endif { + test_oaicompat_json_conversion(); test_template_output_parsers(); std::cout << "\n[chat] All tests passed!" << std::endl; } From a58e1fca5bd0d3c0f1edda40d61b622752aef6a2 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 02:51:38 +0000 Subject: [PATCH 15/41] fix typo --- common/chat.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/chat.cpp b/common/chat.cpp index e906fb6aa2fc1..ddef7250361fe 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -134,7 +134,7 @@ json common_chat_msgs_to_json_oaicompat(const std::vector & msg jmsg["name"] = msg.tool_name; } if (!msg.tool_call_id.empty()) { - jmsg["tool_call_id"] = json::parse(msg.tool_call_id); + jmsg["tool_call_id"] = msg.tool_call_id; } if (!msg.tool_calls.empty()) { auto & tool_calls = jmsg["tool_calls"] = json::array(); From 103c84057eeabe048180ba506525ff94be4e0ec9 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 12:41:47 +0000 Subject: [PATCH 16/41] fix content part string concat in legacy template branch --- common/chat.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index ddef7250361fe..33c35b9311105 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1448,8 +1448,12 @@ common_chat_params common_chat_templates_apply( content += part.text; } contents.emplace_back(std::move(content)); - chat.push_back({msg.role.c_str(), contents.back().c_str()}); - alloc_size += (msg.role.size() + contents.back().size()) * 1.25; + } + for (size_t i = 0; i < contents.size(); ++i) { + const auto & msg = inputs.messages[i]; + const auto & content = contents[i]; + chat.push_back({msg.role.c_str(), content.c_str()}); + alloc_size += (msg.role.size() + content.size()) * 1.25; } std::vector buf(alloc_size); From c154c02a3cc4f0846c6ed69c8f38cc84f61dbfa9 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 17:35:27 +0000 Subject: [PATCH 17/41] test tools json conversions --- tests/test-chat.cpp | 62 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index df7a19a579fca..23807ae4a9f75 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -387,12 +387,72 @@ static void test_oaicompat_json_conversion() { }; for (const auto & msg : msgs) { auto oai_json = common_chat_msgs_to_json_oaicompat({msg}); - fprintf(stderr, "OAI JSON: %s\n", oai_json.dump(2).c_str()); auto msgs2 = common_chat_msgs_parse_oaicompat(oai_json); assert_equals((size_t) 1, msgs2.size()); auto msg2 = msgs2[0]; assert_msg_equals(msg, msg2); } + assert_equals( + std::string( + "[\n" + " {\n" + " \"role\": \"assistant\",\n" + " \"content\": null,\n" + " \"tool_calls\": [\n" + " {\n" + " \"type\": \"function\",\n" + " \"function\": {\n" + " \"name\": \"python\",\n" + " \"arguments\": \"{\\\"code\\\": \\\"print('hey')\\\"}\"\n" + " }\n" + " }\n" + " ]\n" + " }\n" + "]" + ), + common_chat_msgs_to_json_oaicompat({message_assist_call_python}).dump(2)); + + std::vector tools{ + special_function_tool, + python_tool, + code_interpreter_tool, + }; + + for (const auto & tool : tools) { + auto oai_json = common_chat_tools_to_json_oaicompat({tool}); + auto tools2 = common_chat_tools_parse_oaicompat(oai_json); + assert_equals((size_t) 1, tools2.size()); + auto tool2 = tools2[0]; + assert_equals(tool.name, tool2.name); + assert_equals(tool.description, tool2.description); + assert_equals(json::parse(tool.parameters).dump(2), json::parse(tool2.parameters).dump(2)); + } + + assert_equals( + std::string( + "[\n" + " {\n" + " \"type\": \"function\",\n" + " \"function\": {\n" + " \"name\": \"special_function\",\n" + " \"description\": \"I'm special\",\n" + " \"parameters\": {\n" + " \"type\": \"object\",\n" + " \"properties\": {\n" + " \"arg1\": {\n" + " \"type\": \"integer\",\n" + " \"description\": \"The arg.\"\n" + " }\n" + " },\n" + " \"required\": [\n" + " \"arg1\"\n" + " ]\n" + " }\n" + " }\n" + " }\n" + "]" + ), + common_chat_tools_to_json_oaicompat({special_function_tool}).dump(2)); } static void test_template_output_parsers() { From 3d41f1b901fc204fce5f448043375b729e53a075 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 18:42:02 +0000 Subject: [PATCH 18/41] test content parts in test-chat --- tests/test-chat.cpp | 40 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 23807ae4a9f75..f14101b2671a6 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -273,6 +273,17 @@ const common_chat_msg message_user { /* .tool_calls = */ {}, /* .reasoning_content = */ "", }; + +const common_chat_msg message_user_parts { + "user", + /* .content = */ "", + /* .content_parts = */ { + { "text", "Hey" }, + { "text", "there" }, + }, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", +}; const common_chat_msg message_assist { "assistant", "Hello, world!\nWhat's up?", @@ -375,8 +386,10 @@ const common_chat_msg message_assist_call_code_interpreter { /* .tool_call_id = */ "", }; -static void test_oaicompat_json_conversion() { +static void test_msgs_oaicompat_json_conversion() { std::vector msgs{ + message_user, + message_user_parts, message_assist_call, message_assist_call_thoughts, message_assist_call_thoughts_unparsed, @@ -392,6 +405,26 @@ static void test_oaicompat_json_conversion() { auto msg2 = msgs2[0]; assert_msg_equals(msg, msg2); } + assert_equals( + std::string( + "[\n" + " {\n" + " \"role\": \"user\",\n" + " \"content\": [\n" + " {\n" + " \"type\": \"text\",\n" + " \"text\": \"Hey\"\n" + " },\n" + " {\n" + " \"type\": \"text\",\n" + " \"text\": \"there\"\n" + " }\n" + " ]\n" + " }\n" + "]" + ), + common_chat_msgs_to_json_oaicompat({message_user_parts}).dump(2)); + assert_equals( std::string( "[\n" @@ -411,7 +444,9 @@ static void test_oaicompat_json_conversion() { "]" ), common_chat_msgs_to_json_oaicompat({message_assist_call_python}).dump(2)); +} +static void test_tools_oaicompat_json_conversion() { std::vector tools{ special_function_tool, python_tool, @@ -772,7 +807,8 @@ int main(int argc, char ** argv) { } else #endif { - test_oaicompat_json_conversion(); + test_msgs_oaicompat_json_conversion(); + test_tools_oaicompat_json_conversion(); test_template_output_parsers(); std::cout << "\n[chat] All tests passed!" << std::endl; } From 59c8059d531a1f21445d867daa90eb0fbdd624d1 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 18:45:40 +0000 Subject: [PATCH 19/41] fix clang-tidy lints in [test-]chat.* --- common/chat.cpp | 435 ++++++++++++++++++++++++-------------------- common/chat.hpp | 9 +- tests/test-chat.cpp | 96 ++++++---- 3 files changed, 294 insertions(+), 246 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index 33c35b9311105..6779c5e3029a8 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -27,9 +27,15 @@ struct templates_params { }; common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) { - if (tool_choice == "auto") return COMMON_CHAT_TOOL_CHOICE_AUTO; - if (tool_choice == "none") return COMMON_CHAT_TOOL_CHOICE_NONE; - if (tool_choice == "required") return COMMON_CHAT_TOOL_CHOICE_REQUIRED; + if (tool_choice == "auto") { + return COMMON_CHAT_TOOL_CHOICE_AUTO; + } + if (tool_choice == "none") { + return COMMON_CHAT_TOOL_CHOICE_NONE; + } + if (tool_choice == "required") { + return COMMON_CHAT_TOOL_CHOICE_REQUIRED; + } throw std::runtime_error("Invalid tool_choice: " + tool_choice); } @@ -39,13 +45,19 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa try { - if (!messages.is_array()) throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump()); + if (!messages.is_array()) { + throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump()); + } for (const auto & message : messages) { - if (!message.is_object()) throw std::runtime_error("Expected 'message' to be an object, got " + message.dump()); + if (!message.is_object()) { + throw std::runtime_error("Expected 'message' to be an object, got " + message.dump()); + } common_chat_msg msg; - if (!message.contains("role")) throw std::runtime_error("Missing 'role' in message: " + message.dump()); + if (!message.contains("role")) { + throw std::runtime_error("Missing 'role' in message: " + message.dump()); + } msg.role = message.at("role"); if (message.contains("content")) { @@ -54,9 +66,13 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa msg.content = content; } else if (content.is_array()) { for (const auto & part : content) { - if (!part.contains("type")) throw std::runtime_error("Missing content part type: " + part.dump()); + if (!part.contains("type")) { + throw std::runtime_error("Missing content part type: " + part.dump()); + } const auto & type = part.at("type"); - if (type != "text") throw std::runtime_error("Unsupported content part type: " + type.dump()); + if (type != "text") { + throw std::runtime_error("Unsupported content part type: " + type.dump()); + } common_chat_msg_content_part msg_part; msg_part.type = type; msg_part.text = part.at("text"); @@ -80,12 +96,20 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa if (message.contains("tool_calls")) { for (const auto & tool_call : message.at("tool_calls")) { common_chat_tool_call tc; - if (!tool_call.contains("type")) throw std::runtime_error("Missing tool call type: " + tool_call.dump()); + if (!tool_call.contains("type")) { + throw std::runtime_error("Missing tool call type: " + tool_call.dump()); + } const auto & type = tool_call.at("type"); - if (type != "function") throw std::runtime_error("Unsupported tool call type: " + tool_call.dump()); - if (!tool_call.contains("function")) throw std::runtime_error("Missing tool call function: " + tool_call.dump()); + if (type != "function") { + throw std::runtime_error("Unsupported tool call type: " + tool_call.dump()); + } + if (!tool_call.contains("function")) { + throw std::runtime_error("Missing tool call function: " + tool_call.dump()); + } const auto & fc = tool_call.at("function"); - if (!fc.contains("name")) throw std::runtime_error("Missing tool call name: " + tool_call.dump()); + if (!fc.contains("name")) { + throw std::runtime_error("Missing tool call name: " + tool_call.dump()); + } tc.name = fc.at("name"); tc.arguments = fc.at("arguments"); if (tool_call.contains("id")) { @@ -168,12 +192,20 @@ std::vector common_chat_tools_parse_oaicompat(const json & too try { if (!tools.is_null()) { - if (!tools.is_array()) throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump()); + if (!tools.is_array()) { + throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump()); + } for (const auto & tool : tools) { - if (!tool.contains("type")) throw std::runtime_error("Missing tool type: " + tool.dump()); + if (!tool.contains("type")) { + throw std::runtime_error("Missing tool type: " + tool.dump()); + } const auto & type = tool.at("type"); - if (!type.is_string() || type != "function") throw std::runtime_error("Unsupported tool type: " + tool.dump()); - if (!tool.contains("function")) throw std::runtime_error("Missing tool function: " + tool.dump()); + if (!type.is_string() || type != "function") { + throw std::runtime_error("Unsupported tool type: " + tool.dump()); + } + if (!tool.contains("function")) { + throw std::runtime_error("Missing tool function: " + tool.dump()); + } const auto & function = tool.at("function"); result.push_back({ @@ -197,7 +229,9 @@ std::vector common_chat_tools_parse_oaicompat(const std::strin template <> json common_chat_tools_to_json_oaicompat(const std::vector & tools) { - if (tools.empty()) return json(); + if (tools.empty()) { + return json(); + } auto result = json::array(); for (const auto & tool : tools) { @@ -220,7 +254,7 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { msg.role = "user"; msg.content = "test"; - auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl); + auto * tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl); common_chat_templates_inputs inputs; inputs.messages = {msg}; @@ -270,12 +304,16 @@ std::string common_chat_format_single( std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) { common_chat_templates_inputs inputs; inputs.use_jinja = use_jinja; - inputs.messages = { - {"system", "You are a helpful assistant", {}, {}, ""}, - {"user", "Hello", {}, {}, ""}, - {"assistant", "Hi there", {}, {}, ""}, - {"user", "How are you?", {}, {}, ""}, + auto add_simple_msg = [&](auto role, auto content) { + common_chat_msg msg; + msg.role = role; + msg.content = content; + inputs.messages.push_back(msg); }; + add_simple_msg("system", "You are a helpful assistant"); + add_simple_msg("user", "Hello"); + add_simple_msg("assistant", "Hi there"); + add_simple_msg("user", "How are you?"); return common_chat_templates_apply(tmpls, inputs).prompt; } @@ -307,7 +345,7 @@ struct common_chat_templates * common_chat_templates_init( bool has_explicit_template = !chat_template_override.empty(); if (chat_template_override.empty()) { GGML_ASSERT(model != nullptr); - auto str = llama_model_chat_template(model, /* name */ nullptr); + const auto * str = llama_model_chat_template(model, /* name */ nullptr); if (str) { default_template_src = str; has_explicit_template = true; @@ -330,22 +368,21 @@ struct common_chat_templates * common_chat_templates_init( std::string token_bos = bos_token_override; std::string token_eos = eos_token_override; if (model) { - auto vocab = llama_model_get_vocab(model); + const auto * vocab = llama_model_get_vocab(model); const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { if (token == LLAMA_TOKEN_NULL) { if (default_template_src.find(jinja_variable_name) != std::string::npos || template_tool_use_src.find(jinja_variable_name) != std::string::npos) { - LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name); + LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name); } return std::string(); - } else { - return common_token_to_piece(vocab, token, true); } + return common_token_to_piece(vocab, token, true); }; token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); } - auto tmpls = new common_chat_templates(); + auto * tmpls = new common_chat_templates(); tmpls->has_explicit_template = has_explicit_template; try { tmpls->template_default = std::make_unique(default_template_src, token_bos, token_eos); @@ -917,21 +954,17 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo auto arg_value_str = raw_args.substr(it_eq + 1); auto arg_value = json::parse(arg_value_str); - return { - /* .role = */ "assistant", - /* .content = */ match.prefix().str(), - /* .content_parts = */ {}, - /* .tool_calls = */ { - { - /* .name = */ match[1], - /* .arguments = */ (json { - {arg_name, arg_value}, - }).dump(), - /* .id = */ "", - }, - }, - /* .reasoning_content = */ "", - }; + common_chat_msg msg; + msg.role = "assistant"; + msg.content = match.prefix().str(); + msg.tool_calls.push_back({ + /* .name = */ name, + /* .arguments = */ (json { + {arg_name, arg_value}, + }).dump(), + /* .id = */ "", + }); + return msg; } } return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); @@ -1215,19 +1248,15 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s std::smatch match; if (std::regex_search(input, match, python_tag_regex)) { auto code = match[1].str(); - return { - /* .role = */ "assistant", - /* .content = */ match.prefix().str(), - /* .content_parts = */ {}, - /* .tool_calls = */ { - { - /* .name = */ "python", - /* .arguments = */ (json {{"code", code}}).dump(), - /* .id = */ "", - }, - }, - /* .reasoning_content = */ "", - }; + common_chat_msg msg; + msg.role = "assistant"; + msg.content = match.prefix().str(); + msg.tool_calls.push_back({ + /* .name = */ "python", + /* .arguments = */ (json {{"code", code}}).dump(), + /* .id = */ "", + }); + return msg; } static std::regex function_regex(R"()"); static std::regex close_regex(R"()"); @@ -1271,22 +1300,18 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) std::regex middle_pattern(R"([\n\s]*[\n\s]*)"); std::regex end_pattern(R"([\n\s]*[\n\s]*$)"); + common_chat_msg msg; + msg.role = "assistant"; + auto end = input.end(); std::sregex_iterator rend; std::sregex_iterator rit(input.begin(), end, start_pattern); if (rit == rend) { - return { - /* .role = */ "assistant", - /* .content = */ input, - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "", - }; + msg.content = input; + return msg; } - common_chat_msg result; - result.role = "assistant"; - result.content = rit->prefix(); + msg.content = rit->prefix(); auto it = rit->suffix().first; while (it != end) { @@ -1295,7 +1320,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) throw std::runtime_error("Failed to parse json tool call"); } const auto & arguments = call.at("arguments"); - result.tool_calls.push_back({ + msg.tool_calls.push_back({ call.at("name"), arguments.dump(), // arguments.is_string() ? arguments.get() : arguments.dump(), @@ -1312,15 +1337,13 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) break; } } - return result; + return msg; } catch (const std::exception & e) { - return { - /* .role = */ "assistant", - /* .content = */ input, - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "", - }; + LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what()); + common_chat_msg msg; + msg.role = "assistant"; + msg.content = input; + return msg; } } @@ -1340,160 +1363,168 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha return data; } -common_chat_params common_chat_templates_apply( +static common_chat_params common_chat_templates_apply_jinja( const struct common_chat_templates * tmpls, const struct common_chat_templates_inputs & inputs) { - GGML_ASSERT(tmpls != nullptr); - if (inputs.use_jinja) { - templates_params params; - params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages); - params.add_generation_prompt = inputs.add_generation_prompt; - params.extract_reasoning = inputs.extract_reasoning; - params.tools = common_chat_tools_to_json_oaicompat(inputs.tools); - params.tool_choice = inputs.tool_choice; - params.grammar = inputs.grammar; - if (!inputs.json_schema.empty()) { - params.json_schema = json::parse(inputs.json_schema); - } - const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use - ? *tmpls->template_tool_use - : *tmpls->template_default; - const auto & src = tmpl.source(); - const auto & caps = tmpl.original_caps(); - - if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { - LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); - params.parallel_tool_calls = false; - } else { - params.parallel_tool_calls = inputs.parallel_tool_calls; - } + templates_params params; + params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages); + params.add_generation_prompt = inputs.add_generation_prompt; + params.extract_reasoning = inputs.extract_reasoning; + params.tools = common_chat_tools_to_json_oaicompat(inputs.tools); + params.tool_choice = inputs.tool_choice; + params.grammar = inputs.grammar; + if (!inputs.json_schema.empty()) { + params.json_schema = json::parse(inputs.json_schema); + } + const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use + ? *tmpls->template_tool_use + : *tmpls->template_default; + const auto & src = tmpl.source(); + const auto & caps = tmpl.original_caps(); + + if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { + LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); + params.parallel_tool_calls = false; + } else { + params.parallel_tool_calls = inputs.parallel_tool_calls; + } - if (params.tools.is_array()) { - if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) { - throw std::runtime_error("Cannot specify grammar with tools"); - } - if (caps.supports_tool_calls && !caps.supports_tools) { - LOG_WRN("Template supports tool calls but does not natively describe tools. The fallback behaviour used may produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n"); - } + if (params.tools.is_array()) { + if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) { + throw std::runtime_error("Cannot specify grammar with tools"); } - - // DeepSeek R1: use handler in all cases except json schema (thinking / tools). - if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) { - return common_chat_params_init_deepseek_r1(tmpl, params); + if (caps.supports_tool_calls && !caps.supports_tools) { + LOG_WRN("Template supports tool calls but does not natively describe tools. The fallback behaviour used may produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n"); } + } - // Command R7B: : use handler in all cases except json schema (thinking / tools). - if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) { - return common_chat_params_init_command_r7b(tmpl, params); - } + // DeepSeek R1: use handler in all cases except json schema (thinking / tools). + if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) { + return common_chat_params_init_deepseek_r1(tmpl, params); + } - // Use generic handler when mixing tools + JSON schema. - // TODO: support that mix in handlers below. - if ((!params.tools.is_array() && params.json_schema.is_object())) { - return common_chat_params_init_generic(tmpl, params); - } + // Command R7B: : use handler in all cases except json schema (thinking / tools). + if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) { + return common_chat_params_init_command_r7b(tmpl, params); + } - // Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases. - if (src.find(">>>all") != std::string::npos) { - return common_chat_params_init_functionary_v3_2(tmpl, params); - } + // Use generic handler when mixing tools + JSON schema. + // TODO: support that mix in handlers below. + if ((!params.tools.is_array() && params.json_schema.is_object())) { + return common_chat_params_init_generic(tmpl, params); + } - // Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases. - if (src.find(" functools[") != std::string::npos) { - return common_chat_params_init_firefunction_v2(tmpl, params); - } + // Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases. + if (src.find(">>>all") != std::string::npos) { + return common_chat_params_init_functionary_v3_2(tmpl, params); + } - // Plain handler (no tools) - if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { - return common_chat_params_init_without_tools(tmpl, params); - } + // Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases. + if (src.find(" functools[") != std::string::npos) { + return common_chat_params_init_firefunction_v2(tmpl, params); + } - // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) - if (src.find("") != std::string::npos) { - return common_chat_params_init_hermes_2_pro(tmpl, params); - } + // Plain handler (no tools) + if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { + return common_chat_params_init_without_tools(tmpl, params); + } - // Functionary v3.1 (w/ tools) - if (src.find("<|start_header_id|>") != std::string::npos - && src.find("") != std::string::npos) { + return common_chat_params_init_hermes_2_pro(tmpl, params); + } - // Llama 3.1, 3.2, 3.3 (w/ tools) - if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) { - auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; - return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools); - } + // Functionary v3.1 (w/ tools) + if (src.find("<|start_header_id|>") != std::string::npos + && src.find("ipython<|end_header_id|>") != std::string::npos) { + auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; + return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools); + } - // Generic fallback - return common_chat_params_init_generic(tmpl, params); - } else { - // Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template. + // Mistral Nemo (w/ tools) + if (src.find("[TOOL_CALLS]") != std::string::npos) { + return common_chat_params_init_mistral_nemo(tmpl, params); + } - int alloc_size = 0; - std::vector chat; - std::vector contents; - for (const auto & msg : inputs.messages) { - auto content = msg.content; - for (const auto & part : msg.content_parts) { - if (!content.empty()) { - content += "\n";; - } - content += part.text; + // Generic fallback + return common_chat_params_init_generic(tmpl, params); +} + +// Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template. +static common_chat_params common_chat_templates_apply_legacy( + const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs) +{ + int alloc_size = 0; + std::vector chat; + std::vector contents; + for (const auto & msg : inputs.messages) { + auto content = msg.content; + for (const auto & part : msg.content_parts) { + if (!content.empty()) { + content += "\n";; } - contents.emplace_back(std::move(content)); - } - for (size_t i = 0; i < contents.size(); ++i) { - const auto & msg = inputs.messages[i]; - const auto & content = contents[i]; - chat.push_back({msg.role.c_str(), content.c_str()}); - alloc_size += (msg.role.size() + content.size()) * 1.25; + content += part.text; } + contents.emplace_back(std::move(content)); + } + for (size_t i = 0; i < contents.size(); ++i) { + const auto & msg = inputs.messages[i]; + const auto & content = contents[i]; + chat.push_back({msg.role.c_str(), content.c_str()}); + alloc_size += (msg.role.size() + content.size()) * 1.25; + } - std::vector buf(alloc_size); + std::vector buf(alloc_size); - // run the first time to get the total output length - const auto & src = tmpls->template_default->source(); - int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size()); + // run the first time to get the total output length + const auto & src = tmpls->template_default->source(); + int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size()); - // error: chat template is not supported - if (res < 0) { - // if the custom "tmpl" is not supported, we throw an error - // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() - throw std::runtime_error("this custom template is not supported"); - } + // error: chat template is not supported + if (res < 0) { + // if the custom "tmpl" is not supported, we throw an error + // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() + throw std::runtime_error("this custom template is not supported"); + } - // if it turns out that our buffer is too small, we resize it - if ((size_t) res > buf.size()) { - buf.resize(res); - res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size()); - } + // if it turns out that our buffer is too small, we resize it + if ((size_t) res > buf.size()) { + buf.resize(res); + res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size()); + } - common_chat_params params; - params.prompt = std::string(buf.data(), res); - if (!inputs.json_schema.empty()) { - params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema)); - } else { - params.grammar = inputs.grammar; - } - return params; + common_chat_params params; + params.prompt = std::string(buf.data(), res); + if (!inputs.json_schema.empty()) { + params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema)); + } else { + params.grammar = inputs.grammar; } + return params; +} + +common_chat_params common_chat_templates_apply( + const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs) +{ + GGML_ASSERT(tmpls != nullptr); + return inputs.use_jinja + ? common_chat_templates_apply_jinja(tmpls, inputs) + : common_chat_templates_apply_legacy(tmpls, inputs); } static common_chat_msg common_chat_parse_content_only(const std::string & input) { - return { - /* .role = */ "assistant", - /* .content = */ input, - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "", - }; + common_chat_msg msg; + msg.role = "assistant"; + msg.content = input; + return msg; } common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) { diff --git a/common/chat.hpp b/common/chat.hpp index 1bfc22dc199b1..8fb7dff6db17c 100644 --- a/common/chat.hpp +++ b/common/chat.hpp @@ -3,7 +3,6 @@ #pragma once #include "common.h" -#include #include #include @@ -26,9 +25,9 @@ struct common_chat_msg { std::string content; std::vector content_parts = {}; std::vector tool_calls = {}; - std::string reasoning_content = ""; - std::string tool_name = ""; - std::string tool_call_id = ""; + std::string reasoning_content; + std::string tool_name; + std::string tool_call_id; }; struct common_chat_tool { @@ -123,7 +122,7 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin // Parses a JSON array of messages in OpenAI's chat completion API format. // T can be std::string containing JSON or nlohmann::ordered_json template std::vector common_chat_msgs_parse_oaicompat(const T & messages); -template T common_chat_msgs_to_json_oaicompat(const std::vector & messages); +template T common_chat_msgs_to_json_oaicompat(const std::vector & msgs); // Parses a JSON array of tools in OpenAI's chat completion tool call API format. // T can be std::string containing JSON or nlohmann::ordered_json diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index f14101b2671a6..0d8596f8675c6 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -27,7 +27,7 @@ template static void assert_equals(const T & expected, const T & actua } static std::string read_file(const std::string & path) { - std::cerr << "# Reading: " << path << std::endl << std::flush; + std::cerr << "# Reading: " << path << '\n' << std::flush; std::ifstream fs(path, std::ios_base::binary); if (!fs.is_open()) { fs = std::ifstream("../" + path, std::ios_base::binary); @@ -40,7 +40,7 @@ static std::string read_file(const std::string & path) { fs.seekg(0); std::string out; out.resize(static_cast(size)); - fs.read(&out[0], static_cast(size)); + fs.read(out.data(), static_cast(size)); return out; } @@ -68,11 +68,9 @@ static bool match_string(const std::string & input, llama_grammar * grammar) { } } - for (const auto & stack : stacks_cur) { - if (stack.empty()) { - // An empty stack means that the grammar has been completed - return true; - } + if (std::any_of(stacks_cur.begin(), stacks_cur.end(), [](const auto & stack) { return stack.empty(); })) { + // An empty stack means that the grammar has been completed + return true; } return false; @@ -213,7 +211,9 @@ static void test_templates(const struct common_chat_templates * tmpls, const std bool expect_grammar_triggered = true, bool test_grammar_if_triggered = true, bool think = false) { - common_chat_msg user_message = { "user", "Hello, world!", {}, {}, "" }; + common_chat_msg user_message; + user_message.role = "user"; + user_message.content = "Hello, world!"; for (const auto & tool_choice : std::vector {COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED}) { auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice, think); @@ -272,6 +272,8 @@ const common_chat_msg message_user { /* .content_parts = */ {}, /* .tool_calls = */ {}, /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", }; const common_chat_msg message_user_parts { @@ -283,6 +285,8 @@ const common_chat_msg message_user_parts { }, /* .tool_calls = */ {}, /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", }; const common_chat_msg message_assist { "assistant", @@ -290,6 +294,8 @@ const common_chat_msg message_assist { /* .content_parts = */ {}, /* .tool_calls = */ {}, /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", }; const common_chat_msg message_assist_thoughts_unparsed_think { "assistant", @@ -297,6 +303,8 @@ const common_chat_msg message_assist_thoughts_unparsed_think { /* .content_parts = */ {}, /* .tool_calls = */ {}, /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", }; const common_chat_msg message_assist_thoughts_unparsed_r7b { "assistant", @@ -304,6 +312,8 @@ const common_chat_msg message_assist_thoughts_unparsed_r7b { /* .content_parts = */ {}, /* .tool_calls = */ {}, /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", }; const common_chat_msg message_assist_thoughts { "assistant", @@ -311,6 +321,8 @@ const common_chat_msg message_assist_thoughts { /* .content_parts = */ {}, /* .tool_calls = */ {}, /* .reasoning_content = */ "I'm thinking", + /* .tool_name = */ "", + /* .tool_call_id = */ "", }; const std::vector tool_calls { { "special_function", "{\"arg1\": 1}", /* .id = */ "" }, @@ -777,40 +789,46 @@ static void test_template_output_parsers() { } int main(int argc, char ** argv) { + try { #ifndef _WIN32 - if (argc > 1) { - common_chat_templates_inputs inputs; - inputs.messages = { - { "user", "Hey", {}, {}, "" }, - }; - inputs.tools = { special_function_tool }; - - std::cout << "| Template | Format |\n"; - std::cout << "|----------|--------|\n"; - - for (int i = 1; i < argc; i++) { - try { - std::string path = argv[i]; - if (path.rfind(".jinja") != path.size() - 6) { - std::cerr << "Skipping non-jinja file: " << path << std::endl; - continue; + if (argc > 1) { + common_chat_templates_inputs inputs; + common_chat_msg msg; + msg.role = "user"; + msg.content = "Hey"; + inputs.messages = {msg}; + inputs.tools = { special_function_tool }; + + std::cout << "| Template | Format |\n"; + std::cout << "|----------|--------|\n"; + + for (int i = 1; i < argc; i++) { + try { + std::string path = argv[i]; + if (path.rfind(".jinja") != path.size() - 6) { + std::cerr << "Skipping non-jinja file: " << path << '\n'; + continue; + } + auto tmpls = read_templates(path); + auto parts = string_split(path, "/"); + auto name = parts[parts.size() - 1]; + auto format = common_chat_format_name(common_chat_templates_apply(tmpls.get(), inputs).format); + std::cout << "| " << name << " | " << format << " |\n"; + } catch (const std::exception & e) { + std::cerr << "Failed to process " << argv[i] << ": " << e.what() << '\n'; } - auto tmpls = read_templates(path); - auto parts = string_split(path, "/"); - auto name = parts[parts.size() - 1]; - auto format = common_chat_format_name(common_chat_templates_apply(tmpls.get(), inputs).format); - std::cout << "| " << name << " | " << format << " |\n"; - } catch (const std::exception & e) { - std::cerr << "Failed to process " << argv[i] << ": " << e.what() << std::endl; } - } - } else + } else #endif - { - test_msgs_oaicompat_json_conversion(); - test_tools_oaicompat_json_conversion(); - test_template_output_parsers(); - std::cout << "\n[chat] All tests passed!" << std::endl; + { + test_msgs_oaicompat_json_conversion(); + test_tools_oaicompat_json_conversion(); + test_template_output_parsers(); + std::cout << "\n[chat] All tests passed!" << '\n'; + } + return 0; + } catch (const std::exception & e) { + std::cerr << "Error: " << e.what() << '\n'; + return 1; } - return 0; } From 1847cae2cd8eba7e365b9619dc6e9704be85e873 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 19:12:05 +0000 Subject: [PATCH 20/41] fix deepseek r1 slow test (no longer opening w/ new template) --- examples/server/tests/unit/test_tool_call.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py index 8aca10011d148..40e279b5b92c7 100644 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -444,7 +444,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, (128, None, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (1024, 'deepseek', "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - (1024, 'none', "\n?I need[\\s\\S]*?\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + (1024, 'none', "^I need[\\s\\S]*?\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), (1024, 'deepseek', "To find the sum of.*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), ]) From 8462a51544a77378ced001411884b7b58da13566 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 19:48:12 +0000 Subject: [PATCH 21/41] fix lints in test-chat-template.cpp --- common/chat.cpp | 1 - tests/test-chat-template.cpp | 38 +++++++++++++----------------------- 2 files changed, 14 insertions(+), 25 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index 6779c5e3029a8..a8b2a268eb150 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1,5 +1,4 @@ #include "chat.hpp" -#include #include "json-schema-to-grammar.h" #include "log.h" #include "minja/chat-template.hpp" diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 80d12b83cbbb9..4a6e50a052139 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -18,6 +18,13 @@ static std::string normalize_newlines(const std::string & s) { #endif } +static common_chat_msg simple_msg(const std::string & role, const std::string & content) { + common_chat_msg msg; + msg.role = role; + msg.content = content; + return msg; +} + int main(void) { std::vector conversation { {"system", "You are a helpful assistant"}, @@ -306,13 +313,7 @@ int main(void) { std::vector messages; for (const auto & msg : conversation) { - messages.push_back({ - msg.role, - msg.content, - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "", - }); + messages.push_back(simple_msg(msg.role, msg.content)); } for (const auto & test_case : test_cases) { if (!test_case.supported_with_jinja) { @@ -322,6 +323,7 @@ int main(void) { try { common_chat_templates_ptr tmpls(common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token), &common_chat_templates_free); common_chat_templates_inputs inputs; + inputs.use_jinja = false; inputs.messages = messages; inputs.add_generation_prompt = add_generation_prompt; auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt; @@ -343,13 +345,7 @@ int main(void) { // test llama_chat_format_single for system message printf("\n\n=== llama_chat_format_single (system message) ===\n\n"); std::vector chat2; - common_chat_msg sys_msg { - "system", - "You are a helpful assistant", - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "", - }; + auto sys_msg = simple_msg("system", "You are a helpful assistant"); auto fmt_sys = [&](std::string tmpl_str) { common_chat_templates_ptr tmpls(common_chat_templates_init(/* model= */ nullptr, tmpl_str), &common_chat_templates_free); @@ -373,16 +369,10 @@ int main(void) { // test llama_chat_format_single for user message printf("\n\n=== llama_chat_format_single (user message) ===\n\n"); - chat2.push_back({"system", "You are a helpful assistant", {}, {}, ""}); - chat2.push_back({"user", "Hello", {}, {}, ""}); - chat2.push_back({"assistant", "I am assistant", {}, {}, ""}); - common_chat_msg new_msg { - "user", - "How are you", - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "", - }; + chat2.push_back(simple_msg("system", "You are a helpful assistant")); + chat2.push_back(simple_msg("user", "Hello")); + chat2.push_back(simple_msg("assistant", "I am assistant")); + auto new_msg = simple_msg("user", "How are you"); auto fmt_single = [&](const std::string & tmpl_str) { common_chat_templates_ptr tmpls(common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str()), &common_chat_templates_free); From 80c432b659d12edb3f85e24822d4aba485dec2ed Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 19:49:03 +0000 Subject: [PATCH 22/41] tweak test_calc_result expectations --- examples/server/tests/unit/test_tool_call.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py index 40e279b5b92c7..a91a2f3333ca3 100644 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -356,12 +356,12 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | (None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), - ("^> 0.56$", 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), # TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value) - ("^The y-coordinate [\\s\\S]*?\\*\\*0.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - ("[\\s\\S]*?\\*\\*0\\.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), + ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + # ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), ]) def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None): global server From 42b29e12d0f3b22524081e9ac5ac261f5905221c Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 20:11:51 +0000 Subject: [PATCH 23/41] fix double bos/eos jinja avoidance hack (was preventing inner bos/eos tokens) --- common/chat.cpp | 15 +++++++++++---- tests/test-chat-template.cpp | 10 +++++----- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index a8b2a268eb150..7a0fb3ae5dbe7 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -582,10 +582,17 @@ static std::string apply( // tmpl_inputs.now = std::chrono::system_clock::now(); minja::chat_template_options tmpl_opts; - tmpl_opts.use_bos_token = false; - tmpl_opts.use_eos_token = false; - - return tmpl.apply(tmpl_inputs, tmpl_opts); + // To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens + // instead of using `chat_template_options.use_bos_token = false`, since these tokens + // may be needed inside the template / between messages too. + auto result = tmpl.apply(tmpl_inputs, tmpl_opts); + if (string_starts_with(result, tmpl.bos_token())) { + result = result.substr(tmpl.bos_token().size()); + } + if (string_ends_with(result, tmpl.eos_token())) { + result = result.substr(0, result.size() - tmpl.eos_token().size()); + } + return result; } static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) { diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 4a6e50a052139..e70d6b26dc47f 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -57,7 +57,7 @@ int main(void) { /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", /* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", /* .expected_output_jinja= */ "", - /* .bos_token= */ "", + /* .bos_token= */ "", /* .eos_token= */ "", }, { @@ -79,8 +79,8 @@ int main(void) { { /* .name= */ "mlabonne/AlphaMonarch-7B", /* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}", - /* .expected_output= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", - /* .expected_output_jinja= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + /* .expected_output= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + /* .expected_output_jinja= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", /* .bos_token= */ "", /* .eos_token= */ "", }, @@ -94,7 +94,7 @@ int main(void) { /* .name= */ "OrionStarAI/Orion-14B-Chat", /* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}", /* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", - /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", + /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", /* .bos_token= */ "", /* .eos_token= */ "", }, @@ -323,7 +323,7 @@ int main(void) { try { common_chat_templates_ptr tmpls(common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token), &common_chat_templates_free); common_chat_templates_inputs inputs; - inputs.use_jinja = false; + inputs.use_jinja = true; inputs.messages = messages; inputs.add_generation_prompt = add_generation_prompt; auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt; From ce4ccf03d34306110aeb870784bc38e5f228452a Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 20:38:56 +0000 Subject: [PATCH 24/41] add common_chat_templates_source + rehab server template logs --- common/chat.cpp | 14 ++++++++++++++ common/chat.hpp | 7 ++++--- examples/server/server.cpp | 21 ++++++++++++--------- 3 files changed, 30 insertions(+), 12 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index 7a0fb3ae5dbe7..d6cfb4e2a806b 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -332,6 +332,20 @@ bool common_chat_templates_was_explicit(const struct common_chat_templates * tmp return tmpls->has_explicit_template; } +const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) { + if (variant != nullptr) { + if (strcmp(variant, "tool_use") == 0) { + if (tmpls->template_tool_use) { + return tmpls->template_tool_use->source().c_str(); + } + return nullptr; + } else { + LOG_DBG("%s: unknown template variant: %s\n", __func__, variant); + } + } + return tmpls->template_default->source().c_str(); +} + struct common_chat_templates * common_chat_templates_init( const struct llama_model * model, const std::string & chat_template_override, diff --git a/common/chat.hpp b/common/chat.hpp index 8fb7dff6db17c..cb5bc81b7cad2 100644 --- a/common/chat.hpp +++ b/common/chat.hpp @@ -92,8 +92,9 @@ struct common_chat_templates * common_chat_templates_init( const std::string & bos_token_override = "", const std::string & eos_token_override = ""); -bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls); -void common_chat_templates_free(struct common_chat_templates * tmpls); +bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls); +const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr); +void common_chat_templates_free(struct common_chat_templates * tmpls); typedef std::unique_ptr common_chat_templates_ptr; @@ -127,4 +128,4 @@ template T common_chat_msgs_to_json_oaicompat(const std::vector std::vector common_chat_tools_parse_oaicompat(const T & tools); -template T common_chat_tools_to_json_oaicompat(const std::vector & tools); \ No newline at end of file +template T common_chat_tools_to_json_oaicompat(const std::vector & tools); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ad754036118d3..84e18dec99b40 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1897,6 +1897,7 @@ struct server_context { common_chat_format_example(chat_templates, params.use_jinja); } catch (const std::exception & e) { SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); + common_chat_templates_free(chat_templates); chat_templates = common_chat_templates_init(model, "chatml"); } @@ -3795,14 +3796,16 @@ int main(int argc, char ** argv) { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model }, - // { "chat_template", ctx_server.chat_templates.template_default->source() }, - // { "bos_token", ctx_server.chat_templates.template_default->bos_token() }, - // { "eos_token", ctx_server.chat_templates.template_default->eos_token() }, + { "chat_template", common_chat_templates_source(ctx_server.chat_templates) }, + { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)}, + { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)}, { "build_info", build_info }, }; - // if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) { - // data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source(); - // } + if (ctx_server.params_base.use_jinja) { + if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates, "tool_use")) { + data["chat_template_tool_use"] = tool_use_src; + } + } res_ok(res, data); }; @@ -4454,9 +4457,9 @@ int main(int argc, char ** argv) { LOG_INF("%s: model loaded\n", __func__); // print sample chat example to make it clear which template is used - // LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - // ctx_server.chat_templates.template_default->source().c_str(), - // common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str()); + LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, + common_chat_templates_source(ctx_server.chat_templates), + common_chat_format_example(ctx_server.chat_templates, ctx_server.params_base.use_jinja).c_str()); ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) { ctx_server.process_single_task(task); From cb31f087b466c1a9c2bbe316d0514b04ff6ddc4a Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 20:39:08 +0000 Subject: [PATCH 25/41] fix msg lints --- examples/main/main.cpp | 10 +++------- examples/run/run.cpp | 5 ++++- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 4e953675f8de7..48b3389e69efe 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -266,13 +266,9 @@ int main(int argc, char ** argv) { std::vector embd_inp; auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { - common_chat_msg new_msg { - role, - content, - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "", - }; + common_chat_msg new_msg; + new_msg.role = role; + new_msg.content = content; auto formatted = common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, role == "user", g_params->use_jinja); chat_msgs.push_back(new_msg); LOG_DBG("formatted: '%s'\n", formatted.c_str()); diff --git a/examples/run/run.cpp b/examples/run/run.cpp index a70222ccbe20e..c43d3d283c848 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -837,7 +837,10 @@ static void add_message(const char * role, const std::string & text, LlamaData & static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) { common_chat_templates_inputs inputs; for (const auto & msg : llama_data.messages) { - inputs.messages.push_back({ msg.role, msg.content, {}, {}, "" }); + common_chat_msg cmsg; + cmsg.role = msg.role; + cmsg.content = msg.content; + inputs.messages.push_back(cmsg); } inputs.add_generation_prompt = append; inputs.use_jinja = use_jinja; From 76f5d27b94d20ca1c65c16dd3038020e8bef0c5b Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 20:47:36 +0000 Subject: [PATCH 26/41] tool-call: allow empty tools w/ auto + grammar --- examples/server/utils.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index d7ddfd0ae43c2..1e000087d7fc3 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -601,7 +601,7 @@ static json oaicompat_completion_params_parse( inputs.use_jinja = use_jinja; inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; - if (inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && llama_params.contains("grammar")) { + if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && llama_params.contains("grammar")) { throw std::runtime_error("Cannot use custom grammar constraints with tools."); } From 34e4e2244f0828ae912227f4fd61d99b8a23ec00 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 21:29:09 +0000 Subject: [PATCH 27/41] fix & test grammar & json_schema w/ & w/o --jinja --- common/chat.cpp | 2 +- examples/server/server.cpp | 3 -- .../server/tests/unit/test_chat_completion.py | 41 +++++++++++++++++++ examples/server/utils.hpp | 6 +-- 4 files changed, 45 insertions(+), 7 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index d6cfb4e2a806b..df6d2ae1a8994 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1431,7 +1431,7 @@ static common_chat_params common_chat_templates_apply_jinja( // Use generic handler when mixing tools + JSON schema. // TODO: support that mix in handlers below. - if ((!params.tools.is_array() && params.json_schema.is_object())) { + if ((params.tools.is_array() && params.json_schema.is_object())) { return common_chat_params_init_generic(tmpl, params); } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 84e18dec99b40..6c76bcb5c2215 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -329,9 +329,6 @@ struct server_task { } // process "json_schema" and "grammar" - if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { - throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); - } if (data.contains("json_schema") && !data.contains("grammar")) { try { auto schema = json_value(data, "json_schema", json::object()); diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index f23d5cff49abc..6a980f0c1b19b 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -169,6 +169,47 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int assert "error" in res.body +@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [ + (False, {"const": "42"}, 6, "\"42\""), + (True, {"const": "42"}, 6, "\"42\""), +]) +def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str): + global server + server.jinja = jinja + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predicted, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "json_schema": json_schema, + }) + assert res.status_code == 200, f'Expected 200, got {res.status_code}' + choice = res.body["choices"][0] + assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}' + + +@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [ + (False, 'root ::= "a"{5,5}', 6, "a{5,5}"), + (True, 'root ::= "a"{5,5}', 6, "a{5,5}"), +]) +def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str): + global server + server.jinja = jinja + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predicted, + "messages": [ + {"role": "user", "content": "Does not matter what I say, does it?"}, + ], + "grammar": grammar, + }) + assert res.status_code == 200, res.body + choice = res.body["choices"][0] + assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"] + + @pytest.mark.parametrize("messages", [ None, "string", diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 1e000087d7fc3..c18c31e56ac3a 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -571,8 +571,8 @@ static json oaicompat_completion_params_parse( llama_params["stop"] = json_value(body, "stop", json::array()); } - auto json_schema = json_value(llama_params, "json_schema", json()); - auto grammar = json_value(llama_params, "grammar", std::string()); + auto json_schema = json_value(body, "json_schema", json()); + auto grammar = json_value(body, "grammar", std::string()); if (!json_schema.is_null() && !grammar.empty()) { throw std::runtime_error("Cannot use both json_schema and grammar"); } @@ -601,7 +601,7 @@ static json oaicompat_completion_params_parse( inputs.use_jinja = use_jinja; inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; - if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && llama_params.contains("grammar")) { + if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) { throw std::runtime_error("Cannot use custom grammar constraints with tools."); } From 1c6168bcf8d53cdc2f1ee5ca489efcffcbe1d1e0 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 21:30:53 +0000 Subject: [PATCH 28/41] Update test-chat-template.cpp --- tests/test-chat-template.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index e70d6b26dc47f..7d3a0eb839a73 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -79,8 +79,8 @@ int main(void) { { /* .name= */ "mlabonne/AlphaMonarch-7B", /* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}", - /* .expected_output= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", - /* .expected_output_jinja= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + /* .expected_output= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + /* .expected_output_jinja= */ "", /* .bos_token= */ "", /* .eos_token= */ "", }, From ae6b870d13c0573dccc826a2fb950f46f27f0d56 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 21:48:33 +0000 Subject: [PATCH 29/41] test & fix array message.content --- common/chat.cpp | 49 +++++++++++++------ common/chat.hpp | 2 +- .../server/tests/unit/test_chat_completion.py | 4 +- 3 files changed, 38 insertions(+), 17 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index df6d2ae1a8994..f3118ab757bf8 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -128,7 +128,7 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa } template <> -json common_chat_msgs_to_json_oaicompat(const std::vector & msgs) { +json common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text) { json messages = json::array(); for (const auto & msg : msgs) { if (!msg.content.empty() && !msg.content_parts.empty()) { @@ -140,12 +140,27 @@ json common_chat_msgs_to_json_oaicompat(const std::vector & msg if (!msg.content.empty()) { jmsg["content"] = msg.content; } else if (!msg.content_parts.empty()) { - auto & parts = jmsg["content"] = json::array(); - for (const auto & part : msg.content_parts) { - parts.push_back({ - {"type", part.type}, - {"text", part.text}, - }); + if (concat_typed_text) { + std::string text; + for (const auto & part : msg.content_parts) { + if (part.type != "text") { + LOG_WRN("Ignoring content part type: %s\n", part.type.c_str()); + continue; + } + if (!text.empty()) { + text += '\n'; + } + text += part.text; + } + jmsg["content"] = text; + } else { + auto & parts = jmsg["content"] = json::array(); + for (const auto & part : msg.content_parts) { + parts.push_back({ + {"type", part.type}, + {"text", part.text}, + }); + } } } else { jmsg["content"] = json(); // null @@ -1388,21 +1403,21 @@ static common_chat_params common_chat_templates_apply_jinja( const struct common_chat_templates_inputs & inputs) { templates_params params; - params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages); + params.tools = common_chat_tools_to_json_oaicompat(inputs.tools); + const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use + ? *tmpls->template_tool_use + : *tmpls->template_default; + const auto & src = tmpl.source(); + const auto & caps = tmpl.original_caps(); + params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content); params.add_generation_prompt = inputs.add_generation_prompt; params.extract_reasoning = inputs.extract_reasoning; - params.tools = common_chat_tools_to_json_oaicompat(inputs.tools); params.tool_choice = inputs.tool_choice; params.grammar = inputs.grammar; if (!inputs.json_schema.empty()) { params.json_schema = json::parse(inputs.json_schema); } - const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use - ? *tmpls->template_tool_use - : *tmpls->template_default; - const auto & src = tmpl.source(); - const auto & caps = tmpl.original_caps(); - + if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); params.parallel_tool_calls = false; @@ -1487,6 +1502,10 @@ static common_chat_params common_chat_templates_apply_legacy( for (const auto & msg : inputs.messages) { auto content = msg.content; for (const auto & part : msg.content_parts) { + if (part.type != "text") { + LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str()); + continue; + } if (!content.empty()) { content += "\n";; } diff --git a/common/chat.hpp b/common/chat.hpp index cb5bc81b7cad2..d5c3dca780ae1 100644 --- a/common/chat.hpp +++ b/common/chat.hpp @@ -123,7 +123,7 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin // Parses a JSON array of messages in OpenAI's chat completion API format. // T can be std::string containing JSON or nlohmann::ordered_json template std::vector common_chat_msgs_parse_oaicompat(const T & messages); -template T common_chat_msgs_to_json_oaicompat(const std::vector & msgs); +template T common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text = false); // Parses a JSON array of tools in OpenAI's chat completion tool call API format. // T can be std::string containing JSON or nlohmann::ordered_json diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 6a980f0c1b19b..af1dcb5b96554 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -21,6 +21,8 @@ def create_server(): (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), + (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None), + (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None), ] ) def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template): @@ -44,7 +46,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte assert res.body["usage"]["completion_tokens"] == n_predicted choice = res.body["choices"][0] assert "assistant" == choice["message"]["role"] - assert match_regex(re_content, choice["message"]["content"]) + assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}' assert choice["finish_reason"] == finish_reason From 142103712151bf94bbcda7a8c05e4743beb1d342 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 21:49:53 +0000 Subject: [PATCH 30/41] fix links to prepare merge --- common/chat.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index f3118ab757bf8..82c3bd0f183ff 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -78,10 +78,10 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa msg.content_parts.push_back(msg_part); } } else if (!content.is_null()) { - throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); + throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); } } else { - throw std::runtime_error("Expected 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); + throw std::runtime_error("Expected 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); } if (message.contains("reasoning_content")) { msg.reasoning_content = message.at("reasoning_content"); From 5a5ed7bfd5fca2e8e5dc6e40c730757b459aca11 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 21:52:27 +0000 Subject: [PATCH 31/41] fix merge --- examples/server/utils.hpp | 35 ----------------------------------- 1 file changed, 35 deletions(-) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index e2cff4a330e55..5f485bebb38bd 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -345,41 +345,6 @@ static llama_tokens format_infill( return embd_inp; } -// Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const common_chat_template & tmpl, const std::vector & messages) { - std::vector chat; - - for (size_t i = 0; i < messages.size(); ++i) { - const auto & curr_msg = messages[i]; - - std::string role = json_value(curr_msg, "role", std::string("")); - - std::string content; - if (curr_msg.contains("content")) { - if (curr_msg["content"].is_string()) { - content = curr_msg["content"].get(); - } else if (curr_msg["content"].is_array()) { - for (const auto & part : curr_msg["content"]) { - if (part.contains("text")) { - content += "\n" + part["text"].get(); - } - } - } else { - throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); - } - } else { - throw std::runtime_error("Missing 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); - } - - chat.push_back({role, content, /* tool_calls= */ {}}); - } - - const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false); - LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); - - return formatted_chat; -} - // // base64 utils (TODO: move to common in the future) // From dd5ef85f204b8b0b1cb00c1b9ebe28914ac0a9d1 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 21:57:32 +0000 Subject: [PATCH 32/41] rm trailing spaces --- common/chat.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index 82c3bd0f183ff..264746b29fd4d 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -41,7 +41,7 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin template <> std::vector common_chat_msgs_parse_oaicompat(const json & messages) { std::vector msgs; - + try { if (!messages.is_array()) { @@ -203,7 +203,7 @@ std::vector common_chat_msgs_parse_oaicompat(const std::string template <> std::vector common_chat_tools_parse_oaicompat(const json & tools) { std::vector result; - + try { if (!tools.is_null()) { if (!tools.is_array()) { @@ -1417,7 +1417,7 @@ static common_chat_params common_chat_templates_apply_jinja( if (!inputs.json_schema.empty()) { params.json_schema = json::parse(inputs.json_schema); } - + if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); params.parallel_tool_calls = false; @@ -1554,7 +1554,7 @@ common_chat_params common_chat_templates_apply( const struct common_chat_templates_inputs & inputs) { GGML_ASSERT(tmpls != nullptr); - return inputs.use_jinja + return inputs.use_jinja ? common_chat_templates_apply_jinja(tmpls, inputs) : common_chat_templates_apply_legacy(tmpls, inputs); } From 2f2f0fa11435c737f425588fe5a88fa0d418b39d Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 21:59:09 +0000 Subject: [PATCH 33/41] Add missing include to chat.cpp --- common/chat.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/common/chat.cpp b/common/chat.cpp index 264746b29fd4d..825e17d399155 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -4,6 +4,7 @@ #include "minja/chat-template.hpp" #include "minja/minja.hpp" +#include typedef minja::chat_template common_chat_template; From a58b9e5edb31fed090b1240e2f0f53f84bf596c9 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 22:19:02 +0000 Subject: [PATCH 34/41] tiny fix: somehow llama_token being defined in an extern c makes it less than a c++ typedef --- examples/main/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 48b3389e69efe..b1d6e8a64382f 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -759,7 +759,7 @@ int main(int argc, char ** argv) { // check for reverse prompt using special tokens llama_token last_token = common_sampler_last(smpl); - if (std::find(antiprompt_token.begin(), antiprompt_token.end(), last_token) != antiprompt_token.end()) { + if (std::find(antiprompt_token.begin(), antiprompt_token.end(), (int32_t) last_token) != antiprompt_token.end()) { if (params.interactive) { is_interacting = true; } From f999ff565f2a5162b9b2043ca47fe9883d704d98 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 22:30:54 +0000 Subject: [PATCH 35/41] alternative fix for gcc c vs. c++ weirdness --- examples/main/main.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b1d6e8a64382f..b16f7f56772a7 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -759,11 +759,14 @@ int main(int argc, char ** argv) { // check for reverse prompt using special tokens llama_token last_token = common_sampler_last(smpl); - if (std::find(antiprompt_token.begin(), antiprompt_token.end(), (int32_t) last_token) != antiprompt_token.end()) { - if (params.interactive) { - is_interacting = true; + for (auto token : antiprompt_token) { + if (token == last_token) { + if (params.interactive) { + is_interacting = true; + } + is_antiprompt = true; + break; } - is_antiprompt = true; } if (is_antiprompt) { From 55a7614332e9af02f569c154658b90c80086d3c1 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 22:32:06 +0000 Subject: [PATCH 36/41] add missing include to test-chat-template --- tests/test-chat-template.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 7d3a0eb839a73..bcb61955b38ae 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #undef NDEBUG #include From 9d62f62fcca8b884c010919830486dbde1b5085a Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 16 Feb 2025 23:35:43 +0000 Subject: [PATCH 37/41] Update chat.hpp --- common/chat.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/common/chat.hpp b/common/chat.hpp index d5c3dca780ae1..a05665621b4b5 100644 --- a/common/chat.hpp +++ b/common/chat.hpp @@ -19,7 +19,6 @@ struct common_chat_msg_content_part { std::string text; }; -// same with llama_chat_message, but uses std::string struct common_chat_msg { std::string role; std::string content; From da0982a00e0011bcd6a03b22cdf47f33992151fd Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 17 Feb 2025 14:44:04 +0000 Subject: [PATCH 38/41] have common_chat_templates_init return a unique_ptr --- common/chat.cpp | 8 ++++---- common/chat.hpp | 8 +++++--- examples/main/main.cpp | 4 +--- examples/run/run.cpp | 4 +--- examples/server/server.cpp | 20 ++++++++++---------- tests/test-chat-template.cpp | 6 +++--- tests/test-chat.cpp | 2 +- 7 files changed, 25 insertions(+), 27 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index 825e17d399155..ec19d359b30c9 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -269,12 +269,12 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { msg.role = "user"; msg.content = "test"; - auto * tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl); + auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl); common_chat_templates_inputs inputs; inputs.messages = {msg}; - common_chat_templates_apply(tmpls, inputs); + common_chat_templates_apply(tmpls.get(), inputs); return true; } catch (const std::exception & e) { LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); @@ -362,7 +362,7 @@ const char * common_chat_templates_source(const struct common_chat_templates * t return tmpls->template_default->source().c_str(); } -struct common_chat_templates * common_chat_templates_init( +common_chat_templates_ptr common_chat_templates_init( const struct llama_model * model, const std::string & chat_template_override, const std::string & bos_token_override, @@ -426,7 +426,7 @@ struct common_chat_templates * common_chat_templates_init( LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what()); } } - return tmpls; + return {tmpls, common_chat_templates_free}; } std::string common_chat_format_name(common_chat_format format) { diff --git a/common/chat.hpp b/common/chat.hpp index a05665621b4b5..ba98d030a280b 100644 --- a/common/chat.hpp +++ b/common/chat.hpp @@ -85,7 +85,11 @@ struct common_chat_params { // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); -struct common_chat_templates * common_chat_templates_init( + +void common_chat_templates_free(struct common_chat_templates * tmpls); +typedef std::unique_ptr common_chat_templates_ptr; + +common_chat_templates_ptr common_chat_templates_init( const struct llama_model * model, const std::string & chat_template_override, const std::string & bos_token_override = "", @@ -93,9 +97,7 @@ struct common_chat_templates * common_chat_templates_init( bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls); const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr); -void common_chat_templates_free(struct common_chat_templates * tmpls); -typedef std::unique_ptr common_chat_templates_ptr; struct common_chat_params common_chat_templates_apply( const struct common_chat_templates * tmpls, diff --git a/examples/main/main.cpp b/examples/main/main.cpp index b16f7f56772a7..9620849ccd51c 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -158,9 +158,7 @@ int main(int argc, char ** argv) { } const llama_vocab * vocab = llama_model_get_vocab(model); - common_chat_templates_ptr chat_templates( - common_chat_templates_init(model, params.chat_template), - &common_chat_templates_free); + auto chat_templates = common_chat_templates_init(model, params.chat_template); LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); diff --git a/examples/run/run.cpp b/examples/run/run.cpp index c43d3d283c848..7fcc762d25501 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -1057,9 +1057,7 @@ static int get_user_input(std::string & user_input, const std::string & user) { static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) { int prev_len = 0; llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); - common_chat_templates_ptr chat_templates( - common_chat_templates_init(llama_data.model.get(), ""), - &common_chat_templates_free); + auto chat_templates = common_chat_templates_init(llama_data.model.get(), ""); static const bool stdout_a_terminal = is_stdout_a_terminal(); while (true) { // Get user input diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 51a2ff8ddcfaa..f72701e734a95 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1804,7 +1804,9 @@ struct server_context { // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; - struct common_chat_templates * chat_templates = nullptr; + common_chat_templates_ptr chat_templates; + + server_context() : chat_templates(nullptr, nullptr) {} ~server_context() { // Clear any sampling context @@ -1822,7 +1824,6 @@ struct server_context { } llama_batch_free(batch); - common_chat_templates_free(chat_templates); } bool load_model(const common_params & params) { @@ -1891,10 +1892,9 @@ struct server_context { chat_templates = common_chat_templates_init(model, params_base.chat_template); try { - common_chat_format_example(chat_templates, params.use_jinja); + common_chat_format_example(chat_templates.get(), params.use_jinja); } catch (const std::exception & e) { SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); - common_chat_templates_free(chat_templates); chat_templates = common_chat_templates_init(model, "chatml"); } @@ -3793,13 +3793,13 @@ int main(int argc, char ** argv) { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model }, - { "chat_template", common_chat_templates_source(ctx_server.chat_templates) }, + { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)}, { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)}, { "build_info", build_info }, }; if (ctx_server.params_base.use_jinja) { - if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates, "tool_use")) { + if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) { data["chat_template_tool_use"] = tool_use_src; } } @@ -4036,7 +4036,7 @@ int main(int argc, char ** argv) { } auto body = json::parse(req.body); - json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates); + json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get()); return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, @@ -4049,7 +4049,7 @@ int main(int argc, char ** argv) { // same with handle_chat_completions, but without inference part const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) { auto body = json::parse(req.body); - json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates); + json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get()); res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); }; @@ -4455,8 +4455,8 @@ int main(int argc, char ** argv) { // print sample chat example to make it clear which template is used LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - common_chat_templates_source(ctx_server.chat_templates), - common_chat_format_example(ctx_server.chat_templates, ctx_server.params_base.use_jinja).c_str()); + common_chat_templates_source(ctx_server.chat_templates.get()), + common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str()); ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) { ctx_server.process_single_task(task); diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index bcb61955b38ae..7dc61bd442da1 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -322,7 +322,7 @@ int main(void) { } printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str()); try { - common_chat_templates_ptr tmpls(common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token), &common_chat_templates_free); + auto tmpls = common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token); common_chat_templates_inputs inputs; inputs.use_jinja = true; inputs.messages = messages; @@ -349,7 +349,7 @@ int main(void) { auto sys_msg = simple_msg("system", "You are a helpful assistant"); auto fmt_sys = [&](std::string tmpl_str) { - common_chat_templates_ptr tmpls(common_chat_templates_init(/* model= */ nullptr, tmpl_str), &common_chat_templates_free); + auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str); auto output = common_chat_format_single(tmpls.get(), chat2, sys_msg, false, /* use_jinja= */ false); printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); @@ -376,7 +376,7 @@ int main(void) { auto new_msg = simple_msg("user", "How are you"); auto fmt_single = [&](const std::string & tmpl_str) { - common_chat_templates_ptr tmpls(common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str()), &common_chat_templates_free); + auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str()); auto output = common_chat_format_single(tmpls.get(), chat2, new_msg, true, /* use_jinja= */ false); printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 0d8596f8675c6..a4cbf4a5e1fe6 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -45,7 +45,7 @@ static std::string read_file(const std::string & path) { } static common_chat_templates_ptr read_templates(const std::string & path) { - return common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, read_file(path)), &common_chat_templates_free); + return common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, read_file(path))); } static std::unique_ptr build_grammar(const std::string & grammar_str) { From 7ddb4540344d5ce8c7683688a7be02b83f52164b Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 17 Feb 2025 14:56:22 +0000 Subject: [PATCH 39/41] chat.{hpp -> h} --- Makefile | 2 +- common/CMakeLists.txt | 2 +- common/arg.cpp | 2 +- common/chat.cpp | 2 +- common/{chat.hpp => chat.h} | 0 examples/main/main.cpp | 2 +- examples/run/run.cpp | 2 +- examples/server/utils.hpp | 2 +- tests/test-chat-template.cpp | 2 +- tests/test-chat.cpp | 2 +- 10 files changed, 9 insertions(+), 9 deletions(-) rename common/{chat.hpp => chat.h} (100%) diff --git a/Makefile b/Makefile index 662194086eaaf..fb9a3b44890a0 100644 --- a/Makefile +++ b/Makefile @@ -1364,7 +1364,7 @@ llama-server: \ examples/server/index.html.hpp \ examples/server/loading.html.hpp \ common/chat.cpp \ - common/chat.hpp \ + common/chat.h \ common/chat-template.hpp \ common/json.hpp \ common/minja.hpp \ diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index bf391c2ad90f0..17146fffc1168 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -57,7 +57,7 @@ add_library(${TARGET} STATIC arg.h base64.hpp chat.cpp - chat.hpp + chat.h common.cpp common.h console.cpp diff --git a/common/arg.cpp b/common/arg.cpp index 702777222008b..eb8beccac2ee7 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2,7 +2,7 @@ #include "log.h" #include "sampling.h" -#include "chat.hpp" +#include "chat.h" #include #include diff --git a/common/chat.cpp b/common/chat.cpp index ec19d359b30c9..b41756beea725 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1,4 +1,4 @@ -#include "chat.hpp" +#include "chat.h" #include "json-schema-to-grammar.h" #include "log.h" #include "minja/chat-template.hpp" diff --git a/common/chat.hpp b/common/chat.h similarity index 100% rename from common/chat.hpp rename to common/chat.h diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 9620849ccd51c..cf8659b037ee3 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -4,7 +4,7 @@ #include "log.h" #include "sampling.h" #include "llama.h" -#include "chat.hpp" +#include "chat.h" #include #include diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 7fcc762d25501..ed8644ef78d97 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -24,7 +24,7 @@ #include #include -#include "chat.hpp" +#include "chat.h" #include "common.h" #include "json.hpp" #include "linenoise.cpp/linenoise.h" diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 5f485bebb38bd..d25570b816ded 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -12,7 +12,7 @@ // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" -#include "chat.hpp" +#include "chat.h" #include #include diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 7dc61bd442da1..9231c517afb0b 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -8,7 +8,7 @@ #include "llama.h" #include "common.h" -#include "chat.hpp" +#include "chat.h" static std::string normalize_newlines(const std::string & s) { #ifdef _WIN32 diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index a4cbf4a5e1fe6..6435923054859 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -10,7 +10,7 @@ #include #include -#include "chat.hpp" +#include "chat.h" #include "llama-grammar.h" #include "unicode.h" From d2969b8730dba15d0d050a2c93ccbd339dd5e9c2 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 17 Feb 2025 15:45:18 +0000 Subject: [PATCH 40/41] build common_chat_templates_ptr earlier --- common/chat.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index b41756beea725..83e716b02918a 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -411,7 +411,7 @@ common_chat_templates_ptr common_chat_templates_init( token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); } - auto * tmpls = new common_chat_templates(); + common_chat_templates_ptr tmpls(new common_chat_templates(), common_chat_templates_free); tmpls->has_explicit_template = has_explicit_template; try { tmpls->template_default = std::make_unique(default_template_src, token_bos, token_eos); @@ -426,7 +426,7 @@ common_chat_templates_ptr common_chat_templates_init( LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what()); } } - return {tmpls, common_chat_templates_free}; + return tmpls; } std::string common_chat_format_name(common_chat_format format) { From fd2b8e10e9e1518156e1c068389e92fe4c6dd0b3 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 17 Feb 2025 15:50:27 +0000 Subject: [PATCH 41/41] use deleter functor for common_chat_templates_ptr --- common/chat.cpp | 2 +- common/chat.h | 6 ++++-- examples/server/server.cpp | 2 -- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index 83e716b02918a..9ebe4c5784cbc 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -411,7 +411,7 @@ common_chat_templates_ptr common_chat_templates_init( token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); } - common_chat_templates_ptr tmpls(new common_chat_templates(), common_chat_templates_free); + common_chat_templates_ptr tmpls(new common_chat_templates()); tmpls->has_explicit_template = has_explicit_template; try { tmpls->template_default = std::make_unique(default_template_src, token_bos, token_eos); diff --git a/common/chat.h b/common/chat.h index ba98d030a280b..e77bef82b9edd 100644 --- a/common/chat.h +++ b/common/chat.h @@ -85,9 +85,11 @@ struct common_chat_params { // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); - void common_chat_templates_free(struct common_chat_templates * tmpls); -typedef std::unique_ptr common_chat_templates_ptr; + +struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } }; + +typedef std::unique_ptr common_chat_templates_ptr; common_chat_templates_ptr common_chat_templates_init( const struct llama_model * model, diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f72701e734a95..c50ef03c5b353 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1806,8 +1806,6 @@ struct server_context { common_chat_templates_ptr chat_templates; - server_context() : chat_templates(nullptr, nullptr) {} - ~server_context() { // Clear any sampling context for (server_slot & slot : slots) {