Skip to content

Commit 12c7897

Browse files
authored
[NPU C++] Update example with conversation mode support (#12510)
1 parent 0918d3b commit 12c7897

File tree

2 files changed

+178
-34
lines changed

2 files changed

+178
-34
lines changed

python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/README.md

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,16 +92,22 @@ cd Release
9292
With built `llm-npu-cli`, you can run the example with specified paramaters. For example,
9393

9494
```cmd
95+
# Run simple text completion
9596
llm-npu-cli.exe -m <converted_model_path> -n 64 "AI是什么?"
97+
98+
# Run in conversation mode
99+
llm-npu-cli.exe -m <converted_model_path> -cnv
96100
```
97101

98102
Arguments info:
99103
- `-m` : argument defining the path of saved converted model.
104+
- `-cnv` : argument to enable conversation mode.
100105
- `-n` : argument defining how many tokens will be generated.
101106
- Last argument is your input prompt.
102107

103108
### 5. Sample Output
104109
#### [`Qwen/Qwen2.5-7B-Instruct`](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct)
110+
##### Text Completion
105111
```cmd
106112
Input:
107113
<|im_start|>system
@@ -112,10 +118,23 @@ AI是什么?<|im_end|>
112118
113119
114120
Prefill 22 tokens cost xxxx ms.
121+
122+
Decode 63 tokens cost xxxx ms (avg xx.xx ms each token).
115123
Output:
116-
AI是"人工智能"的缩写,是英文"Artificial Intelligence"的翻译。它是研究如何使计算机也具有智能的一种技术和理论。简而言之,人工智能就是让计算机能够模仿人智能行为的一项技术。
124+
AI是"人工智能"的缩写,它是一门研究计算机如何能够完成与人类智能相关任务的学科,包括学习、推理、自我修正等能力。简而言之,人工智能就是让计算机模拟或执行人类智能行为的理论、技术和方法。
125+
126+
它涵盖了机器学习、深度学习、自然
127+
```
128+
129+
##### Conversation
130+
```cmd
131+
User:你好
132+
Assistant:你好!很高兴能为你提供帮助。有什么问题或需要聊天可以找我哦。
133+
User:AI是什么?
134+
Assistant: AI代表的是"Artificial Intelligence",中文翻译为人工智能。它是指由计算机或信息技术实现的智能行为。广义的人工智能可以指任何表现出智能行为的计算机或软件系统。狭义的人工智能则指的是模拟、学习、推理、理解自然语言以及自我生成的人工智能系统。
117135
118-
Decode 46 tokens cost xxxx ms (avg xx.xx ms each token).
136+
简而言之,人工智能是一种利用计算机和机器来模仿、模拟或扩展人类智能的技术或系统。它包括机器学习、深度学习、自然语言处理等多个子领域。
137+
User:exit
119138
```
120139

121140
### Troubleshooting

python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/llm-npu-cli.cpp

Lines changed: 157 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,119 @@
2525

2626
static void print_usage(int, char ** argv) {
2727
printf("\nexample usage:\n");
28-
printf("\n %s -m npu_model_dir [-n n_predict] [prompt]\n", argv[0]);
28+
printf("\n %s -m npu_model_dir [-cnv] [-n n_predict] [prompt]\n", argv[0]);
2929
printf("\n");
3030
}
3131

3232

33+
const std::string llama2_template = "<s>[INST] <<SYS>>\n\n<</SYS>>\n\n%s [/INST]";
34+
const std::string llama3_template = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n";
35+
const std::string minicpm_template = "<用户>%s<AI>";
36+
const std::string qwen2_template = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n";
37+
38+
39+
std::string add_chat_history(npu_model_params model_params,
40+
std::string new_prompt, std::string chat_history, bool is_input) {
41+
char prompt[8092];
42+
if (model_params.model_type == std::string("llama") && model_params.vocab_size == 32000) {
43+
if (chat_history == ""){
44+
sprintf_s(prompt, llama2_template.c_str(), new_prompt.c_str());
45+
}else{
46+
if (is_input){
47+
std::string input_template = "%s%s [/INST]";
48+
sprintf_s(prompt, input_template.c_str(), chat_history.c_str(), new_prompt.c_str());
49+
}
50+
else{
51+
std::string res_template = "%s%s </s><s>[INST]";
52+
sprintf_s(prompt, res_template.c_str(), chat_history.c_str(), new_prompt.c_str());
53+
}
54+
}
55+
56+
} else if (model_params.model_type == std::string("llama")) {
57+
if (chat_history == ""){
58+
sprintf_s(prompt, llama3_template.c_str(), new_prompt.c_str());
59+
}else{
60+
if (is_input){
61+
std::string input_template = "%s<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|>";
62+
sprintf_s(prompt, input_template.c_str(), chat_history.c_str(), new_prompt.c_str());
63+
}
64+
else{
65+
std::string res_template = "%s<|start_header_id|>assistant<|end_header_id|>\n\n%s<|eot_id|>";
66+
sprintf_s(prompt, res_template.c_str(), chat_history.c_str(), new_prompt.c_str());
67+
}
68+
}
69+
} else if (model_params.model_type == std::string("qwen2")) {
70+
if (chat_history == ""){
71+
sprintf_s(prompt, qwen2_template.c_str(), new_prompt.c_str());
72+
}else{
73+
if (is_input){
74+
std::string input_template = "%s%s<|im_end|>\n<|im_start|>assistant";
75+
sprintf_s(prompt, input_template.c_str(), chat_history.c_str(), new_prompt.c_str());
76+
}
77+
else{
78+
std::string res_template = "%s%s<|im_end|>\n<|im_start|>user\n";
79+
sprintf_s(prompt, res_template.c_str(), chat_history.c_str(), new_prompt.c_str());
80+
}
81+
}
82+
} else if (model_params.model_type == std::string("minicpm")) {
83+
if (chat_history == ""){
84+
sprintf_s(prompt, minicpm_template.c_str(), new_prompt.c_str());
85+
}else{
86+
if (is_input){
87+
std::string input_template = "%s%s<AI>";
88+
sprintf_s(prompt, input_template.c_str(), chat_history.c_str(), new_prompt.c_str());
89+
}
90+
else{
91+
std::string res_template = "%s%s<用户>";
92+
sprintf_s(prompt, res_template.c_str(), chat_history.c_str(), new_prompt.c_str());
93+
}
94+
}
95+
} else {
96+
sprintf_s(prompt, chat_history.c_str(), new_prompt.c_str());
97+
}
98+
return prompt;
99+
}
100+
101+
102+
std::string run_generate(void* void_model, int32_t* embd_inp_ptr, int32_t embd_inp_size,
103+
npu_model_params model_params, tokenizer_params tok_params, int32_t max_new_token, bool do_print){
104+
auto start = std::chrono::high_resolution_clock::now();
105+
float* logits = run_prefill(void_model, embd_inp_ptr, embd_inp_size);
106+
int32_t token = llm_sample_token(logits, true, model_params.vocab_size);
107+
auto end = std::chrono::high_resolution_clock::now();
108+
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
109+
if (do_print){
110+
printf("\nPrefill %d tokens cost %d ms.\n", embd_inp_size, duration.count());
111+
}
112+
113+
std::vector<int32_t> embd; // output ids
114+
embd.push_back(token);
115+
116+
int token_nums = 0;
117+
start = std::chrono::high_resolution_clock::now();
118+
for (int i = 1; i < max_new_token; i++){
119+
auto logits = run_decode(void_model, embd[i-1]);
120+
int32_t token = llm_sample_token(logits, true, model_params.vocab_size);
121+
if (std::find(tok_params.eos_token_id.begin(), tok_params.eos_token_id.end(), token) == tok_params.eos_token_id.end()){
122+
embd.push_back(token);
123+
token_nums ++;
124+
} else {
125+
break;
126+
}
127+
}
128+
end = std::chrono::high_resolution_clock::now();
129+
duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
130+
131+
std::string output = llm_decode(embd);
132+
133+
if (do_print){
134+
printf("\nDecode %d tokens cost %d ms (avg %f ms each token).\n", token_nums, duration.count(), (float)duration.count() / token_nums);
135+
}
136+
137+
return output;
138+
}
139+
140+
33141
int main(int argc, char ** argv) {
34142
common_params params;
35143

@@ -39,6 +147,7 @@ int main(int argc, char ** argv) {
39147
std::string prompt = "AI是什么?";
40148
// number of tokens to predict
41149
int n_predict = 32;
150+
bool cnv_mode = false;
42151

43152
// parse command line arguments
44153

@@ -52,7 +161,11 @@ int main(int argc, char ** argv) {
52161
print_usage(argc, argv);
53162
return 1;
54163
}
55-
} else if (strcmp(argv[i], "-n") == 0) {
164+
} else if (strcmp(argv[i], "-cnv") == 0){
165+
// multi-round conversation mode
166+
cnv_mode = true;
167+
break;
168+
}else if (strcmp(argv[i], "-n") == 0) {
56169
if (i + 1 < argc) {
57170
try {
58171
n_predict = std::stoi(argv[++i]);
@@ -86,50 +199,62 @@ int main(int argc, char ** argv) {
86199
params.model = model_dir;
87200
params.prompt = prompt;
88201

202+
// npu_model_params model_params;
89203
void* model = load_model_from_file(params.model);
90204
npu_model_params model_params;
91205
load_config_from_file(model_params, params.model);
92206

93207
tokenizer_params tok_params;
94208
load_tokenizer(tok_params, params.model);
95209

96-
std::string full_prompt = add_chat_template(model_params, params.prompt);
97-
std::cout << "Input: " << std::endl;
98-
std::cout << full_prompt << std::endl;
210+
if (cnv_mode){
211+
std::string prompt;
212+
std::string history = "";
213+
std::string response;
214+
while(true){
215+
std::cout << "User:";
216+
std::getline(std::cin, prompt);
217+
if (prompt == "exit"){
218+
break;
219+
}
220+
else{
221+
// process prompt with chat history
222+
std::string full_prompt = add_chat_history(model_params, prompt, history, true);
99223

100-
// tokenize input
101-
std::vector<int32_t> embd_inp = llm_tokenize(full_prompt, false);
224+
// tokenize input
225+
std::vector<int32_t> embd_inp = llm_tokenize(full_prompt, false);
226+
if (embd_inp.size() > model_params.max_prompt_len){
227+
// empty chat history
228+
full_prompt = add_chat_history(model_params, prompt, "", true);
229+
embd_inp = llm_tokenize(full_prompt, false);
230+
}
231+
232+
response = run_generate(model, embd_inp.data(), embd_inp.size(),
233+
model_params, tok_params, model_params.kv_len - embd_inp.size(), false);
102234

103-
std::vector<int32_t> embd; // output ids
104-
auto start = std::chrono::high_resolution_clock::now();
105-
float* logits = run_prefill(model, embd_inp.data(), embd_inp.size());
106-
int32_t token = llm_sample_token(logits, true, model_params.vocab_size);
107-
auto end = std::chrono::high_resolution_clock::now();
108-
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
109-
printf("\nPrefill %d tokens cost %d ms.\n", embd_inp.size(), duration.count());
110-
embd.push_back(token);
235+
std::cout << "Assistant:";
236+
std::cout << response << std::endl;
111237

112-
int token_nums = 0;
113-
start = std::chrono::high_resolution_clock::now();
114-
for (int i = 1; i < params.n_predict; i++){
115-
auto logits = run_decode(model, embd[i-1]);
116-
int32_t token = llm_sample_token(logits, true, model_params.vocab_size);
117-
if (std::find(tok_params.eos_token_id.begin(), tok_params.eos_token_id.end(), token) == tok_params.eos_token_id.end()){
118-
embd.push_back(token);
119-
token_nums ++;
120-
} else {
121-
break;
238+
history = add_chat_history(model_params, response, full_prompt, false);
239+
240+
reset(model);
241+
}
122242
}
123243
}
124-
end = std::chrono::high_resolution_clock::now();
125-
duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
126-
127-
std::string output = llm_decode(embd);
244+
else{
245+
std::string full_prompt = add_chat_template(model_params, params.prompt);
246+
std::cout << "Input: " << std::endl;
247+
std::cout << full_prompt << std::endl;
128248

129-
std::cout << "Output: " << std::endl;
130-
std::cout << output << std::endl;
249+
// tokenize input
250+
std::vector<int32_t> embd_inp = llm_tokenize(full_prompt, false);
131251

132-
printf("\nDecode %d tokens cost %d ms (avg %f ms each token).\n", token_nums, duration.count(), (float)duration.count() / token_nums);
252+
// single text generation
253+
std::string output = run_generate(model, embd_inp.data(), embd_inp.size(),
254+
model_params, tok_params, params.n_predict, true);
133255

256+
std::cout << "Output: " << std::endl;
257+
std::cout << output << std::endl;
258+
}
134259
return 0;
135260
}

0 commit comments

Comments
 (0)