-
Notifications
You must be signed in to change notification settings - Fork 13.2k
Add llama_chat_apply_template() #5538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
4e64440
bba75c7
9c4422f
6012ad6
7a3eac8
011af99
dba4337
73fbd67
649f6f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -12459,6 +12459,122 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token | |||||||||||
return 0; | ||||||||||||
} | ||||||||||||
|
||||||||||||
int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_template, std::vector<const llama_chat_message *> conversation, bool add_ass); | ||||||||||||
|
||||||||||||
// trim whitespace from the beginning and end of a string | ||||||||||||
static std::string trim(const std::string & str) { | ||||||||||||
size_t start = 0; | ||||||||||||
size_t end = str.size(); | ||||||||||||
while (start < end && isspace(str[start])) { | ||||||||||||
start += 1; | ||||||||||||
} | ||||||||||||
while (end > start && isspace(str[end - 1])) { | ||||||||||||
end -= 1; | ||||||||||||
} | ||||||||||||
return str.substr(start, end - start); | ||||||||||||
} | ||||||||||||
|
||||||||||||
// Simple version of "llama_apply_chat_template" that only works with strings | ||||||||||||
// This function uses heuristic checks to determine commonly used template. It is not a jinja parser. | ||||||||||||
int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_template, std::vector<const llama_chat_message *> conversation, bool add_ass) { | ||||||||||||
|
int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_template, std::vector<const llama_chat_message *> conversation, bool add_ass) { | |
int32_t llama_chat_apply_template_internal( | |
const std::string & chat_template, | |
const std::vector<const llama_chat_message *> & chat, | |
std::string & dest, bool add_ass) { |
The terms chat
and conversation
seem conflated. Propose to use chat
universally (apply to other places too)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed all the occurrences of conversation
and msg
to chat
in this commit: 73fbd67
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
#include <iostream> | ||
#include <string> | ||
#include <vector> | ||
#include <sstream> | ||
|
||
#undef NDEBUG | ||
#include <cassert> | ||
|
||
#include "llama.h" | ||
|
||
int main(void) { | ||
llama_chat_message conversation[] = { | ||
{"system", "You are a helpful assistant"}, | ||
{"user", "Hello"}, | ||
{"assistant", "Hi there"}, | ||
{"user", "Who are you"}, | ||
{"assistant", " I am an assistant "}, | ||
{"user", "Another question"}, | ||
}; | ||
size_t message_count = 6; | ||
std::vector<std::string> templates = { | ||
// teknium/OpenHermes-2.5-Mistral-7B | ||
"{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", | ||
// mistralai/Mistral-7B-Instruct-v0.2 | ||
"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", | ||
// TheBloke/FusionNet_34Bx2_MoE-AWQ | ||
"{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <<SYS>>\\\\n' + messages[idx]['content'] + '\\\\n<</SYS>>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}", | ||
// bofenghuang/vigogne-2-70b-chat | ||
"{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\\\n' + system_message + '\\\\n<</SYS>>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\\\n' + content.strip() + '\\\\n<</SYS>>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", | ||
}; | ||
std::vector<std::string> expected_substr = { | ||
"<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant", | ||
"[/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]", | ||
"</s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]", | ||
"[/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]", | ||
}; | ||
std::vector<char> formatted_chat(1024); | ||
int32_t res; | ||
|
||
// test invalid chat template | ||
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size()); | ||
assert(res < 0); | ||
|
||
for (size_t i = 0; i < templates.size(); i++) { | ||
std::string custom_template = templates[i]; | ||
std::string substr = expected_substr[i]; | ||
formatted_chat.resize(1024); | ||
res = llama_chat_apply_template( | ||
nullptr, | ||
custom_template.c_str(), | ||
conversation, | ||
message_count, | ||
true, | ||
formatted_chat.data(), | ||
formatted_chat.size() | ||
); | ||
formatted_chat.resize(res); | ||
std::string output(formatted_chat.data(), formatted_chat.size()); | ||
std::cout << output << "\n-------------------------\n"; | ||
// expect the "formatted_chat" to contain pre-defined strings | ||
assert(output.find(substr) != std::string::npos); | ||
} | ||
return 0; | ||
} |
Uh oh!
There was an error while loading. Please reload this page.