Skip to content

Commit 128d522

Browse files
chat : support Magistral thinking (#16413)
* feat: added a dedicated Magistral chat format that preserves [THINK] spans, parses reasoning before tool calls * feat: new flow in the chat template test suite for Magistral
1 parent f6dcda3 commit 128d522

File tree

3 files changed

+92
-0
lines changed

3 files changed

+92
-0
lines changed

common/chat.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,7 @@ const char * common_chat_format_name(common_chat_format format) {
625625
case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only";
626626
case COMMON_CHAT_FORMAT_GENERIC: return "Generic";
627627
case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo";
628+
case COMMON_CHAT_FORMAT_MAGISTRAL: return "Magistral";
628629
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
629630
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
630631
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
@@ -984,6 +985,65 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
984985
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
985986
return data;
986987
}
988+
989+
static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) {
990+
common_chat_params data;
991+
data.prompt = apply(tmpl, inputs);
992+
data.format = COMMON_CHAT_FORMAT_MAGISTRAL;
993+
data.preserved_tokens = {
994+
"[THINK]",
995+
"[/THINK]",
996+
};
997+
998+
if (inputs.tools.is_array() && !inputs.tools.empty()) {
999+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1000+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1001+
auto schemas = json::array();
1002+
foreach_function(inputs.tools, [&](const json & tool) {
1003+
const auto & function = tool.at("function");
1004+
schemas.push_back({
1005+
{"type", "object"},
1006+
{"properties", {
1007+
{"name", {
1008+
{"type", "string"},
1009+
{"const", function.at("name")},
1010+
}},
1011+
{"arguments", function.at("parameters")},
1012+
{"id", {
1013+
{"type", "string"},
1014+
{"pattern", "^[a-zA-Z0-9]{9}$"},
1015+
}},
1016+
}},
1017+
{"required", json::array({"name", "arguments", "id"})},
1018+
});
1019+
});
1020+
auto schema = json {
1021+
{"type", "array"},
1022+
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
1023+
{"minItems", 1},
1024+
};
1025+
if (!inputs.parallel_tool_calls) {
1026+
schema["maxItems"] = 1;
1027+
}
1028+
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
1029+
});
1030+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"});
1031+
data.preserved_tokens.push_back("[TOOL_CALLS]");
1032+
} else {
1033+
data.grammar_lazy = false;
1034+
if (!inputs.json_schema.is_null()) {
1035+
if (!inputs.grammar.empty()) {
1036+
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
1037+
}
1038+
data.grammar = json_schema_to_grammar(inputs.json_schema);
1039+
} else {
1040+
data.grammar = inputs.grammar;
1041+
}
1042+
}
1043+
1044+
return data;
1045+
}
1046+
9871047
static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
9881048
if (!builder.syntax().parse_tool_calls) {
9891049
builder.add_content(builder.consume_rest());
@@ -994,6 +1054,18 @@ static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
9941054
parse_prefixed_json_tool_call_array(builder, prefix);
9951055
}
9961056

1057+
static void common_chat_parse_magistral(common_chat_msg_parser & builder) {
1058+
builder.try_parse_reasoning("[THINK]", "[/THINK]");
1059+
1060+
if (!builder.syntax().parse_tool_calls) {
1061+
builder.add_content(builder.consume_rest());
1062+
return;
1063+
}
1064+
1065+
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
1066+
parse_prefixed_json_tool_call_array(builder, prefix);
1067+
}
1068+
9971069
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
9981070
common_chat_params data;
9991071

@@ -2702,6 +2774,10 @@ static common_chat_params common_chat_templates_apply_jinja(
27022774
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
27032775
}
27042776

2777+
if (src.find("[THINK]") != std::string::npos && src.find("[/THINK]") != std::string::npos) {
2778+
return common_chat_params_init_magistral(tmpl, params);
2779+
}
2780+
27052781
// Plain handler (no tools)
27062782
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
27072783
return common_chat_params_init_without_tools(tmpl, params);
@@ -2802,6 +2878,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
28022878
case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
28032879
common_chat_parse_mistral_nemo(builder);
28042880
break;
2881+
case COMMON_CHAT_FORMAT_MAGISTRAL:
2882+
common_chat_parse_magistral(builder);
2883+
break;
28052884
case COMMON_CHAT_FORMAT_LLAMA_3_X:
28062885
common_chat_parse_llama_3_1(builder);
28072886
break;

common/chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ enum common_chat_format {
101101
COMMON_CHAT_FORMAT_CONTENT_ONLY,
102102
COMMON_CHAT_FORMAT_GENERIC,
103103
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
104+
COMMON_CHAT_FORMAT_MAGISTRAL,
104105
COMMON_CHAT_FORMAT_LLAMA_3_X,
105106
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
106107
COMMON_CHAT_FORMAT_DEEPSEEK_R1,

tests/test-chat.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ const common_chat_msg message_assist_thoughts_unparsed_md = simple_assis
411411
const common_chat_msg message_assist_thoughts_unparsed_md_partial = simple_assist_msg("<think>I'm\nthinking</think>Hello, world!\nWhat's up?\n```json\n{}");
412412

413413
const common_chat_msg message_assist_thoughts_unparsed_r7b = simple_assist_msg("<|START_THINKING|>I'm\nthinking<|END_THINKING|>Hello, world!\nWhat's up?");
414+
const common_chat_msg message_assist_thoughts_unparsed_magistral = simple_assist_msg("[THINK]raisonnement[/THINK]Réponse");
414415
const common_chat_msg message_assist_thoughts = simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking");
415416
const common_chat_msg message_assist_thoughts_unopened_unparsed = simple_assist_msg("I'm\nthinking</think>Hello, world!\nWhat's up?");
416417
const common_chat_msg message_assist_thoughts_no_content = simple_assist_msg("", "I'm\nthinking");
@@ -745,6 +746,17 @@ static void test_template_output_parsers() {
745746
tmpls.get(), end_tokens, message_assist_call_id, tools,
746747
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
747748
}
749+
{
750+
assert_msg_equals(
751+
simple_assist_msg("Réponse", "raisonnement"),
752+
common_chat_parse(
753+
message_assist_thoughts_unparsed_magistral.content,
754+
/* is_partial= */ false,
755+
{
756+
/* .format = */ COMMON_CHAT_FORMAT_MAGISTRAL,
757+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
758+
}));
759+
}
748760
{
749761
auto tmpls = read_templates("models/templates/Qwen-QwQ-32B.jinja");
750762
std::vector<std::string> end_tokens{ "<|im_end|>" };

0 commit comments

Comments
 (0)