25
25
26
26
static void print_usage (int , char ** argv) {
27
27
printf (" \n example usage:\n " );
28
- printf (" \n %s -m npu_model_dir [-n n_predict] [prompt]\n " , argv[0 ]);
28
+ printf (" \n %s -m npu_model_dir [-cnv] [- n n_predict] [prompt]\n " , argv[0 ]);
29
29
printf (" \n " );
30
30
}
31
31
32
32
33
+ const std::string llama2_template = " <s>[INST] <<SYS>>\n\n <</SYS>>\n\n %s [/INST]" ;
34
+ const std::string llama3_template = " <|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n " ;
35
+ const std::string minicpm_template = " <用户>%s<AI>" ;
36
+ const std::string qwen2_template = " <|im_start|>system\n You are a helpful assistant.<|im_end|>\n <|im_start|>user\n %s<|im_end|>\n <|im_start|>assistant\n " ;
37
+
38
+
39
+ std::string add_chat_history (npu_model_params model_params,
40
+ std::string new_prompt, std::string chat_history, bool is_input) {
41
+ char prompt[8092 ];
42
+ if (model_params.model_type == std::string (" llama" ) && model_params.vocab_size == 32000 ) {
43
+ if (chat_history == " " ){
44
+ sprintf_s (prompt, llama2_template.c_str (), new_prompt.c_str ());
45
+ }else {
46
+ if (is_input){
47
+ std::string input_template = " %s%s [/INST]" ;
48
+ sprintf_s (prompt, input_template.c_str (), chat_history.c_str (), new_prompt.c_str ());
49
+ }
50
+ else {
51
+ std::string res_template = " %s%s </s><s>[INST]" ;
52
+ sprintf_s (prompt, res_template.c_str (), chat_history.c_str (), new_prompt.c_str ());
53
+ }
54
+ }
55
+
56
+ } else if (model_params.model_type == std::string (" llama" )) {
57
+ if (chat_history == " " ){
58
+ sprintf_s (prompt, llama3_template.c_str (), new_prompt.c_str ());
59
+ }else {
60
+ if (is_input){
61
+ std::string input_template = " %s<|start_header_id|>user<|end_header_id|>\n\n %s<|eot_id|>" ;
62
+ sprintf_s (prompt, input_template.c_str (), chat_history.c_str (), new_prompt.c_str ());
63
+ }
64
+ else {
65
+ std::string res_template = " %s<|start_header_id|>assistant<|end_header_id|>\n\n %s<|eot_id|>" ;
66
+ sprintf_s (prompt, res_template.c_str (), chat_history.c_str (), new_prompt.c_str ());
67
+ }
68
+ }
69
+ } else if (model_params.model_type == std::string (" qwen2" )) {
70
+ if (chat_history == " " ){
71
+ sprintf_s (prompt, qwen2_template.c_str (), new_prompt.c_str ());
72
+ }else {
73
+ if (is_input){
74
+ std::string input_template = " %s%s<|im_end|>\n <|im_start|>assistant" ;
75
+ sprintf_s (prompt, input_template.c_str (), chat_history.c_str (), new_prompt.c_str ());
76
+ }
77
+ else {
78
+ std::string res_template = " %s%s<|im_end|>\n <|im_start|>user\n " ;
79
+ sprintf_s (prompt, res_template.c_str (), chat_history.c_str (), new_prompt.c_str ());
80
+ }
81
+ }
82
+ } else if (model_params.model_type == std::string (" minicpm" )) {
83
+ if (chat_history == " " ){
84
+ sprintf_s (prompt, minicpm_template.c_str (), new_prompt.c_str ());
85
+ }else {
86
+ if (is_input){
87
+ std::string input_template = " %s%s<AI>" ;
88
+ sprintf_s (prompt, input_template.c_str (), chat_history.c_str (), new_prompt.c_str ());
89
+ }
90
+ else {
91
+ std::string res_template = " %s%s<用户>" ;
92
+ sprintf_s (prompt, res_template.c_str (), chat_history.c_str (), new_prompt.c_str ());
93
+ }
94
+ }
95
+ } else {
96
+ sprintf_s (prompt, chat_history.c_str (), new_prompt.c_str ());
97
+ }
98
+ return prompt;
99
+ }
100
+
101
+
102
+ std::string run_generate (void * void_model, int32_t * embd_inp_ptr, int32_t embd_inp_size,
103
+ npu_model_params model_params, tokenizer_params tok_params, int32_t max_new_token, bool do_print){
104
+ auto start = std::chrono::high_resolution_clock::now ();
105
+ float * logits = run_prefill (void_model, embd_inp_ptr, embd_inp_size);
106
+ int32_t token = llm_sample_token (logits, true , model_params.vocab_size );
107
+ auto end = std::chrono::high_resolution_clock::now ();
108
+ auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
109
+ if (do_print){
110
+ printf (" \n Prefill %d tokens cost %d ms.\n " , embd_inp_size, duration.count ());
111
+ }
112
+
113
+ std::vector<int32_t > embd; // output ids
114
+ embd.push_back (token);
115
+
116
+ int token_nums = 0 ;
117
+ start = std::chrono::high_resolution_clock::now ();
118
+ for (int i = 1 ; i < max_new_token; i++){
119
+ auto logits = run_decode (void_model, embd[i-1 ]);
120
+ int32_t token = llm_sample_token (logits, true , model_params.vocab_size );
121
+ if (std::find (tok_params.eos_token_id .begin (), tok_params.eos_token_id .end (), token) == tok_params.eos_token_id .end ()){
122
+ embd.push_back (token);
123
+ token_nums ++;
124
+ } else {
125
+ break ;
126
+ }
127
+ }
128
+ end = std::chrono::high_resolution_clock::now ();
129
+ duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
130
+
131
+ std::string output = llm_decode (embd);
132
+
133
+ if (do_print){
134
+ printf (" \n Decode %d tokens cost %d ms (avg %f ms each token).\n " , token_nums, duration.count (), (float )duration.count () / token_nums);
135
+ }
136
+
137
+ return output;
138
+ }
139
+
140
+
33
141
int main (int argc, char ** argv) {
34
142
common_params params;
35
143
@@ -39,6 +147,7 @@ int main(int argc, char ** argv) {
39
147
std::string prompt = " AI是什么?" ;
40
148
// number of tokens to predict
41
149
int n_predict = 32 ;
150
+ bool cnv_mode = false ;
42
151
43
152
// parse command line arguments
44
153
@@ -52,7 +161,11 @@ int main(int argc, char ** argv) {
52
161
print_usage (argc, argv);
53
162
return 1 ;
54
163
}
55
- } else if (strcmp (argv[i], " -n" ) == 0 ) {
164
+ } else if (strcmp (argv[i], " -cnv" ) == 0 ){
165
+ // multi-round conversation mode
166
+ cnv_mode = true ;
167
+ break ;
168
+ }else if (strcmp (argv[i], " -n" ) == 0 ) {
56
169
if (i + 1 < argc) {
57
170
try {
58
171
n_predict = std::stoi (argv[++i]);
@@ -86,50 +199,62 @@ int main(int argc, char ** argv) {
86
199
params.model = model_dir;
87
200
params.prompt = prompt;
88
201
202
+ // npu_model_params model_params;
89
203
void * model = load_model_from_file (params.model );
90
204
npu_model_params model_params;
91
205
load_config_from_file (model_params, params.model );
92
206
93
207
tokenizer_params tok_params;
94
208
load_tokenizer (tok_params, params.model );
95
209
96
- std::string full_prompt = add_chat_template (model_params, params.prompt );
97
- std::cout << " Input: " << std::endl;
98
- std::cout << full_prompt << std::endl;
210
+ if (cnv_mode){
211
+ std::string prompt;
212
+ std::string history = " " ;
213
+ std::string response;
214
+ while (true ){
215
+ std::cout << " User:" ;
216
+ std::getline (std::cin, prompt);
217
+ if (prompt == " exit" ){
218
+ break ;
219
+ }
220
+ else {
221
+ // process prompt with chat history
222
+ std::string full_prompt = add_chat_history (model_params, prompt, history, true );
99
223
100
- // tokenize input
101
- std::vector<int32_t > embd_inp = llm_tokenize (full_prompt, false );
224
+ // tokenize input
225
+ std::vector<int32_t > embd_inp = llm_tokenize (full_prompt, false );
226
+ if (embd_inp.size () > model_params.max_prompt_len ){
227
+ // empty chat history
228
+ full_prompt = add_chat_history (model_params, prompt, " " , true );
229
+ embd_inp = llm_tokenize (full_prompt, false );
230
+ }
231
+
232
+ response = run_generate (model, embd_inp.data (), embd_inp.size (),
233
+ model_params, tok_params, model_params.kv_len - embd_inp.size (), false );
102
234
103
- std::vector<int32_t > embd; // output ids
104
- auto start = std::chrono::high_resolution_clock::now ();
105
- float * logits = run_prefill (model, embd_inp.data (), embd_inp.size ());
106
- int32_t token = llm_sample_token (logits, true , model_params.vocab_size );
107
- auto end = std::chrono::high_resolution_clock::now ();
108
- auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
109
- printf (" \n Prefill %d tokens cost %d ms.\n " , embd_inp.size (), duration.count ());
110
- embd.push_back (token);
235
+ std::cout << " Assistant:" ;
236
+ std::cout << response << std::endl;
111
237
112
- int token_nums = 0 ;
113
- start = std::chrono::high_resolution_clock::now ();
114
- for (int i = 1 ; i < params.n_predict ; i++){
115
- auto logits = run_decode (model, embd[i-1 ]);
116
- int32_t token = llm_sample_token (logits, true , model_params.vocab_size );
117
- if (std::find (tok_params.eos_token_id .begin (), tok_params.eos_token_id .end (), token) == tok_params.eos_token_id .end ()){
118
- embd.push_back (token);
119
- token_nums ++;
120
- } else {
121
- break ;
238
+ history = add_chat_history (model_params, response, full_prompt, false );
239
+
240
+ reset (model);
241
+ }
122
242
}
123
243
}
124
- end = std::chrono::high_resolution_clock::now ();
125
- duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start );
126
-
127
- std::string output = llm_decode (embd) ;
244
+ else {
245
+ std::string full_prompt = add_chat_template (model_params, params. prompt );
246
+ std::cout << " Input: " << std::endl;
247
+ std::cout << full_prompt << std::endl ;
128
248
129
- std::cout << " Output: " << std::endl;
130
- std::cout << output << std::endl ;
249
+ // tokenize input
250
+ std::vector< int32_t > embd_inp = llm_tokenize (full_prompt, false ) ;
131
251
132
- printf (" \n Decode %d tokens cost %d ms (avg %f ms each token).\n " , token_nums, duration.count (), (float )duration.count () / token_nums);
252
+ // single text generation
253
+ std::string output = run_generate (model, embd_inp.data (), embd_inp.size (),
254
+ model_params, tok_params, params.n_predict , true );
133
255
256
+ std::cout << " Output: " << std::endl;
257
+ std::cout << output << std::endl;
258
+ }
134
259
return 0 ;
135
260
}
0 commit comments