Skip to content

Commit b5bd037

Browse files
authored
llama : add support for qwen3 reranker (#15824)
1 parent dfcd53f commit b5bd037

File tree

9 files changed

+166
-78
lines changed

9 files changed

+166
-78
lines changed

common/common.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -961,15 +961,13 @@ struct common_init_result common_init_from_params(common_params & params) {
961961

962962
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
963963
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
964+
bool has_rerank_prompt = llama_model_chat_template(model, "rerank") != NULL;
964965

965-
if (!has_eos && !has_sep) {
966-
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
966+
if (!has_eos && !has_sep && !has_rerank_prompt) {
967+
LOG_WRN("%s: warning: vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n", __func__);
967968
ok = false;
968969
} else if (!has_eos) {
969970
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
970-
} else if (!has_sep) {
971-
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
972-
ok = false;
973971
}
974972

975973
if (!ok) {

convert_hf_to_gguf.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3717,11 +3717,29 @@ def prepare_tensors(self):
37173717
class Qwen3Model(Qwen2Model):
37183718
model_arch = gguf.MODEL_ARCH.QWEN3
37193719

3720+
# extra logic for rerank models
3721+
is_rerank: bool = False
3722+
is_tied_embeddings: bool = False
3723+
token_false_id: int | None = None
3724+
token_true_id: int | None = None
3725+
37203726
def __init__(self, *args, **kwargs):
37213727
super().__init__(*args, **kwargs)
3728+
3729+
# track for intern-s1-mini
37223730
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
37233731
self.origin_hf_arch = hparams.get('architectures', [None])[0]
37243732

3733+
# a bit hacky, but currently the only way to detect if this is a rerank model
3734+
# ref: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
3735+
readme_path = self.dir_model / "README.md"
3736+
readme_text = ""
3737+
if readme_path.exists():
3738+
with readme_path.open("r", encoding="utf-8") as f:
3739+
readme_text = f.read()
3740+
if "# Qwen3-Reranker" in readme_text:
3741+
self._find_rerank_config()
3742+
37253743
def set_vocab(self):
37263744
# deal with intern-s1-mini
37273745
if self.origin_hf_arch == 'InternS1ForConditionalGeneration':
@@ -3730,6 +3748,53 @@ def set_vocab(self):
37303748

37313749
super().set_vocab()
37323750

3751+
def _find_rerank_config(self):
3752+
from transformers import AutoTokenizer
3753+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
3754+
3755+
self.is_rerank = True
3756+
self.is_tied_embeddings = self.hparams.get("tie_word_embeddings", False)
3757+
self.token_false_id = tokenizer.convert_tokens_to_ids("no")
3758+
self.token_true_id = tokenizer.convert_tokens_to_ids("yes")
3759+
self.sep_token_id = tokenizer.convert_tokens_to_ids("|")
3760+
3761+
assert self.token_false_id is not None and self.token_true_id is not None
3762+
3763+
def set_gguf_parameters(self):
3764+
super().set_gguf_parameters()
3765+
if self.is_rerank:
3766+
self.gguf_writer.add_pooling_type(gguf.PoolingType.RANK)
3767+
self.gguf_writer.add_classifier_output_labels(["yes", "no"])
3768+
self.gguf_writer.add_chat_template([{
3769+
"name": "rerank",
3770+
"template": "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n"
3771+
"<|im_start|>user\n<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {query}\n<Document>: {document}<|im_end|>\n"
3772+
"<|im_start|>assistant\n<think>\n\n</think>\n\n"
3773+
}])
3774+
3775+
def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor:
3776+
# extract "yes" and "no" tokens from the output lm_head tensor
3777+
false_row = data_torch[self.token_false_id]
3778+
true_row = data_torch[self.token_true_id]
3779+
return torch.stack([true_row, false_row], dim=0)
3780+
3781+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3782+
if self.is_rerank:
3783+
is_tied_head = self.is_tied_embeddings and "embed_tokens" in name
3784+
is_real_head = not self.is_tied_embeddings and "lm_head" in name
3785+
if is_tied_head or is_real_head:
3786+
cls_out_head = (
3787+
gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.CLS_OUT] + ".weight",
3788+
self._get_cls_out_tensor(data_torch),
3789+
)
3790+
if is_tied_head:
3791+
embed = (self.map_tensor_name(name), data_torch)
3792+
return [cls_out_head, embed]
3793+
if is_real_head:
3794+
return [cls_out_head]
3795+
3796+
return super().modify_tensors(data_torch, name, bid)
3797+
37333798

37343799
@ModelBase.register("Qwen3MoeForCausalLM")
37353800
class Qwen3MoeModel(Qwen2MoeModel):

examples/embedding/embedding.cpp

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,13 @@ int main(int argc, char ** argv) {
9595
params.n_batch = params.n_ctx;
9696
}
9797

98-
// For non-causal models, batch size must be equal to ubatch size
99-
params.n_ubatch = params.n_batch;
98+
// for non-causal models, batch size must be equal to ubatch size
99+
if (params.attention_type != LLAMA_ATTENTION_TYPE_CAUSAL) {
100+
params.n_ubatch = params.n_batch;
101+
}
102+
103+
// get max number of sequences per batch
104+
const int n_seq_max = llama_max_parallel_sequences();
100105

101106
llama_backend_init();
102107
llama_numa_init(params.numa);
@@ -144,6 +149,7 @@ int main(int argc, char ** argv) {
144149
// get added sep and eos token, if any
145150
const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : "";
146151
const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : "";
152+
const char * rerank_prompt = llama_model_chat_template(model, "rerank");
147153

148154
// tokenize the prompts and trim
149155
std::vector<std::vector<int32_t>> inputs;
@@ -153,21 +159,28 @@ int main(int argc, char ** argv) {
153159
// split classification pairs and insert expected separator tokens
154160
if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) {
155161
std::vector<std::string> pairs = split_lines(prompt, params.cls_sep);
156-
std::string final_prompt;
157-
158-
for (size_t i = 0; i < pairs.size(); i++) {
159-
final_prompt += pairs[i];
160-
if (i != pairs.size() - 1) {
161-
if (!added_eos_token.empty()) {
162-
final_prompt += added_eos_token;
163-
}
164-
if (!added_sep_token.empty()) {
165-
final_prompt += added_sep_token;
162+
if (rerank_prompt != nullptr) {
163+
const std::string query = pairs[0];
164+
const std::string doc = pairs[1];
165+
std::string final_prompt = rerank_prompt;
166+
string_replace_all(final_prompt, "{query}" , query);
167+
string_replace_all(final_prompt, "{document}", doc );
168+
inp = common_tokenize(vocab, final_prompt, true, true);
169+
} else {
170+
std::string final_prompt;
171+
for (size_t i = 0; i < pairs.size(); i++) {
172+
final_prompt += pairs[i];
173+
if (i != pairs.size() - 1) {
174+
if (!added_eos_token.empty()) {
175+
final_prompt += added_eos_token;
176+
}
177+
if (!added_sep_token.empty()) {
178+
final_prompt += added_sep_token;
179+
}
166180
}
167181
}
182+
inp = common_tokenize(ctx, final_prompt, true, true);
168183
}
169-
170-
inp = common_tokenize(ctx, final_prompt, true, true);
171184
} else {
172185
inp = common_tokenize(ctx, prompt, true, true);
173186
}
@@ -229,7 +242,7 @@ int main(int argc, char ** argv) {
229242
const uint64_t n_toks = inp.size();
230243

231244
// encode if at capacity
232-
if (batch.n_tokens + n_toks > n_batch) {
245+
if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max) {
233246
float * out = emb + e * n_embd;
234247
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
235248
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;

src/llama-arch.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
721721
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
722722
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
723723
{ LLM_TENSOR_OUTPUT, "output" },
724+
{ LLM_TENSOR_CLS_OUT, "cls.output" },
724725
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
725726
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
726727
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },

src/llama-graph.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,10 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
204204
std::vector<int> target_pos(n_seqs_unq, -1);
205205
std::vector<int> target_row(n_seqs_unq, -1);
206206

207-
bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST;
207+
const bool last = (
208+
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
209+
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
210+
);
208211

209212
for (int i = 0; i < n_tokens; ++i) {
210213
const llama_pos pos = ubatch->pos[i];
@@ -1177,7 +1180,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
11771180
}
11781181

11791182
ggml_tensor * llm_graph_context::build_inp_cls() const {
1180-
auto inp = std::make_unique<llm_graph_input_cls>(cparams);
1183+
auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
11811184

11821185
auto & cur = inp->cls;
11831186

@@ -1877,34 +1880,32 @@ void llm_graph_context::build_pooling(
18771880
case LLAMA_POOLING_TYPE_RANK:
18781881
{
18791882
ggml_tensor * inp_cls = build_inp_cls();
1880-
inp = ggml_get_rows(ctx0, inp, inp_cls);
1883+
cur = ggml_get_rows(ctx0, inp, inp_cls);
18811884

1885+
// classification head
1886+
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
18821887
if (cls) {
1883-
// classification head
1884-
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
1885-
cur = ggml_mul_mat(ctx0, cls, inp);
1888+
cur = ggml_mul_mat(ctx0, cls, cur);
18861889
if (cls_b) {
18871890
cur = ggml_add(ctx0, cur, cls_b);
18881891
}
18891892
cur = ggml_tanh(ctx0, cur);
1893+
}
18901894

1891-
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1892-
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1893-
if (cls_out) {
1894-
cur = ggml_mul_mat(ctx0, cls_out, cur);
1895-
if (cls_out_b) {
1896-
cur = ggml_add(ctx0, cur, cls_out_b);
1897-
}
1898-
}
1899-
} else if (cls_out) {
1900-
// Single layer classification head (direct projection)
1901-
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1902-
cur = ggml_mul_mat(ctx0, cls_out, inp);
1895+
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
1896+
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
1897+
// Single layer classification head (direct projection)
1898+
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
1899+
if (cls_out) {
1900+
cur = ggml_mul_mat(ctx0, cls_out, cur);
19031901
if (cls_out_b) {
19041902
cur = ggml_add(ctx0, cur, cls_out_b);
19051903
}
1906-
} else {
1907-
GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
1904+
}
1905+
1906+
// softmax for qwen3 reranker
1907+
if (arch == LLM_ARCH_QWEN3) {
1908+
cur = ggml_soft_max(ctx0, cur);
19081909
}
19091910
} break;
19101911
default:

src/llama-graph.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,15 @@ class llm_graph_input_mean : public llm_graph_input_i {
206206

207207
class llm_graph_input_cls : public llm_graph_input_i {
208208
public:
209-
llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
209+
llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {}
210210
virtual ~llm_graph_input_cls() = default;
211211

212212
void set_input(const llama_ubatch * ubatch) override;
213213

214214
ggml_tensor * cls; // I32 [n_batch]
215215

216216
const llama_cparams cparams;
217+
const llm_arch arch;
217218
};
218219

219220
class llm_graph_input_rs : public llm_graph_input_i {

src/llama-model.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3167,6 +3167,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
31673167
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
31683168
}
31693169

3170+
// output rerank head
3171+
cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
3172+
31703173
for (int i = 0; i < n_layer; ++i) {
31713174
auto & layer = layers[i];
31723175

tools/server/server.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5093,21 +5093,15 @@ int main(int argc, char ** argv) {
50935093
return;
50945094
}
50955095

5096-
std::vector<server_tokens> tokenized_queries = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, query, /* add_special */ false, true);
5097-
if (tokenized_queries.size() != 1) {
5098-
res_error(res, format_error_response("\"query\" must contain only a single prompt", ERROR_TYPE_INVALID_REQUEST));
5099-
}
5100-
51015096
// create and queue the task
51025097
json responses = json::array();
51035098
bool error = false;
51045099
std::unordered_set<int> task_ids;
51055100
{
51065101
std::vector<server_task> tasks;
5107-
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, documents, /* add_special */ false, true);
5108-
tasks.reserve(tokenized_docs.size());
5109-
for (size_t i = 0; i < tokenized_docs.size(); i++) {
5110-
auto tmp = format_rerank(ctx_server.vocab, tokenized_queries[0], tokenized_docs[i]);
5102+
tasks.reserve(documents.size());
5103+
for (size_t i = 0; i < documents.size(); i++) {
5104+
auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]);
51115105
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
51125106
task.id = ctx_server.queue_tasks.get_new_id();
51135107
task.index = i;

tools/server/utils.hpp

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,34 +1368,6 @@ static std::string fnv_hash(const uint8_t * data, size_t len) {
13681368
return std::to_string(hash);
13691369
}
13701370

1371-
1372-
// format rerank task: [BOS]query[EOS][SEP]doc[EOS].
1373-
static server_tokens format_rerank(const struct llama_vocab * vocab, server_tokens & query, server_tokens & doc) {
1374-
server_tokens result = {};
1375-
1376-
// Get EOS token - use SEP token as fallback if EOS is not available
1377-
llama_token eos_token = llama_vocab_eos(vocab);
1378-
if (eos_token == LLAMA_TOKEN_NULL) {
1379-
eos_token = llama_vocab_sep(vocab);
1380-
}
1381-
if (llama_vocab_get_add_bos(vocab)) {
1382-
result.push_back(llama_vocab_bos(vocab));
1383-
}
1384-
result.push_back(query);
1385-
if (llama_vocab_get_add_eos(vocab)) {
1386-
result.push_back(eos_token);
1387-
}
1388-
if (llama_vocab_get_add_sep(vocab)) {
1389-
result.push_back(llama_vocab_sep(vocab));
1390-
}
1391-
result.push_back(doc);
1392-
if (llama_vocab_get_add_eos(vocab)) {
1393-
result.push_back(eos_token);
1394-
}
1395-
return result;
1396-
}
1397-
1398-
13991371
static server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector<raw_buffer> files) {
14001372
mtmd::bitmaps bitmaps;
14011373
for (auto & file : files) {
@@ -1501,3 +1473,43 @@ static std::vector<server_tokens> tokenize_input_prompts(const llama_vocab * voc
15011473
}
15021474
return result;
15031475
}
1476+
1477+
// format rerank task: [BOS]query[EOS][SEP]doc[EOS].
1478+
static server_tokens format_rerank(const struct llama_model * model, const struct llama_vocab * vocab, mtmd_context * mctx, const std::string & query, const std::string & doc) {
1479+
server_tokens result = {};
1480+
1481+
const char * rerank_prompt = llama_model_chat_template(model, "rerank");
1482+
1483+
if (rerank_prompt != nullptr) {
1484+
std::string prompt = rerank_prompt;
1485+
string_replace_all(prompt, "{query}" , query);
1486+
string_replace_all(prompt, "{document}", doc );
1487+
server_tokens tokens = tokenize_input_subprompt(vocab, mctx, prompt, false, true);
1488+
result.push_back(tokens);
1489+
} else {
1490+
// Get EOS token - use SEP token as fallback if EOS is not available
1491+
server_tokens query_tokens = tokenize_input_subprompt(vocab, mctx, query, false, false);
1492+
server_tokens doc_tokens = tokenize_input_subprompt(vocab, mctx, doc, false, false);
1493+
llama_token eos_token = llama_vocab_eos(vocab);
1494+
if (eos_token == LLAMA_TOKEN_NULL) {
1495+
eos_token = llama_vocab_sep(vocab);
1496+
}
1497+
1498+
if (llama_vocab_get_add_bos(vocab)) {
1499+
result.push_back(llama_vocab_bos(vocab));
1500+
}
1501+
result.push_back(query_tokens);
1502+
if (llama_vocab_get_add_eos(vocab)) {
1503+
result.push_back(eos_token);
1504+
}
1505+
if (llama_vocab_get_add_sep(vocab)) {
1506+
result.push_back(llama_vocab_sep(vocab));
1507+
}
1508+
result.push_back(doc_tokens);
1509+
if (llama_vocab_get_add_eos(vocab)) {
1510+
result.push_back(eos_token);
1511+
}
1512+
}
1513+
1514+
return result;
1515+
}

0 commit comments

Comments
 (0)