Skip to content

Commit 3505450

Browse files
committed
1 parent 889dcd6 commit 3505450

File tree

5 files changed

+10
-22
lines changed

5 files changed

+10
-22
lines changed

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ struct common_params {
413413
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
414414
std::string embd_sep = "\n"; // separator of embeddings
415415
std::string cls_sep = "\t"; // separator of classification sequences
416+
bool reranking = false; // enable reranking support on server
416417

417418
// server params
418419
int32_t port = 8080; // server listens on this network port

convert_hf_to_gguf.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -887,9 +887,6 @@ def get_vocab_base_pre(self, tokenizer) -> str:
887887
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
888888
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
889889
res = "mellum"
890-
if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
891-
# ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
892-
res = "qwen2"
893890

894891
if res is None:
895892
logger.warning("\n")
@@ -3665,6 +3662,8 @@ class Qwen3Model(Qwen2Model):
36653662

36663663
def __init__(self, *args, **kwargs):
36673664
super().__init__(*args, **kwargs)
3665+
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
3666+
self.origin_hf_arch = hparams.get('architectures', [None])[0]
36683667
# a bit hacky, but currently the only way to detect if this is a rerank model
36693668
# ref: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
36703669
readme_path = self.dir_model / "README.md"
@@ -3674,6 +3673,8 @@ def __init__(self, *args, **kwargs):
36743673
readme_text = f.read()
36753674
if "# Qwen3-Reranker" in readme_text:
36763675
self._find_rerank_config()
3676+
else:
3677+
logger.info("gguf: not a rerank model")
36773678

36783679
def _find_rerank_config(self):
36793680
from transformers import AutoTokenizer
@@ -3688,6 +3689,7 @@ def _find_rerank_config(self):
36883689
def set_gguf_parameters(self):
36893690
super().set_gguf_parameters()
36903691
is_rerank = self.token_false_id is not None and self.token_true_id is not None
3692+
logger.info(f"gguf: is_rerank = {is_rerank}")
36913693
if is_rerank:
36923694
self.gguf_writer.add_pooling_type(gguf.PoolingType.RANK)
36933695
self.gguf_writer.add_classifier_output_labels(["yes", "no"])
@@ -3723,11 +3725,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
37233725

37243726
return super().modify_tensors(data_torch, name, bid)
37253727

3726-
def __init__(self, *args, **kwargs):
3727-
super().__init__(*args, **kwargs)
3728-
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
3729-
self.origin_hf_arch = hparams.get('architectures', [None])[0]
3730-
37313728
def set_vocab(self):
37323729
# deal with intern-s1-mini
37333730
if self.origin_hf_arch == 'InternS1ForConditionalGeneration':

src/llama-graph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
204204
std::vector<int> target_pos(n_seqs_unq, -1);
205205
std::vector<int> target_row(n_seqs_unq, -1);
206206

207-
bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST;
207+
bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST || arch == LLM_ARCH_QWEN3; // qwen3 reranking & embedding
208208

209209
for (int i = 0; i < n_tokens; ++i) {
210210
const llama_pos pos = ubatch->pos[i];

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8088,7 +8088,7 @@ struct llm_build_stablelm : public llm_graph_context {
80888088
};
80898089

80908090
struct llm_build_qwen : public llm_graph_context {
8091-
llm_build_qwen(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
8091+
llm_build_qwen(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
80928092
const int64_t n_embd_head = hparams.n_embd_head_v;
80938093

80948094
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);

tools/server/server.cpp

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4798,8 +4798,8 @@ int main(int argc, char ** argv) {
47984798
};
47994799

48004800
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
4801-
if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
4802-
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
4801+
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
4802+
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
48034803
return;
48044804
}
48054805

@@ -4838,29 +4838,19 @@ int main(int argc, char ** argv) {
48384838
return;
48394839
}
48404840

4841-
std::vector<server_tokens> tokenized_queries = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, query, /* add_special */ false, true);
4842-
if (tokenized_queries.size() != 1) {
4843-
res_error(res, format_error_response("\"query\" must contain only a single prompt", ERROR_TYPE_INVALID_REQUEST));
4844-
}
4845-
48464841
// create and queue the task
48474842
json responses = json::array();
48484843
bool error = false;
48494844
std::unordered_set<int> task_ids;
48504845
{
48514846
std::vector<server_task> tasks;
4852-
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, documents, /* add_special */ false, true);
4853-
tasks.reserve(tokenized_docs.size());
4854-
for (size_t i = 0; i < tokenized_docs.size(); i++) {
4855-
auto tmp = format_rerank(ctx_server.vocab, tokenized_queries[0], tokenized_docs[i]);
48564847
auto inputs = tokenize_rerank(ctx_server.model, query, documents);
48574848
tasks.reserve(documents.size());
48584849
for (size_t i = 0; i < inputs.size(); i++) {
48594850
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
48604851
task.id = ctx_server.queue_tasks.get_new_id();
48614852
task.index = i;
48624853
task.prompt_tokens = server_tokens(inputs[i], ctx_server.mctx != nullptr);
4863-
task.prompt_tokens = std::move(tmp);
48644854
tasks.push_back(std::move(task));
48654855
}
48664856

0 commit comments

Comments
 (0)