@@ -634,7 +634,6 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
634
634
json error = nullptr ;
635
635
636
636
server_task_result_ptr result = ctx_server->queue_results .recv (id_task);
637
- ctx_server->queue_results .remove_waiting_task_id (id_task);
638
637
639
638
json response_str = result->to_json ();
640
639
if (result->is_error ()) {
@@ -644,6 +643,10 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
644
643
return nullptr ;
645
644
}
646
645
646
+ if (result->is_stop ()) {
647
+ ctx_server->queue_results .remove_waiting_task_id (id_task);
648
+ }
649
+
647
650
const auto out_res = result->to_json ();
648
651
649
652
// Extract "embedding" as a vector of vectors (2D array)
@@ -679,6 +682,102 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
679
682
return j_embedding;
680
683
}
681
684
685
+ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank (JNIEnv *env, jobject obj, jstring jprompt,
686
+ jobjectArray documents) {
687
+ jlong server_handle = env->GetLongField (obj, f_model_pointer);
688
+ auto *ctx_server = reinterpret_cast <server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
689
+
690
+ if (!ctx_server->params_base .reranking || ctx_server->params_base .embedding ) {
691
+ env->ThrowNew (c_llama_error,
692
+ " This server does not support reranking. Start it with `--reranking` and without `--embedding`" );
693
+ return nullptr ;
694
+ }
695
+
696
+ const std::string prompt = parse_jstring (env, jprompt);
697
+
698
+ const auto tokenized_query = tokenize_mixed (ctx_server->vocab , prompt, true , true );
699
+
700
+ json responses = json::array ();
701
+
702
+ std::vector<server_task> tasks;
703
+ const jsize amount_documents = env->GetArrayLength (documents);
704
+ auto *document_array = parse_string_array (env, documents, amount_documents);
705
+ auto document_vector = std::vector<std::string>(document_array, document_array + amount_documents);
706
+ free_string_array (document_array, amount_documents);
707
+
708
+ std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts (ctx_server->vocab , document_vector, true , true );
709
+
710
+ tasks.reserve (tokenized_docs.size ());
711
+ for (int i = 0 ; i < tokenized_docs.size (); i++) {
712
+ auto task = server_task (SERVER_TASK_TYPE_RERANK);
713
+ task.id = ctx_server->queue_tasks .get_new_id ();
714
+ task.index = i;
715
+ task.prompt_tokens = format_rerank (ctx_server->vocab , tokenized_query, tokenized_docs[i]);
716
+ tasks.push_back (task);
717
+ }
718
+ ctx_server->queue_results .add_waiting_tasks (tasks);
719
+ ctx_server->queue_tasks .post (tasks);
720
+
721
+ // get the result
722
+ std::unordered_set<int > task_ids = server_task::get_list_id (tasks);
723
+ std::vector<server_task_result_ptr> results (task_ids.size ());
724
+
725
+ // Create a new HashMap instance
726
+ jobject o_probabilities = env->NewObject (c_hash_map, cc_hash_map);
727
+ if (o_probabilities == nullptr ) {
728
+ env->ThrowNew (c_llama_error, " Failed to create HashMap object." );
729
+ return nullptr ;
730
+ }
731
+
732
+ for (int i = 0 ; i < (int )task_ids.size (); i++) {
733
+ server_task_result_ptr result = ctx_server->queue_results .recv (task_ids);
734
+ if (result->is_error ()) {
735
+ auto response = result->to_json ()[" message" ].get <std::string>();
736
+ for (const int id_task : task_ids) {
737
+ ctx_server->queue_results .remove_waiting_task_id (id_task);
738
+ }
739
+ env->ThrowNew (c_llama_error, response.c_str ());
740
+ return nullptr ;
741
+ }
742
+
743
+ const auto out_res = result->to_json ();
744
+
745
+ if (result->is_stop ()) {
746
+ for (const int id_task : task_ids) {
747
+ ctx_server->queue_results .remove_waiting_task_id (id_task);
748
+ }
749
+ }
750
+
751
+ int index = out_res[" index" ].get <int >();
752
+ float score = out_res[" score" ].get <float >();
753
+ std::string tok_str = document_vector[index];
754
+ jstring jtok_str = env->NewStringUTF (tok_str.c_str ());
755
+
756
+ jobject jprob = env->NewObject (c_float, cc_float, score);
757
+ env->CallObjectMethod (o_probabilities, m_map_put, jtok_str, jprob);
758
+ env->DeleteLocalRef (jtok_str);
759
+ env->DeleteLocalRef (jprob);
760
+ }
761
+ jbyteArray jbytes = parse_jbytes (env, prompt);
762
+ return env->NewObject (c_output, cc_output, jbytes, o_probabilities, true );
763
+ }
764
+
765
+ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate (JNIEnv *env, jobject obj, jstring jparams) {
766
+ jlong server_handle = env->GetLongField (obj, f_model_pointer);
767
+ auto *ctx_server = reinterpret_cast <server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
768
+
769
+ std::string c_params = parse_jstring (env, jparams);
770
+ json data = json::parse (c_params);
771
+
772
+ json templateData =
773
+ oaicompat_completion_params_parse (data, ctx_server->params_base .use_jinja ,
774
+ ctx_server->params_base .reasoning_format , ctx_server->chat_templates .get ());
775
+ std::string tok_str = templateData.at (" prompt" );
776
+ jstring jtok_str = env->NewStringUTF (tok_str.c_str ());
777
+
778
+ return jtok_str;
779
+ }
780
+
682
781
JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode (JNIEnv *env, jobject obj, jstring jprompt) {
683
782
jlong server_handle = env->GetLongField (obj, f_model_pointer);
684
783
auto *ctx_server = reinterpret_cast <server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
0 commit comments