@@ -95,8 +95,13 @@ int main(int argc, char ** argv) {
95
95
params.n_batch = params.n_ctx ;
96
96
}
97
97
98
- // For non-causal models, batch size must be equal to ubatch size
99
- params.n_ubatch = params.n_batch ;
98
+ // for non-causal models, batch size must be equal to ubatch size
99
+ if (params.attention_type != LLAMA_ATTENTION_TYPE_CAUSAL) {
100
+ params.n_ubatch = params.n_batch ;
101
+ }
102
+
103
+ // get max number of sequences per batch
104
+ const int n_seq_max = llama_max_parallel_sequences ();
100
105
101
106
llama_backend_init ();
102
107
llama_numa_init (params.numa );
@@ -144,6 +149,7 @@ int main(int argc, char ** argv) {
144
149
// get added sep and eos token, if any
145
150
const std::string added_sep_token = llama_vocab_get_add_sep (vocab) ? llama_vocab_get_text (vocab, llama_vocab_sep (vocab)) : " " ;
146
151
const std::string added_eos_token = llama_vocab_get_add_eos (vocab) ? llama_vocab_get_text (vocab, llama_vocab_eos (vocab)) : " " ;
152
+ const char * rerank_prompt = llama_model_chat_template (model, " rerank" );
147
153
148
154
// tokenize the prompts and trim
149
155
std::vector<std::vector<int32_t >> inputs;
@@ -153,21 +159,28 @@ int main(int argc, char ** argv) {
153
159
// split classification pairs and insert expected separator tokens
154
160
if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find (params.cls_sep ) != std::string::npos) {
155
161
std::vector<std::string> pairs = split_lines (prompt, params.cls_sep );
156
- std::string final_prompt;
157
-
158
- for (size_t i = 0 ; i < pairs.size (); i++) {
159
- final_prompt += pairs[i];
160
- if (i != pairs.size () - 1 ) {
161
- if (!added_eos_token.empty ()) {
162
- final_prompt += added_eos_token;
163
- }
164
- if (!added_sep_token.empty ()) {
165
- final_prompt += added_sep_token;
162
+ if (rerank_prompt != nullptr ) {
163
+ const std::string query = pairs[0 ];
164
+ const std::string doc = pairs[1 ];
165
+ std::string final_prompt = rerank_prompt;
166
+ string_replace_all (final_prompt, " {query}" , query);
167
+ string_replace_all (final_prompt, " {document}" , doc );
168
+ inp = common_tokenize (vocab, final_prompt, true , false );
169
+ } else {
170
+ std::string final_prompt;
171
+ for (size_t i = 0 ; i < pairs.size (); i++) {
172
+ final_prompt += pairs[i];
173
+ if (i != pairs.size () - 1 ) {
174
+ if (!added_eos_token.empty ()) {
175
+ final_prompt += added_eos_token;
176
+ }
177
+ if (!added_sep_token.empty ()) {
178
+ final_prompt += added_sep_token;
179
+ }
166
180
}
167
181
}
182
+ inp = common_tokenize (ctx, final_prompt, true , true );
168
183
}
169
-
170
- inp = common_tokenize (ctx, final_prompt, true , true );
171
184
} else {
172
185
inp = common_tokenize (ctx, prompt, true , true );
173
186
}
@@ -229,7 +242,7 @@ int main(int argc, char ** argv) {
229
242
const uint64_t n_toks = inp.size ();
230
243
231
244
// encode if at capacity
232
- if (batch.n_tokens + n_toks > n_batch) {
245
+ if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max ) {
233
246
float * out = emb + e * n_embd;
234
247
batch_decode (ctx, batch, out, s, n_embd, params.embd_normalize );
235
248
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
0 commit comments