diff --git a/.github/labeler.yml b/.github/labeler.yml index 1ddc2bc6df..73b7c90ea3 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -107,7 +107,7 @@ 'category: RAG samples': - 'samples/cpp/rag/**/*' - 'samples/python/rag/**/*' -- 'tests/python_tests/samples/test_text_embedding_pipeline.py' +- 'tests/python_tests/samples/test_rag.py' 'category: structured output generation': - 'src/cpp/src/sampling/structured_output/*' diff --git a/samples/cpp/rag/CMakeLists.txt b/samples/cpp/rag/CMakeLists.txt index c2a46cbc2b..2cc3a29090 100644 --- a/samples/cpp/rag/CMakeLists.txt +++ b/samples/cpp/rag/CMakeLists.txt @@ -21,7 +21,7 @@ function(add_sample_executable target_name) EXCLUDE_FROM_ALL) endfunction() -set(SAMPLE_LIST text_embeddings) +set(SAMPLE_LIST text_embeddings text_rerank) foreach(sample ${SAMPLE_LIST}) add_sample_executable(${sample}) diff --git a/samples/cpp/rag/README.md b/samples/cpp/rag/README.md index bae856a452..2d627ee303 100644 --- a/samples/cpp/rag/README.md +++ b/samples/cpp/rag/README.md @@ -1,6 +1,6 @@ # Retrieval Augmented Generation Sample -This example showcases inference of Text Embedding Models. The application has limited configuration options to encourage the reader to explore and modify the source code. For example, change the device for inference to GPU. The sample features `ov::genai::TextEmbeddingPipeline` and uses text as an input source. +This example showcases inference of Text Embedding and Text Rerank Models. The application has limited configuration options to encourage the reader to explore and modify the source code. For example, change the device for inference to GPU. The sample features `ov::genai::TextEmbeddingPipeline` and `ov::genai::TextRerankPipeline` and uses text as an input source. ## Download and Convert the Model and Tokenizers @@ -10,17 +10,37 @@ Install [../../export-requirements.txt](../../export-requirements.txt) to conver ```sh pip install --upgrade-strategy eager -r ../../export-requirements.txt +``` + +Then, run the export with Optimum CLI: + +```sh optimum-cli export openvino --trust-remote-code --model BAAI/bge-small-en-v1.5 BAAI/bge-small-en-v1.5 ``` + ## Run Follow [Get Started with Samples](https://docs.openvino.ai/2025/get-started/learn-openvino/openvino-samples/get-started-demos.html) to run the sample. -`text_embeddings BAAI/bge-small-en-v1.5 "Document 1" "Document 2"` - +### 1. Text Embedding Sample (`text_embeddings.cpp`) +- **Description:** + Demonstrates inference of text embedding models using OpenVINO GenAI. Converts input text into vector embeddings for downstream tasks such as retrieval or semantic search. +- **Run Command:** + ```sh + text_embeddings "Document 1" "Document 2" + ``` Refer to the [Supported Models](https://openvinotoolkit.github.io/openvino.genai/docs/supported-models/#text-embeddings-models) for more details. +### 2. Text Rerank Sample (`text_rerank.cpp`) +- **Description:** + Demonstrates inference of text rerank models using OpenVINO GenAI. Reranks a list of candidate documents based on their relevance to a query using a cross-encoder or reranker model. +- **Run Command:** + ```sh + text_rerank '' '' ['' ...] + ``` + + # Text Embedding Pipeline Usage ```c++ @@ -29,3 +49,12 @@ Refer to the [Supported Models](https://openvinotoolkit.github.io/openvino.genai ov::genai::TextEmbeddingPipeline pipeline(models_path, device, config); std::vector embeddings = pipeline.embed_documents(documents); ``` + +# Text Rerank Pipeline Usage + +```c++ +#include "openvino/genai/rag/text_rerank_pipeline.hpp" + +ov::genai::TextRerankPipeline pipeline(models_path, device, config); +std::vector> rerank_result = pipeline.rerank(query, documents); +``` diff --git a/samples/cpp/rag/text_rerank.cpp b/samples/cpp/rag/text_rerank.cpp new file mode 100644 index 0000000000..2f6256fc38 --- /dev/null +++ b/samples/cpp/rag/text_rerank.cpp @@ -0,0 +1,45 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "openvino/genai/rag/text_rerank_pipeline.hpp" + +int main(int argc, char* argv[]) try { + if (argc < 4) { + throw std::runtime_error(std::string{"Usage: "} + argv[0] + + " '' '' ['' ...]"); + } + + auto documents = std::vector(argv + 3, argv + argc); + std::string models_path = argv[1]; + std::string query = argv[2]; + + std::string device = "CPU"; // GPU can be used as well + + ov::genai::TextRerankPipeline::Config config; + config.top_n = 3; + + ov::genai::TextRerankPipeline pipeline(models_path, device, config); + + std::vector> rerank_result = pipeline.rerank(query, documents); + + // print reranked documents + std::cout << std::fixed << std::setprecision(4); + std::cout << "Reranked documents:\n"; + for (const auto& [index, score] : rerank_result) { + std::cout << "Document " << index << " (score: " << score << "): " << documents[index] << '\n'; + } + std::cout << std::defaultfloat; + +} catch (const std::exception& error) { + try { + std::cerr << error.what() << '\n'; + } catch (const std::ios_base::failure&) { + } + return EXIT_FAILURE; +} catch (...) { + try { + std::cerr << "Non-exception object thrown\n"; + } catch (const std::ios_base::failure&) { + } + return EXIT_FAILURE; +} diff --git a/samples/python/rag/README.md b/samples/python/rag/README.md index b5159c14fb..e4ad31c21f 100644 --- a/samples/python/rag/README.md +++ b/samples/python/rag/README.md @@ -1,6 +1,6 @@ # Retrieval Augmented Generation Sample -This example showcases inference of Text Embedding Models. The application limited configuration configuration options to encourage the reader to explore and modify the source code. For example, change the device for inference to GPU. The sample features `openvino_genai.TextEmbeddingPipeline` and uses text as an input source. +This example showcases inference of Text Embedding and Text Rerank Models. The application has limited configuration options to encourage the reader to explore and modify the source code. For example, change the device for inference to GPU. The sample features `openvino_genai.TextEmbeddingPipeline` and `openvino_genai.TextRerankPipeline` and uses text as an input source. ## Download and Convert the Model and Tokenizers @@ -38,10 +38,24 @@ export_tokenizer(tokenizer, output_dir) Install [deployment-requirements.txt](../../deployment-requirements.txt) via `pip install -r ../../deployment-requirements.txt` and then, run a sample: -`python text_embeddings.py BAAI/bge-small-en-v1.5 "Document 1" "Document 2"` - +### 1. Text Embedding Sample (`text_embeddings.py`) +- **Description:** + Demonstrates inference of text embedding models using OpenVINO GenAI. Converts input text into vector embeddings for downstream tasks such as retrieval or semantic search. +- **Run Command:** + ```sh + python text_embeddings.py "Document 1" "Document 2" + ``` Refer to the [Supported Models](https://openvinotoolkit.github.io/openvino.genai/docs/supported-models/#text-embeddings-models) for more details. +### 2. Text Rerank Sample (`text_rerank.py`) +- **Description:** + Demonstrates inference of text rerank models using OpenVINO GenAI. Reranks a list of candidate documents based on their relevance to a query using a cross-encoder or reranker model. +- **Run Command:** + ```sh + python text_rerank.py "" "" ["" ...] + ``` + + # Text Embedding Pipeline Usage ```python @@ -51,3 +65,13 @@ pipeline = openvino_genai.TextEmbeddingPipeline(model_dir, "CPU") embeddings = pipeline.embed_documents(["document1", "document2"]) ``` + +# Text Rerank Pipeline Usage + +```python +import openvino_genai + +pipeline = openvino_genai.TextRerankPipeline(model_dir, "CPU") + +rerank_result = pipeline.rerank(query, documents) +``` diff --git a/samples/python/rag/text_rerank.py b/samples/python/rag/text_rerank.py new file mode 100644 index 0000000000..8eb9554a98 --- /dev/null +++ b/samples/python/rag/text_rerank.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import openvino_genai + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("model_dir") + parser.add_argument("query") + parser.add_argument("texts", nargs="+") + args = parser.parse_args() + + device = "CPU" # GPU can be used as well + + config = openvino_genai.TextRerankPipeline.Config() + config.top_n = 3 + + pipeline = openvino_genai.TextRerankPipeline(args.model_dir, device, config) + + rerank_result = pipeline.rerank(args.query, args.texts) + + print("Reranked documents:") + for index, score in rerank_result: + print(f"Document {index} (score: {score:.4f}): {args.texts[index]}") + + +if __name__ == "__main__": + main() diff --git a/src/cpp/include/openvino/genai/rag/text_rerank_pipeline.hpp b/src/cpp/include/openvino/genai/rag/text_rerank_pipeline.hpp new file mode 100644 index 0000000000..751c1af1ed --- /dev/null +++ b/src/cpp/include/openvino/genai/rag/text_rerank_pipeline.hpp @@ -0,0 +1,107 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "openvino/genai/tokenizer.hpp" + +namespace ov { +namespace genai { + +class OPENVINO_GENAI_EXPORTS TextRerankPipeline { +public: + struct OPENVINO_GENAI_EXPORTS Config { + /** + * @brief Number of documents to return sorted by score + */ + size_t top_n = 3; + + /** + * @brief Maximum length of tokens passed to the rerank model + */ + std::optional max_length; + + /** + * @brief Constructs text rerank pipeline configuration + */ + Config() = default; + + /** + * @brief Constructs text rerank pipeline configuration + * + * @param properties configuration options + * + * const ov::AnyMap properties{{"top_n", 3}}; + * ov::genai::TextRerankPipeline::Config config(properties); + * + * ov::genai::TextRerankPipeline::Config config({{"top_n", 3}}); + */ + explicit Config(const ov::AnyMap& properties); + }; + + /** + * @brief Constructs a pipeline from xml/bin files, tokenizer and configuration in the same dir. + * + * @param models_path Path to the directory containing model xml/bin files and tokenizer + * @param device Device + * @param config Pipeline configuration + * @param properties Optional plugin properties to pass to ov::Core::compile_model(). + */ + TextRerankPipeline(const std::filesystem::path& models_path, + const std::string& device, + const Config& config, + const ov::AnyMap& properties = {}); + + /** + * @brief Constructs a pipeline from xml/bin files, tokenizer and configuration in the same dir. + * + * @param models_path Path to the directory containing model xml/bin files and tokenizer + * @param device Device + * @param properties Optional plugin and/or config properties + */ + TextRerankPipeline(const std::filesystem::path& models_path, + const std::string& device, + const ov::AnyMap& properties = {}); + + /** + * @brief Constructs a pipeline from xml/bin files, tokenizer and configuration in the same dir. + * + * @param models_path Path to the directory containing model xml/bin files and tokenizer + * @param device Device + * @param properties Plugin and/or config properties + */ + template ::value, bool>::type = true> + TextRerankPipeline(const std::filesystem::path& models_path, const std::string& device, Properties&&... properties) + : TextRerankPipeline(models_path, device, ov::AnyMap{std::forward(properties)...}) {} + + /** + * @brief Reranks a vector of texts based on the query. + */ + std::vector> rerank(const std::string& query, const std::vector& texts); + + /** + * @brief Asynchronously reranks a vector of texts based on the query. Only one method of async family can be + * active. + */ + void start_rerank_async(const std::string& query, const std::vector& texts); + + /** + * @brief Waits for reranked texts. + */ + std::vector> wait_rerank(); + + ~TextRerankPipeline(); + +private: + class TextRerankPipelineImpl; + std::unique_ptr m_impl; +}; + +/** + * @brief Number of documents to return after reranking sorted by score + */ +static constexpr ov::Property top_n{"top_n"}; + +} // namespace genai +} // namespace ov diff --git a/src/cpp/include/openvino/genai/tokenizer.hpp b/src/cpp/include/openvino/genai/tokenizer.hpp index 8f6963188f..47c57352c4 100644 --- a/src/cpp/include/openvino/genai/tokenizer.hpp +++ b/src/cpp/include/openvino/genai/tokenizer.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include "openvino/runtime/tensor.hpp" #include "openvino/genai/visibility.hpp" @@ -21,6 +22,7 @@ using Vocab = std::unordered_map; // similar to huggingfa struct TokenizedInputs { ov::Tensor input_ids; ov::Tensor attention_mask; + std::optional token_type_ids; }; /** diff --git a/src/cpp/src/debug_utils.hpp b/src/cpp/src/debug_utils.hpp index 36e3a20650..dc6d2a0262 100644 --- a/src/cpp/src/debug_utils.hpp +++ b/src/cpp/src/debug_utils.hpp @@ -3,14 +3,13 @@ #pragma once -#include -#include #include - +#include #include +#include template -void print_array(T * array, size_t size) { +void print_array(T* array, size_t size) { std::cout << " => [ "; for (size_t i = 0; i < std::min(size, size_t(10)); ++i) { std::cout << array[i] << " "; @@ -18,19 +17,45 @@ void print_array(T * array, size_t size) { std::cout << " ] " << std::endl; } +template +void print_tensor(ov::Tensor tensor) { + const auto shape = tensor.get_shape(); + const size_t rank = shape.size(); + const auto* data = tensor.data(); + + if (rank > 2) { + print_array(data, tensor.get_size()); + return; + } + + const size_t batch_size = shape[0]; + const size_t seq_length = shape[1]; + + std::cout << " => [ \n"; + for (size_t batch = 0; batch < batch_size; ++batch) { + std::cout << " [ "; + const size_t batch_offset = batch * seq_length; + for (size_t j = 0; j < seq_length; ++j) { + std::cout << data[batch_offset + j] << " "; + } + std::cout << "]\n"; + } + std::cout << " ]" << std::endl; +} + inline void print_tensor(std::string name, ov::Tensor tensor) { std::cout << name; std::cout << " " << tensor.get_shape().to_string(); if (tensor.get_element_type() == ov::element::i32) { - print_array(tensor.data(), tensor.get_size()); + print_tensor(tensor); } else if (tensor.get_element_type() == ov::element::i64) { - print_array(tensor.data(), tensor.get_size()); + print_tensor(tensor); } else if (tensor.get_element_type() == ov::element::f32) { - print_array(tensor.data(), tensor.get_size()); + print_tensor(tensor); } else if (tensor.get_element_type() == ov::element::boolean) { - print_array(tensor.data(), tensor.get_size()); + print_tensor(tensor); } else if (tensor.get_element_type() == ov::element::f16) { - print_array(tensor.data(), tensor.get_size()); + print_tensor(tensor); } } diff --git a/src/cpp/src/rag/text_rerank_pipeline.cpp b/src/cpp/src/rag/text_rerank_pipeline.cpp new file mode 100644 index 0000000000..d91652d1a8 --- /dev/null +++ b/src/cpp/src/rag/text_rerank_pipeline.cpp @@ -0,0 +1,177 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "openvino/genai/rag/text_rerank_pipeline.hpp" + +#include "openvino/core/except.hpp" +#include "openvino/genai/tokenizer.hpp" +#include "openvino/opsets/opset.hpp" +#include "openvino/opsets/opset1.hpp" +#include "openvino/opsets/opset8.hpp" +#include "utils.hpp" + +namespace { +using namespace ov::genai; +using namespace ov; + +ov::AnyMap remove_config_properties(const ov::AnyMap& properties) { + auto properties_copy = properties; + + properties_copy.erase(top_n.name()); + properties_copy.erase(max_length.name()); + + return properties_copy; +} + +std::shared_ptr apply_postprocessing(std::shared_ptr model) { + PartialShape output_shape = model->get_output_partial_shape(0); + + ov::preprocess::PrePostProcessor processor(model); + + processor.output().postprocess().custom( + [&output_shape](const ov::Output& node) -> std::shared_ptr { + if (output_shape[1] == 1) { + return std::make_shared(node); + } + + // apply softmax to the axis = 1 + const auto softmax = std::make_shared(node, 1); + + // take first class score only + auto start = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{1}); + auto stop = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{2}); + auto step = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{1}); + auto axis = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{1}); + + auto slice = std::make_shared(softmax, start, stop, step, axis); + return slice; + }); + + return processor.build(); +} + +} // namespace + +namespace ov { +namespace genai { +using utils::read_anymap_param; + +TextRerankPipeline::Config::Config(const ov::AnyMap& properties) { + read_anymap_param(properties, ov::genai::top_n.name(), top_n); + read_anymap_param(properties, ov::genai::max_length.name(), max_length); +}; + +class TextRerankPipeline::TextRerankPipelineImpl { +public: + TextRerankPipelineImpl(const std::filesystem::path& models_path, + const std::string& device, + const Config& config, + const ov::AnyMap& properties = {}) + : m_config{config}, + m_tokenizer{models_path} { + ov::Core core = utils::singleton_core(); + + auto model = core.read_model(models_path / "openvino_model.xml", {}, properties); + + model = apply_postprocessing(model); + + if (m_config.max_length) { + m_tokenization_params.insert({max_length.name(), *m_config.max_length}); + } + + ov::CompiledModel compiled_model = core.compile_model(model, device, properties); + + utils::print_compiled_model_properties(compiled_model, "text rerank model"); + m_request = compiled_model.create_infer_request(); + }; + + std::vector> rerank(const std::string& query, const std::vector& texts) { + start_rerank_async(query, texts); + return wait_rerank(); + } + + void start_rerank_async(const std::string& query, const std::vector& texts) { + const auto encoded = m_tokenizer.encode({query}, texts, m_tokenization_params); + + m_request.set_tensor("input_ids", encoded.input_ids); + m_request.set_tensor("attention_mask", encoded.attention_mask); + + if (encoded.token_type_ids.has_value()) { + m_request.set_tensor("token_type_ids", *encoded.token_type_ids); + } + + m_request.start_async(); + } + + std::vector> wait_rerank() { + m_request.wait(); + + // postprocessing applied to output, it's the scores tensor + auto scores_tensor = m_request.get_tensor("logits"); + auto scores_tensor_shape = scores_tensor.get_shape(); + const size_t batch_size = scores_tensor_shape[0]; + + auto scores_data = scores_tensor.data(); + + std::vector> results; + results.reserve(batch_size); + + for (size_t batch = 0; batch < batch_size; batch++) { + results.emplace_back(batch, scores_data[batch]); + } + + const size_t top_n = m_config.top_n; + + // partial sort to get top_n results + std::partial_sort(results.begin(), + results.begin() + std::min(top_n, results.size()), + results.end(), + [](const auto& a, const auto& b) { + return a.second > b.second; + }); + + if (top_n < results.size()) { + results.resize(top_n); + } + + return results; + } + +private: + Tokenizer m_tokenizer; + InferRequest m_request; + Config m_config; + AnyMap m_tokenization_params; +}; + +TextRerankPipeline::TextRerankPipeline(const std::filesystem::path& models_path, + const std::string& device, + const Config& config, + const ov::AnyMap& properties) + : m_impl{std::make_unique(models_path, device, config, properties)} {}; + +TextRerankPipeline::TextRerankPipeline(const std::filesystem::path& models_path, + const std::string& device, + const ov::AnyMap& properties) + : m_impl{std::make_unique(models_path, + device, + Config(properties), + remove_config_properties(properties))} {}; + +std::vector> TextRerankPipeline::rerank(const std::string& query, + const std::vector& texts) { + return m_impl->rerank(query, texts); +} + +void TextRerankPipeline::start_rerank_async(const std::string& query, const std::vector& texts) { + m_impl->start_rerank_async(query, texts); +} + +std::vector> TextRerankPipeline::wait_rerank() { + return m_impl->wait_rerank(); +} + +TextRerankPipeline::~TextRerankPipeline() = default; + +} // namespace genai +} // namespace ov diff --git a/src/cpp/src/tokenizer/tokenizer.cpp b/src/cpp/src/tokenizer/tokenizer.cpp index 5fb53913aa..ec651b4356 100644 --- a/src/cpp/src/tokenizer/tokenizer.cpp +++ b/src/cpp/src/tokenizer/tokenizer.cpp @@ -614,7 +614,6 @@ class Tokenizer::TokenizerImpl { TokenizedInputs encode(const std::vector& prompts_1, const std::vector& prompts_2, const ov::AnyMap& tokenization_params = {}) { OPENVINO_ASSERT(m_ireq_queue_tokenizer, "Either openvino_tokenizer.xml was not provided or it was not loaded correctly. " "Tokenizer::encode is not available"); - size_t batch_size = prompts_1.size(); OPENVINO_ASSERT(prompts_1.size() == prompts_2.size() || prompts_1.size() == 1 || prompts_2.size() == 1, "prompts_1 and prompts_2 should be of the same size or one of them should be of size 1"); @@ -622,8 +621,8 @@ class Tokenizer::TokenizerImpl { { CircularBufferQueueElementGuard infer_request_guard(this->m_ireq_queue_tokenizer.get()); set_state_if_necessary(infer_request_guard, tokenization_params); - infer_request_guard.get().set_input_tensor(0, ov::Tensor{ov::element::string, {batch_size}, const_cast(prompts_1.data())}); - infer_request_guard.get().set_input_tensor(1, ov::Tensor{ov::element::string, {batch_size}, const_cast(prompts_2.data())}); + infer_request_guard.get().set_input_tensor(0, ov::Tensor{ov::element::string, {prompts_1.size()}, const_cast(prompts_1.data())}); + infer_request_guard.get().set_input_tensor(1, ov::Tensor{ov::element::string, {prompts_2.size()}, const_cast(prompts_2.data())}); infer_request_guard.get().infer(); @@ -631,8 +630,19 @@ class Tokenizer::TokenizerImpl { infer_request_guard.get().get_tensor("input_ids"), infer_request_guard.get().get_tensor("attention_mask") ); + + // If the model has a token_type_ids output, copy it to the result + for (const auto& output : infer_request_guard.get().get_compiled_model().outputs()) { + if (output.get_any_name() != "token_type_ids") { + continue; + } + ov::Tensor token_type_ids = infer_request_guard.get().get_tensor(output); + ov::Tensor token_type_ids_copy = ov::Tensor(token_type_ids.get_element_type(), token_type_ids.get_shape()); + token_type_ids.copy_to(token_type_ids_copy); + result.token_type_ids = token_type_ids_copy; + } } - return {result.input_ids, result.attention_mask}; + return {result.input_ids, result.attention_mask, result.token_type_ids}; } TokenizedInputs encode(const std::vector& prompts, const ov::AnyMap& tokenization_params = {}) { diff --git a/src/python/openvino_genai/__init__.py b/src/python/openvino_genai/__init__.py index 6f478068b4..fc628129f1 100644 --- a/src/python/openvino_genai/__init__.py +++ b/src/python/openvino_genai/__init__.py @@ -99,7 +99,8 @@ # RAG from .py_openvino_genai import ( - TextEmbeddingPipeline + TextEmbeddingPipeline, + TextRerankPipeline ) # Speech generation diff --git a/src/python/openvino_genai/__init__.pyi b/src/python/openvino_genai/__init__.pyi index 28d8b0dd98..696c182af0 100644 --- a/src/python/openvino_genai/__init__.pyi +++ b/src/python/openvino_genai/__init__.pyi @@ -45,6 +45,7 @@ from openvino_genai.py_openvino_genai import Text2ImagePipeline from openvino_genai.py_openvino_genai import Text2SpeechDecodedResults from openvino_genai.py_openvino_genai import Text2SpeechPipeline from openvino_genai.py_openvino_genai import TextEmbeddingPipeline +from openvino_genai.py_openvino_genai import TextRerankPipeline from openvino_genai.py_openvino_genai import TextStreamer from openvino_genai.py_openvino_genai import TokenizedInputs from openvino_genai.py_openvino_genai import Tokenizer @@ -59,5 +60,5 @@ from openvino_genai.py_openvino_genai import draft_model from openvino_genai.py_openvino_genai import get_version import os as os from . import py_openvino_genai -__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'ImageGenerationPerfMetrics', 'InpaintingPipeline', 'LLMPipeline', 'PerfMetrics', 'RawImageGenerationPerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'SpeechGenerationConfig', 'SpeechGenerationPerfMetrics', 'StopCriteria', 'StreamerBase', 'StreamingStatus', 'StructuralTagItem', 'StructuralTagsConfig', 'StructuredOutputConfig', 'T5EncoderModel', 'Text2ImagePipeline', 'Text2SpeechDecodedResults', 'Text2SpeechPipeline', 'TextEmbeddingPipeline', 'TextStreamer', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMPipeline', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version', 'openvino', 'os', 'py_openvino_genai'] +__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedResults', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'ImageGenerationPerfMetrics', 'InpaintingPipeline', 'LLMPipeline', 'PerfMetrics', 'RawImageGenerationPerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'Scheduler', 'SchedulerConfig', 'SpeechGenerationConfig', 'SpeechGenerationPerfMetrics', 'StopCriteria', 'StreamerBase', 'StreamingStatus', 'StructuralTagItem', 'StructuralTagsConfig', 'StructuredOutputConfig', 'T5EncoderModel', 'Text2ImagePipeline', 'Text2SpeechDecodedResults', 'Text2SpeechPipeline', 'TextEmbeddingPipeline', 'TextRerankPipeline', 'TextStreamer', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMPipeline', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version', 'openvino', 'os', 'py_openvino_genai'] __version__: str diff --git a/src/python/openvino_genai/py_openvino_genai.pyi b/src/python/openvino_genai/py_openvino_genai.pyi index 0431148522..b8d448600b 100644 --- a/src/python/openvino_genai/py_openvino_genai.pyi +++ b/src/python/openvino_genai/py_openvino_genai.pyi @@ -1,11 +1,11 @@ """ -Pybind11 binding for Text-to-speech Pipeline +Pybind11 binding for OpenVINO GenAI library """ from __future__ import annotations import collections.abc import openvino._pyopenvino import typing -__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedGenerationResult', 'EncodedResults', 'ExtendedPerfMetrics', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationHandle', 'GenerationOutput', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'ImageGenerationPerfMetrics', 'InpaintingPipeline', 'LLMPipeline', 'MeanStdPair', 'PerfMetrics', 'PipelineMetrics', 'RawImageGenerationPerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'SDPerModelsPerfMetrics', 'SDPerfMetrics', 'Scheduler', 'SchedulerConfig', 'SparseAttentionConfig', 'SparseAttentionMode', 'SpeechGenerationConfig', 'SpeechGenerationPerfMetrics', 'StopCriteria', 'StreamerBase', 'StreamingStatus', 'StructuralTagItem', 'StructuralTagsConfig', 'StructuredOutputConfig', 'SummaryStats', 'T5EncoderModel', 'Text2ImagePipeline', 'Text2SpeechDecodedResults', 'Text2SpeechPipeline', 'TextEmbeddingPipeline', 'TextStreamer', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMDecodedResults', 'VLMPerfMetrics', 'VLMPipeline', 'VLMRawPerfMetrics', 'WhisperDecodedResultChunk', 'WhisperDecodedResults', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version'] +__all__ = ['Adapter', 'AdapterConfig', 'AggregationMode', 'AutoencoderKL', 'CLIPTextModel', 'CLIPTextModelWithProjection', 'CacheEvictionConfig', 'ChunkStreamerBase', 'ContinuousBatchingPipeline', 'CppStdGenerator', 'DecodedResults', 'EncodedGenerationResult', 'EncodedResults', 'ExtendedPerfMetrics', 'FluxTransformer2DModel', 'GenerationConfig', 'GenerationFinishReason', 'GenerationHandle', 'GenerationOutput', 'GenerationResult', 'GenerationStatus', 'Generator', 'Image2ImagePipeline', 'ImageGenerationConfig', 'ImageGenerationPerfMetrics', 'InpaintingPipeline', 'LLMPipeline', 'MeanStdPair', 'PerfMetrics', 'PipelineMetrics', 'RawImageGenerationPerfMetrics', 'RawPerfMetrics', 'SD3Transformer2DModel', 'SDPerModelsPerfMetrics', 'SDPerfMetrics', 'Scheduler', 'SchedulerConfig', 'SparseAttentionConfig', 'SparseAttentionMode', 'SpeechGenerationConfig', 'SpeechGenerationPerfMetrics', 'StopCriteria', 'StreamerBase', 'StreamingStatus', 'StructuralTagItem', 'StructuralTagsConfig', 'StructuredOutputConfig', 'SummaryStats', 'T5EncoderModel', 'Text2ImagePipeline', 'Text2SpeechDecodedResults', 'Text2SpeechPipeline', 'TextEmbeddingPipeline', 'TextRerankPipeline', 'TextStreamer', 'TokenizedInputs', 'Tokenizer', 'TorchGenerator', 'UNet2DConditionModel', 'VLMDecodedResults', 'VLMPerfMetrics', 'VLMPipeline', 'VLMRawPerfMetrics', 'WhisperDecodedResultChunk', 'WhisperDecodedResults', 'WhisperGenerationConfig', 'WhisperPerfMetrics', 'WhisperPipeline', 'WhisperRawPerfMetrics', 'draft_model', 'get_version'] class Adapter: """ Immutable LoRA Adapter that carries the adaptation matrices and serves as unique adapter identifier. @@ -2854,6 +2854,58 @@ class TextEmbeddingPipeline: """ Waits computed embeddings for a query """ +class TextRerankPipeline: + """ + Text rerank pipeline + """ + class Config: + """ + + Structure to keep TextRerankPipeline configuration parameters. + Attributes: + top_n (int, optional): + Number of documents to return sorted by score. + max_length (int, optional): + Maximum length of tokens passed to the embedding model. + """ + @typing.overload + def __init__(self) -> None: + ... + @typing.overload + def __init__(self, **kwargs) -> None: + ... + @property + def max_length(self) -> int | None: + ... + @max_length.setter + def max_length(self, arg0: typing.SupportsInt | None) -> None: + ... + @property + def top_n(self) -> int: + ... + @top_n.setter + def top_n(self, arg0: typing.SupportsInt) -> None: + ... + def __init__(self, models_path: os.PathLike | str | bytes, device: str, config: openvino_genai.py_openvino_genai.TextRerankPipeline.Config | None = None, **kwargs) -> None: + """ + Constructs a pipeline from xml/bin files, tokenizer and configuration in the same dir + models_path (os.PathLike): Path to the directory containing model xml/bin files and tokenizer + device (str): Device to run the model on (e.g., CPU, GPU). + config: (TextRerankPipeline.Config): Optional pipeline configuration + kwargs: Plugin and/or config properties + """ + def rerank(self, query: str, texts: collections.abc.Sequence[str]) -> list[tuple[int, float]]: + """ + Reranks a vector of texts based on the query. + """ + def start_rerank_async(self, query: str, texts: collections.abc.Sequence[str]) -> None: + """ + Asynchronously reranks a vector of texts based on the query. + """ + def wait_rerank(self) -> list[tuple[int, float]]: + """ + Waits for reranked texts. + """ class TextStreamer(StreamerBase): """ diff --git a/src/python/py_rag.cpp b/src/python/py_rag.cpp index 7884616075..59e6fa388b 100644 --- a/src/python/py_rag.cpp +++ b/src/python/py_rag.cpp @@ -8,6 +8,7 @@ #include #include "openvino/genai/rag/text_embedding_pipeline.hpp" +#include "openvino/genai/rag/text_rerank_pipeline.hpp" #include "py_utils.hpp" #include "tokenizer/tokenizers_path.hpp" @@ -15,6 +16,7 @@ namespace py = pybind11; using ov::genai::EmbeddingResult; using ov::genai::EmbeddingResults; using ov::genai::TextEmbeddingPipeline; +using ov::genai::TextRerankPipeline; namespace pyutils = ov::genai::pybind::utils; @@ -34,7 +36,17 @@ Structure to keep TextEmbeddingPipeline configuration parameters. embed_instruction (str, optional): Instruction to use for embedding a document. )"; -} + +const auto text_reranking_config_docstring = R"( +Structure to keep TextRerankPipeline configuration parameters. +Attributes: + top_n (int, optional): + Number of documents to return sorted by score. + max_length (int, optional): + Maximum length of tokens passed to the embedding model. +)"; + +} // namespace void init_rag_pipelines(py::module_& m) { auto text_embedding_pipeline = @@ -150,5 +162,84 @@ models_path (os.PathLike): Path to the directory containing model xml/bin files device (str): Device to run the model on (e.g., CPU, GPU). config: (TextEmbeddingPipeline.Config): Optional pipeline configuration kwargs: Plugin and/or config properties +)"); + + auto text_rerank_pipeline = + py::class_(m, "TextRerankPipeline", "Text rerank pipeline") + .def( + "rerank", + [](ov::genai::TextRerankPipeline& pipe, + const std::string& query, + const std::vector& texts) -> py::typing::Union>> { + std::vector> res; + { + py::gil_scoped_release rel; + res = pipe.rerank(query, texts); + } + return py::cast(res); + }, + py::arg("query"), + py::arg("texts"), + "Reranks a vector of texts based on the query.") + .def( + "start_rerank_async", + [](ov::genai::TextRerankPipeline& pipe, + const std::string& query, + const std::vector& texts) -> void { + py::gil_scoped_release rel; + pipe.start_rerank_async(query, texts); + }, + py::arg("query"), + py::arg("texts"), + "Asynchronously reranks a vector of texts based on the query.") + .def( + "wait_rerank", + [](ov::genai::TextRerankPipeline& pipe) -> py::typing::Union>> { + std::vector> res; + { + py::gil_scoped_release rel; + res = pipe.wait_rerank(); + } + return py::cast(res); + }, + "Waits for reranked texts."); + + py::class_(text_rerank_pipeline, "Config", text_reranking_config_docstring) + .def(py::init<>()) + .def(py::init([](py::kwargs kwargs) { + return ov::genai::TextRerankPipeline::Config(pyutils::kwargs_to_any_map(kwargs)); + })) + .def_readwrite("top_n", &ov::genai::TextRerankPipeline::Config::top_n) + .def_readwrite("max_length", &ov::genai::TextRerankPipeline::Config::max_length); + + text_rerank_pipeline.def( + py::init([](const std::filesystem::path& models_path, + const std::string& device, + const std::optional& config, + const py::kwargs& kwargs) { + ScopedVar env_manager(pyutils::ov_tokenizers_module_path()); + if (config.has_value()) { + return std::make_unique(models_path, + device, + *config, + pyutils::kwargs_to_any_map(kwargs)); + } + return std::make_unique(models_path, + device, + pyutils::kwargs_to_any_map(kwargs)); + }), + py::arg("models_path"), + "Path to the directory containing model xml/bin files and tokenizer", + py::arg("device"), + "Device to run the model on (e.g., CPU, GPU)", + py::arg("config") = std::nullopt, + "Optional pipeline configuration", + "Plugin and/or config properties", + R"( +Constructs a pipeline from xml/bin files, tokenizer and configuration in the same dir +models_path (os.PathLike): Path to the directory containing model xml/bin files and tokenizer +device (str): Device to run the model on (e.g., CPU, GPU). +config: (TextRerankPipeline.Config): Optional pipeline configuration +kwargs: Plugin and/or config properties )"); } diff --git a/src/python/py_speech_generation.cpp b/src/python/py_speech_generation.cpp index acfc64da27..585419ba12 100644 --- a/src/python/py_speech_generation.cpp +++ b/src/python/py_speech_generation.cpp @@ -91,8 +91,6 @@ SpeechGenerationConfig update_speech_generation_config_from_kwargs(const SpeechG } // namespace void init_speech_generation_pipeline(py::module_& m) { - m.doc() = "Pybind11 binding for Text-to-speech Pipeline"; - // Binding for SpeechGenerationConfig py::class_(m, "SpeechGenerationConfig", diff --git a/tests/python_tests/requirements.txt b/tests/python_tests/requirements.txt index fb10c49a02..041176ff41 100644 --- a/tests/python_tests/requirements.txt +++ b/tests/python_tests/requirements.txt @@ -6,9 +6,12 @@ onnx==1.18.0 pytest pytest-html hf_transfer -langchain_community gguf>=0.10.0 +# rag requirements +langchain_community==0.3.27 +langchain-core==0.3.69 + # requirements for specific models # - hf-tiny-model-private/tiny-random-RoFormerForCausalLM rjieba diff --git a/tests/python_tests/samples/conftest.py b/tests/python_tests/samples/conftest.py index df0ba18898..adbc46f4cc 100644 --- a/tests/python_tests/samples/conftest.py +++ b/tests/python_tests/samples/conftest.py @@ -6,9 +6,14 @@ import logging import gc import requests +from pathlib import Path from utils.network import retry_request from utils.constants import get_ov_cache_dir +from transformers import AutoTokenizer +from openvino_tokenizers import convert_tokenizer +from openvino import save_model + # Configure logging logging.basicConfig(level=logging.INFO) @@ -115,9 +120,14 @@ "name": "katuni4ka/tiny-random-llava", "convert_args": ["--trust-remote-code", "--task", "image-text-to-text"] }, - "BAAI/bge-small-en-v1.5": { + "bge-small-en-v1.5": { "name": "BAAI/bge-small-en-v1.5", - "convert_args": ['--trust-remote-code'] + "convert_args": ["--trust-remote-code"] + }, + "ms-marco-TinyBERT-L2-v2": { + "name": "cross-encoder/ms-marco-TinyBERT-L2-v2", + "convert_args": ["--trust-remote-code", "--task", "text-classification"], + "convert_2_input_tokenizer": True }, "tiny-random-SpeechT5ForTextToSpeech": { "name": "hf-internal-testing/tiny-random-SpeechT5ForTextToSpeech", @@ -171,6 +181,12 @@ def setup_and_teardown(request, tmp_path_factory): logger.info(f"Skipping cleanup of temporary directory: {ov_cache}") +def convert_2_input_tokenizer(models_path): + hf_tokenizer = AutoTokenizer.from_pretrained(models_path, trust_remote_code=True) + ov_tokenizer = convert_tokenizer(hf_tokenizer, with_detokenizer=False, number_of_inputs=2) + save_model(ov_tokenizer, models_path / "openvino_tokenizer.xml") + + @pytest.fixture(scope="session") def convert_model(request): """Fixture to convert the model once for the session.""" @@ -194,9 +210,12 @@ def convert_model(request): command.extend(model_args) logger.info(f"Conversion command: {' '.join(command)}") retry_request(lambda: subprocess.run(command, check=True, capture_output=True, text=True, env=sub_env)) - + + if MODELS[model_id].get("convert_2_input_tokenizer", False): + convert_2_input_tokenizer(Path(model_path)) + yield model_path - + # Cleanup the model after tests if os.environ.get("CLEANUP_CACHE", "false").lower() == "true": if os.path.exists(model_cache): diff --git a/tests/python_tests/samples/test_rag.py b/tests/python_tests/samples/test_rag.py new file mode 100644 index 0000000000..79b4b59323 --- /dev/null +++ b/tests/python_tests/samples/test_rag.py @@ -0,0 +1,60 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os +import pytest +import sys + +from conftest import SAMPLES_PY_DIR, SAMPLES_CPP_DIR, SAMPLES_JS_DIR +from test_utils import run_sample + + +class TestTextEmbeddingPipeline: + @pytest.mark.rag + @pytest.mark.samples + @pytest.mark.parametrize("convert_model", ["bge-small-en-v1.5"], indirect=True) + def test_sample_text_embedding_pipeline(self, convert_model): + # Run Python sample + py_script = os.path.join(SAMPLES_PY_DIR, "rag/text_embeddings.py") + py_command = [sys.executable, py_script, convert_model, "Document 1", "Document 2"] + py_result = run_sample(py_command) + + # Run C++ sample + cpp_sample = os.path.join(SAMPLES_CPP_DIR, "text_embeddings") + cpp_command = [cpp_sample, convert_model, "Document 1", "Document 2"] + cpp_result = run_sample(cpp_command) + + # Run JS sample + js_sample = os.path.join(SAMPLES_JS_DIR, "rag/text_embeddings.js") + js_command = ["node", js_sample, convert_model, "Document 1", "Document 2"] + js_result = run_sample(js_command) + + # Compare results + assert py_result.stdout == cpp_result.stdout, "Python and C++ results should match" + assert py_result.stdout == js_result.stdout, "Python and JS results should match" + + +class TestTextRerankPipeline: + @pytest.mark.rag + @pytest.mark.samples + @pytest.mark.parametrize("convert_model", ["ms-marco-TinyBERT-L2-v2"], indirect=True) + def test_sample_text_rerank_pipeline(self, convert_model): + # Run Python sample + py_script = os.path.join(SAMPLES_PY_DIR, "rag/text_rerank.py") + document_1 = "Intel Core Ultra processors incorporate an AI-optimized\ +architecture that supports new user experiences and the\ +next wave of commercial applications." + document_2 = "Intel Core Ultra processors are designed to\ +provide enhanced performance and efficiency for a wide\ +range of computing tasks." + py_command = [sys.executable, py_script, convert_model, "What are the main features of Intel Core Ultra processors?", document_1, document_2] + py_result = run_sample(py_command) + + # Run C++ sample + cpp_sample = os.path.join(SAMPLES_CPP_DIR, "text_rerank") + + cpp_command = [cpp_sample, convert_model, "What are the main features of Intel Core Ultra processors?", document_1, document_2] + cpp_result = run_sample(cpp_command) + + # Compare results + assert py_result.stdout == cpp_result.stdout, "Python and C++ results should match" diff --git a/tests/python_tests/samples/test_text_embedding_pipeline.py b/tests/python_tests/samples/test_text_embedding_pipeline.py deleted file mode 100644 index 8b88e448cd..0000000000 --- a/tests/python_tests/samples/test_text_embedding_pipeline.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (C) 2025 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import os -import pytest -import sys - -from conftest import SAMPLES_PY_DIR, SAMPLES_CPP_DIR, SAMPLES_JS_DIR -from test_utils import run_sample - -class TestTextEmbeddingPipeline: - @pytest.mark.rag - @pytest.mark.samples - @pytest.mark.parametrize("convert_model", ["BAAI/bge-small-en-v1.5"], indirect=True) - def test_sample_text_embedding_pipeline(self, convert_model): - # Run Python sample - py_script = os.path.join(SAMPLES_PY_DIR, "rag/text_embeddings.py") - py_command = [sys.executable, py_script, convert_model, "Document 1", "Document 2"] - py_result = run_sample(py_command) - - # Run C++ sample - cpp_sample = os.path.join(SAMPLES_CPP_DIR, "text_embeddings") - cpp_command = [cpp_sample, convert_model, "Document 1", "Document 2"] - cpp_result = run_sample(cpp_command) - - # Run JS sample - js_sample = os.path.join(SAMPLES_JS_DIR, "rag/text_embeddings.js") - js_command =['node', js_sample, convert_model, "Document 1", "Document 2"] - js_result = run_sample(js_command) - - # Compare results - assert py_result.stdout == cpp_result.stdout, "Python and C++ results should match" - assert py_result.stdout == js_result.stdout, "Python and JS results should match" diff --git a/tests/python_tests/samples/test_tools_llm_benchmark.py b/tests/python_tests/samples/test_tools_llm_benchmark.py index fc7a08b4d3..07e6663ee3 100644 --- a/tests/python_tests/samples/test_tools_llm_benchmark.py +++ b/tests/python_tests/samples/test_tools_llm_benchmark.py @@ -203,7 +203,7 @@ def test_python_tool_llm_benchmark_optimum(self, convert_model, download_test_co run_sample(benchmark_py_command) @pytest.mark.samples - @pytest.mark.parametrize("convert_model", ["BAAI/bge-small-en-v1.5"], indirect=True) + @pytest.mark.parametrize("convert_model", ["bge-small-en-v1.5"], indirect=True) @pytest.mark.parametrize("sample_args", [ ["-d", "cpu", "-n", "2"], ["-d", "cpu", "-n", "2", "--embedding_max_length", "128", "--embedding_normalize", "--embedding_pooling", "mean"], diff --git a/tests/python_tests/test_rag.py b/tests/python_tests/test_rag.py index e7e0cff4d8..b85fa37b3a 100644 --- a/tests/python_tests/test_rag.py +++ b/tests/python_tests/test_rag.py @@ -5,19 +5,55 @@ import pytest import gc from pathlib import Path -from openvino_genai import TextEmbeddingPipeline -from utils.hugging_face import download_and_convert_embeddings_models +from openvino_genai import TextEmbeddingPipeline, TextRerankPipeline +from utils.hugging_face import download_and_convert_embeddings_models, download_and_convert_rerank_model +from langchain_core.documents.base import Document from langchain_community.embeddings import OpenVINOBgeEmbeddings +from langchain_community.document_compressors.openvino_rerank import OpenVINOReranker from typing import Literal -TEST_MODELS = [ +EMBEDDINGS_TEST_MODELS = [ "BAAI/bge-small-en-v1.5", "mixedbread-ai/mxbai-embed-xsmall-v1", ] -TEXT_DATASET = f"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aliquam tempus mollis suscipit. Pellentesque id suscipit magna. Pellentesque condimentum magna vel nisi condimentum suscipit. Interdum et malesuada fames ac ante ipsum primis in faucibus. Duis a urna ac eros accumsan commodo non non magna. Aenean mattis commodo urna, ac interdum turpis semper eu. Ut ut pharetra quam. Suspendisse erat tortor, vulputate sit amet quam eu, accumsan facilisis est. Ut varius nibh quis suscipit tempor. Pellentesque luctus, turpis id hendrerit condimentum, urna augue ultricies elit, vitae lacinia mauris enim id lacus.\ -Sed blandit in odio sed sagittis. Donec et turpis tincidunt nisl suscipit sodales id et eros. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Praesent tincidunt quam feugiat, luctus lacus semper, pharetra lorem. Aenean tincidunt, mi id porttitor fringilla, nisi nibh posuere dui, a vulputate libero ex in velit. Nunc non mauris faucibus nibh mattis venenatis. Integer maximus lectus eu mollis sodales. Donec sed varius tortor. Morbi sagittis at ex at semper. Sed in euismod mi, at viverra ante.\ -Nulla luctus accumsan varius. Cras ut tempor est, vitae vehicula velit. In consectetur mi eget elit tempus, at pellentesque felis auctor. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Cras tempor feugiat congue. Interdum et malesuada fames ac ante ipsum primis in faucibus. Suspendisse ac leo ut felis fermentum sagittis rutrum iaculis lectus. Aenean a lorem lobortis, suscipit leo id, malesuada velit. Phasellus vitae eros placerat, sodales ex eget, porta orci. Aliquam erat volutpat. Nam eget nisi at nibh euismod tempus vel at massa. Vivamus malesuada dui quis nibh congue facilisis." +RERANK_TEST_MODELS = [ + "cross-encoder/ms-marco-TinyBERT-L2-v2", # sigmoid applied + # "answerdotai/ModernBERT-base", # 2 classes output, softmax applied. Skip until langchain OpenVINORerank supports it. +] + +TEXT_DATASET = f"The commercial PC market is propelled by premium\ +computing solutions that drive user productivity and help\ +service organizations protect and maintain devices.\ +Corporations must empower mobile and hybrid workers\ +while extracting value from artificial intelligence (AI) to\ +improve business outcomes. Moreover, both public and\ +private sectors must address sustainability initiatives\ +pertaining to the full life cycle of computing fleets. An\ +inflection point in computing architecture is needed to stay\ +ahead of evolving requirements.\ +Introducing Intel® Core™ Ultra Processors\ +Intel® Core™ Ultra processors shape the future of\ +commercial computing in four major ways:\ +Power Efficiency\ +The new product line features a holistic approach to powerefficiency that benefits mobile work. Substantial changes to\ +the microarchitecture, manufacturing process, packaging\ +technology, and power management software result in up to\ +40% lower processor power consumption for modern tasks\ +such as video conferencing with a virtual camera. \ +Artificial Intelligence\ +Intel Core Ultra processors incorporate an AI-optimized\ +architecture that supports new user experiences and the\ +next wave of commercial applications. The CPU, GPU, and\ +the new neural processing unit (NPU) are all capable of\ +executing AI tasks as directed by application developers.\ +For example, elevated mobile collaboration is possible with\ +support for AI assisted background blur, noise suppression,\ +eye tracking, and picture framing. Intel Core Ultra\ +processors are capable of up to 2.5x the AI inference\ +performance per watt as compared to Intel’s previous\ +mobile processor offering.2\ +" @pytest.fixture(scope="class", autouse=True) @@ -32,13 +68,10 @@ def run_gc_after_test(): @pytest.fixture(scope="module") def dataset_documents(chunk_size=200): - return [ - TEXT_DATASET[i : i + chunk_size] - for i in range(0, len(TEXT_DATASET), chunk_size) - ] + return [TEXT_DATASET[i : i + chunk_size] for i in range(0, len(TEXT_DATASET), chunk_size)] -def run_genai( +def run_text_embedding_genai( models_path: Path, documents: list[str], config: TextEmbeddingPipeline.Config | None = None, @@ -55,7 +88,7 @@ def run_genai( return pipeline.embed_query(documents[0]) -def run_langchain( +def run_text_embedding_langchain( models_path: Path, documents: list[str], config: TextEmbeddingPipeline.Config | None = None, @@ -88,30 +121,86 @@ def run_langchain( return ov_embeddings.embed_query(documents[0]) -def run_pipeline_with_ref( +def run_text_embedding_pipeline_with_ref( models_path: Path, documents: list[str], config: TextEmbeddingPipeline.Config | None = None, task: Literal["embed_documents", "embed_query"] = "embed_documents", ): - genai_result = run_genai(models_path, documents, config, task) - langchain_result = run_langchain(models_path, documents, config, task) + genai_result = run_text_embedding_genai(models_path, documents, config, task) + langchain_result = run_text_embedding_langchain(models_path, documents, config, task) np_genai_result = np.array(genai_result) np_langchain_result = np.array(langchain_result) max_error = np.abs(np_genai_result - np_langchain_result).max() print(f"Max error: {max_error}") - assert ( - np.abs(np_genai_result - np_langchain_result).max() < 2e-6 - ), f"Max error: {max_error}" + assert np.abs(np_genai_result - np_langchain_result).max() < 2e-6, f"Max error: {max_error}" -@pytest.mark.parametrize( - "download_and_convert_embeddings_models", ["BAAI/bge-small-en-v1.5"], indirect=True -) +def assert_rerank_results(result_1: list[tuple[int, float]], result_2: list[tuple[int, float]]): + assert len(result_1) == len(result_2), f"Results length mismatch: {len(result_1)} != {len(result_2)}" + for pair_1, pair_2 in zip(result_1, result_2): + assert pair_1[0] == pair_2[0], f"Document IDs do not match: {pair_1[0]} != {pair_2[0]}" + assert abs(pair_1[1] - pair_2[1]) < 1e-6, f"Scores do not match for document ID {pair_1[0]}: " f"{pair_1[1]} != {pair_2[1]}" + + +def run_text_rerank_langchain( + models_path: Path, + query: str, + documents: list[str], + config: TextRerankPipeline.Config | None = None, +): + if not config: + config = TextRerankPipeline.Config() + + reranker = OpenVINOReranker(model_name_or_path=str(models_path), top_n=config.top_n) + + langchain_documents = [Document(page_content=text) for text in documents] + + reranked_documents = reranker.compress_documents(documents=langchain_documents, query=query) + + return [(doc.metadata["id"], float(doc.metadata.get("relevance_score", -1))) for doc in reranked_documents] + + +def run_text_rerank_genai( + models_path: Path, + query: str, + documents: list[str], + config: TextRerankPipeline.Config | None = None, +): + if not config: + config = TextRerankPipeline.Config() + + reranker = TextRerankPipeline(models_path, "CPU", config=config) + + sync_result = reranker.rerank( + query=query, + texts=documents, + ) + + reranker.start_rerank_async(query, documents) + async_result = reranker.wait_rerank() + + assert_rerank_results(sync_result, async_result) + return async_result + + +def run_text_rerank_pipeline_with_ref( + models_path: Path, + query: str, + documents: list[str], + config: TextRerankPipeline.Config | None = None, +): + genai_result = run_text_rerank_genai(models_path, query, documents, config) + langchain_result = run_text_rerank_langchain(models_path, query, documents, config) + + assert_rerank_results(genai_result, langchain_result) + + +@pytest.mark.parametrize("download_and_convert_embeddings_models", ["BAAI/bge-small-en-v1.5"], indirect=True) @pytest.mark.precommit -def test_constructors(download_and_convert_embeddings_models): +def test_embedding_constructors(download_and_convert_embeddings_models): _, _, models_path = download_and_convert_embeddings_models TextEmbeddingPipeline(models_path, "CPU") @@ -137,20 +226,14 @@ def test_constructors(download_and_convert_embeddings_models): ) -@pytest.mark.parametrize( - "download_and_convert_embeddings_models", TEST_MODELS, indirect=True -) +@pytest.mark.parametrize("download_and_convert_embeddings_models", EMBEDDINGS_TEST_MODELS, indirect=True) @pytest.mark.parametrize( "config", [ TextEmbeddingPipeline.Config(normalize=False), - TextEmbeddingPipeline.Config( - normalize=False, pooling_type=TextEmbeddingPipeline.PoolingType.MEAN - ), + TextEmbeddingPipeline.Config(normalize=False, pooling_type=TextEmbeddingPipeline.PoolingType.MEAN), TextEmbeddingPipeline.Config(normalize=True), - TextEmbeddingPipeline.Config( - normalize=True, pooling_type=TextEmbeddingPipeline.PoolingType.MEAN - ), + TextEmbeddingPipeline.Config(normalize=True, pooling_type=TextEmbeddingPipeline.PoolingType.MEAN), TextEmbeddingPipeline.Config( normalize=False, embed_instruction="Represent this document for searching relevant passages: ", @@ -165,27 +248,19 @@ def test_constructors(download_and_convert_embeddings_models): ], ) @pytest.mark.precommit -def test_embed_documents( - download_and_convert_embeddings_models, dataset_documents, config -): +def test_embed_documents(download_and_convert_embeddings_models, dataset_documents, config): _, _, models_path = download_and_convert_embeddings_models - run_pipeline_with_ref(models_path, dataset_documents, config, "embed_documents") + run_text_embedding_pipeline_with_ref(models_path, dataset_documents, config, "embed_documents") -@pytest.mark.parametrize( - "download_and_convert_embeddings_models", TEST_MODELS, indirect=True -) +@pytest.mark.parametrize("download_and_convert_embeddings_models", EMBEDDINGS_TEST_MODELS, indirect=True) @pytest.mark.parametrize( "config", [ TextEmbeddingPipeline.Config(normalize=False), - TextEmbeddingPipeline.Config( - normalize=False, pooling_type=TextEmbeddingPipeline.PoolingType.MEAN - ), + TextEmbeddingPipeline.Config(normalize=False, pooling_type=TextEmbeddingPipeline.PoolingType.MEAN), TextEmbeddingPipeline.Config(normalize=True), - TextEmbeddingPipeline.Config( - normalize=True, pooling_type=TextEmbeddingPipeline.PoolingType.MEAN - ), + TextEmbeddingPipeline.Config(normalize=True, pooling_type=TextEmbeddingPipeline.PoolingType.MEAN), TextEmbeddingPipeline.Config( normalize=False, query_instruction="Represent this query for searching relevant passages: ", @@ -202,4 +277,49 @@ def test_embed_documents( @pytest.mark.precommit def test_embed_query(download_and_convert_embeddings_models, dataset_documents, config): _, _, models_path = download_and_convert_embeddings_models - run_pipeline_with_ref(models_path, dataset_documents[:1], config, "embed_query") + run_text_embedding_pipeline_with_ref(models_path, dataset_documents[:1], config, "embed_query") + + +@pytest.mark.parametrize("download_and_convert_rerank_model", [RERANK_TEST_MODELS[0]], indirect=True) +@pytest.mark.precommit +def test_rerank_constructors(download_and_convert_rerank_model): + _, _, models_path = download_and_convert_rerank_model + + TextRerankPipeline(models_path, "CPU") + TextRerankPipeline(models_path, "CPU", TextRerankPipeline.Config()) + TextRerankPipeline( + models_path, + "CPU", + TextRerankPipeline.Config(), + PERFORMANCE_HINT_NUM_REQUESTS=2, + ) + TextRerankPipeline( + models_path, + "CPU", + top_n=2, + ) + TextRerankPipeline( + models_path, + "CPU", + top_n=2, + PERFORMANCE_HINT_NUM_REQUESTS=2, + ) + + +@pytest.mark.parametrize("download_and_convert_rerank_model", RERANK_TEST_MODELS, indirect=True) +@pytest.mark.parametrize("query", ["What are the main features of Intel Core Ultra processors?"]) +@pytest.mark.parametrize( + "config", + [ + TextRerankPipeline.Config(), + TextRerankPipeline.Config(top_n=10), + ], + ids=[ + "top_n=default", + "top_n=10", + ], +) +@pytest.mark.precommit +def test_rerank_documents(download_and_convert_rerank_model, dataset_documents, query, config): + _, _, models_path = download_and_convert_rerank_model + run_text_rerank_pipeline_with_ref(models_path, query, dataset_documents, config) diff --git a/tests/python_tests/test_tokenizer.py b/tests/python_tests/test_tokenizer.py index ef461d1712..f48af9bada 100644 --- a/tests/python_tests/test_tokenizer.py +++ b/tests/python_tests/test_tokenizer.py @@ -454,7 +454,7 @@ def test_two_inputs_string_list_of_lists(hf_ov_genai_models, input_pair): @pytest.mark.parametrize("input_pair", [ [["Eng... test, string?!" * 100], ["Multiline\nstring!\nWow!"]], [["hi" * 20], ["buy" * 90]], - # [["What is the capital of Great Britain"] * 4, ["London is capital of Great Britain"]], # TODO: broadcast of [4, 1] to [4, 4] fails + [["What is the capital of Great Britain"] * 4, ["London is capital of Great Britain"]], [["What is the capital of Great Britain"], ["London is capital of Great Britain"] * 4], ]) def test_two_inputs_string(hf_ov_genai_models, input_pair): diff --git a/tests/python_tests/utils/hugging_face.py b/tests/python_tests/utils/hugging_face.py index 20ed0cf8fe..3e87431c82 100644 --- a/tests/python_tests/utils/hugging_face.py +++ b/tests/python_tests/utils/hugging_face.py @@ -9,7 +9,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import GenerationConfig as HFGenerationConfig -from optimum.intel import OVModelForCausalLM, OVModelForFeatureExtraction +from optimum.intel import OVModelForCausalLM, OVModelForFeatureExtraction, OVModelForSequenceClassification from optimum.intel.openvino.modeling import OVModel from huggingface_hub import hf_hub_download @@ -204,6 +204,15 @@ def download_and_convert_embeddings_models(request): return _download_and_convert_model(model_id, OVModelForFeatureExtraction) +@pytest.fixture() +def download_and_convert_rerank_model(request): + model_id = request.param + opt_model, hf_tokenizer, models_path = _download_and_convert_model(model_id, OVModelForSequenceClassification) + ov_tokenizer = convert_tokenizer(hf_tokenizer, with_detokenizer=False, number_of_inputs=2) + save_model(ov_tokenizer, models_path / "openvino_tokenizer.xml") + return opt_model, hf_tokenizer, models_path + + def _download_and_convert_model(model_id: str, model_class: Type[OVModel], **tokenizer_kwargs): dir_name = str(model_id).replace(sep, "_") ov_cache_models_dir = get_ov_cache_models_dir()