Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a2f3528
Add pipeline
as-suvorov Jul 10, 2025
8da1701
Release gil in python bindings
as-suvorov Jul 10, 2025
567cea6
apply sigmoid
as-suvorov Jul 11, 2025
42a4601
Use langchain for reranker tests
as-suvorov Jul 11, 2025
7504e8c
Update cpp samples readme
as-suvorov Jul 11, 2025
521d926
Add python sample and readme
as-suvorov Jul 11, 2025
7e488e0
Add samples test
as-suvorov Jul 11, 2025
3a425db
Support 2 classes output models
as-suvorov Jul 15, 2025
46b654e
Merge remote-tracking branch 'upstream/master' into as/rag_text_reranker
as-suvorov Jul 15, 2025
309c7d1
Fix sample run
as-suvorov Jul 15, 2025
333d1f9
Fix sample run model id
as-suvorov Jul 15, 2025
57dcdaf
unfix reranker version
as-suvorov Jul 15, 2025
ad4fa6a
Use langchain
as-suvorov Jul 16, 2025
21f52e2
Provide model task
as-suvorov Jul 16, 2025
dc43a10
Convert 2 inputs tokenizer for samples tests
as-suvorov Jul 17, 2025
be22a89
Merge remote-tracking branch 'upstream/master' into as/rag_text_reranker
as-suvorov Jul 17, 2025
cd4422e
Unskip tokenizers test
as-suvorov Jul 17, 2025
35e4929
Fix pybindings docstring
as-suvorov Jul 17, 2025
8e2c7b6
Address review comments
as-suvorov Jul 17, 2025
ea286ee
Merge remote-tracking branch 'upstream/master' into as/rag_text_reranker
as-suvorov Jul 17, 2025
e819507
Add max_length
as-suvorov Jul 17, 2025
17154f9
Merge remote-tracking branch 'upstream/master' into as/rag_text_reranker
as-suvorov Jul 18, 2025
a5fa86e
Merge remote-tracking branch 'upstream/master' into as/rag_text_reranker
as-suvorov Jul 23, 2025
610052c
Update bindings
as-suvorov Jul 23, 2025
37d1b00
Align pyi
as-suvorov Jul 24, 2025
4fca17d
Merge branch 'master' into as/rag_text_reranker
Wovchena Jul 24, 2025
244e21d
Merge remote-tracking branch 'upstream/master' into as/rag_text_reranker
as-suvorov Jul 28, 2025
64cef5a
Fix pyi
as-suvorov Jul 28, 2025
6f61fdd
Merge remote-tracking branch 'origin/as/rag_text_reranker' into as/ra…
as-suvorov Jul 28, 2025
0985787
Merge branch 'master' into as/rag_text_reranker
as-suvorov Jul 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
'category: RAG samples':
- 'samples/cpp/rag/**/*'
- 'samples/python/rag/**/*'
- 'tests/python_tests/samples/test_text_embedding_pipeline.py'
- 'tests/python_tests/samples/test_rag.py'

'category: structured output generation':
- 'src/cpp/src/sampling/structured_output/*'
Expand Down
2 changes: 1 addition & 1 deletion samples/cpp/rag/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function(add_sample_executable target_name)
EXCLUDE_FROM_ALL)
endfunction()

set(SAMPLE_LIST text_embeddings)
set(SAMPLE_LIST text_embeddings text_rerank)

foreach(sample ${SAMPLE_LIST})
add_sample_executable(${sample})
Expand Down
35 changes: 32 additions & 3 deletions samples/cpp/rag/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Retrieval Augmented Generation Sample

This example showcases inference of Text Embedding Models. The application has limited configuration options to encourage the reader to explore and modify the source code. For example, change the device for inference to GPU. The sample features `ov::genai::TextEmbeddingPipeline` and uses text as an input source.
This example showcases inference of Text Embedding and Text Rerank Models. The application has limited configuration options to encourage the reader to explore and modify the source code. For example, change the device for inference to GPU. The sample features `ov::genai::TextEmbeddingPipeline` and `ov::genai::TextRerankPipeline` and uses text as an input source.

## Download and Convert the Model and Tokenizers

Expand All @@ -10,17 +10,37 @@ Install [../../export-requirements.txt](../../export-requirements.txt) to conver

```sh
pip install --upgrade-strategy eager -r ../../export-requirements.txt
```

Then, run the export with Optimum CLI:

```sh
optimum-cli export openvino --trust-remote-code --model BAAI/bge-small-en-v1.5 BAAI/bge-small-en-v1.5
```


## Run

Follow [Get Started with Samples](https://docs.openvino.ai/2025/get-started/learn-openvino/openvino-samples/get-started-demos.html) to run the sample.

`text_embeddings BAAI/bge-small-en-v1.5 "Document 1" "Document 2"`

### 1. Text Embedding Sample (`text_embeddings.cpp`)
- **Description:**
Demonstrates inference of text embedding models using OpenVINO GenAI. Converts input text into vector embeddings for downstream tasks such as retrieval or semantic search.
- **Run Command:**
```sh
text_embeddings <MODEL_DIR> "Document 1" "Document 2"
```
Refer to the [Supported Models](https://openvinotoolkit.github.io/openvino.genai/docs/supported-models/#text-embeddings-models) for more details.

### 2. Text Rerank Sample (`text_rerank.cpp`)
- **Description:**
Demonstrates inference of text rerank models using OpenVINO GenAI. Reranks a list of candidate documents based on their relevance to a query using a cross-encoder or reranker model.
- **Run Command:**
```sh
text_rerank <MODEL_DIR> '<QUERY>' '<TEXT 1>' ['<TEXT 2>' ...]
```


# Text Embedding Pipeline Usage

```c++
Expand All @@ -29,3 +49,12 @@ Refer to the [Supported Models](https://openvinotoolkit.github.io/openvino.genai
ov::genai::TextEmbeddingPipeline pipeline(models_path, device, config);
std::vector<ov::genai::EmbeddingResult> embeddings = pipeline.embed_documents(documents);
```

# Text Rerank Pipeline Usage

```c++
#include "openvino/genai/rag/text_rerank_pipeline.hpp"

ov::genai::TextRerankPipeline pipeline(models_path, device, config);
std::vector<std::pair<size_t, float>> rerank_result = pipeline.rerank(query, documents);
```
45 changes: 45 additions & 0 deletions samples/cpp/rag/text_rerank.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "openvino/genai/rag/text_rerank_pipeline.hpp"

int main(int argc, char* argv[]) try {
if (argc < 4) {
throw std::runtime_error(std::string{"Usage: "} + argv[0] +
" <MODEL_DIR> '<QUERY>' '<TEXT 1>' ['<TEXT 2>' ...]");
}

auto documents = std::vector<std::string>(argv + 3, argv + argc);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is more convenient to specify document files that is more closer to practical use case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep sample simple for the moment. We could implement some file/network loaders and text splitters. But I think it should be a dedicated sample where we can showcase full RAG flow.

std::string models_path = argv[1];
std::string query = argv[2];

std::string device = "CPU"; // GPU can be used as well

ov::genai::TextRerankPipeline::Config config;
config.top_n = 3;

ov::genai::TextRerankPipeline pipeline(models_path, device, config);

std::vector<std::pair<size_t, float>> rerank_result = pipeline.rerank(query, documents);

// print reranked documents
std::cout << std::fixed << std::setprecision(4);
std::cout << "Reranked documents:\n";
for (const auto& [index, score] : rerank_result) {
std::cout << "Document " << index << " (score: " << score << "): " << documents[index] << '\n';
}
std::cout << std::defaultfloat;

} catch (const std::exception& error) {
try {
std::cerr << error.what() << '\n';
} catch (const std::ios_base::failure&) {
}
return EXIT_FAILURE;
} catch (...) {
try {
std::cerr << "Non-exception object thrown\n";
} catch (const std::ios_base::failure&) {
}
return EXIT_FAILURE;
}
30 changes: 27 additions & 3 deletions samples/python/rag/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Retrieval Augmented Generation Sample

This example showcases inference of Text Embedding Models. The application limited configuration configuration options to encourage the reader to explore and modify the source code. For example, change the device for inference to GPU. The sample features `openvino_genai.TextEmbeddingPipeline` and uses text as an input source.
This example showcases inference of Text Embedding and Text Rerank Models. The application has limited configuration options to encourage the reader to explore and modify the source code. For example, change the device for inference to GPU. The sample features `openvino_genai.TextEmbeddingPipeline` and `openvino_genai.TextRerankPipeline` and uses text as an input source.

## Download and Convert the Model and Tokenizers

Expand Down Expand Up @@ -38,10 +38,24 @@ export_tokenizer(tokenizer, output_dir)

Install [deployment-requirements.txt](../../deployment-requirements.txt) via `pip install -r ../../deployment-requirements.txt` and then, run a sample:

`python text_embeddings.py BAAI/bge-small-en-v1.5 "Document 1" "Document 2"`

### 1. Text Embedding Sample (`text_embeddings.py`)
- **Description:**
Demonstrates inference of text embedding models using OpenVINO GenAI. Converts input text into vector embeddings for downstream tasks such as retrieval or semantic search.
- **Run Command:**
```sh
python text_embeddings.py <MODEL_DIR> "Document 1" "Document 2"
```
Refer to the [Supported Models](https://openvinotoolkit.github.io/openvino.genai/docs/supported-models/#text-embeddings-models) for more details.

### 2. Text Rerank Sample (`text_rerank.py`)
- **Description:**
Demonstrates inference of text rerank models using OpenVINO GenAI. Reranks a list of candidate documents based on their relevance to a query using a cross-encoder or reranker model.
- **Run Command:**
```sh
python text_rerank.py <MODEL_DIR> "<QUERY>" "<TEXT 1>" ["<TEXT 2>" ...]
```


# Text Embedding Pipeline Usage

```python
Expand All @@ -51,3 +65,13 @@ pipeline = openvino_genai.TextEmbeddingPipeline(model_dir, "CPU")

embeddings = pipeline.embed_documents(["document1", "document2"])
```

# Text Rerank Pipeline Usage

```python
import openvino_genai

pipeline = openvino_genai.TextRerankPipeline(model_dir, "CPU")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can consider combining TextRerankPipeline and TextEmbeddingPipeline into TextRagPipeline that has separate methods for embedding retrieval and ranking. We have just one object in memoy that solve RAG problem.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a ContextualCompressionRetriever in langchain which wraps retriever and reranker. Base method is invoke(query) -> relevant documents. But we need to implement a retriever which utilizes TextEmbeddingPipeline first.
I will look into different combining scenarios, thanks.


rerank_result = pipeline.rerank(query, documents)
```
31 changes: 31 additions & 0 deletions samples/python/rag/text_rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/usr/bin/env python3
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import argparse
import openvino_genai


def main():
parser = argparse.ArgumentParser()
parser.add_argument("model_dir")
parser.add_argument("query")
parser.add_argument("texts", nargs="+")
args = parser.parse_args()

device = "CPU" # GPU can be used as well

config = openvino_genai.TextRerankPipeline.Config()
config.top_n = 3

pipeline = openvino_genai.TextRerankPipeline(args.model_dir, device, config)

rerank_result = pipeline.rerank(args.query, args.texts)

print("Reranked documents:")
for index, score in rerank_result:
print(f"Document {index} (score: {score:.4f}): {args.texts[index]}")


if __name__ == "__main__":
main()
102 changes: 102 additions & 0 deletions src/cpp/include/openvino/genai/rag/text_rerank_pipeline.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "openvino/genai/tokenizer.hpp"

namespace ov {
namespace genai {

class OPENVINO_GENAI_EXPORTS TextRerankPipeline {
public:
struct OPENVINO_GENAI_EXPORTS Config {
/**
* @brief Number of documents to return sorted by score
*/
size_t top_n = 3;

/**
* @brief Constructs text rerank pipeline configuration
*/
Config() = default;

/**
* @brief Constructs text rerank pipeline configuration
*
* @param properties configuration options
*
* const ov::AnyMap properties{{"top_n", 3}};
* ov::genai::TextRerankPipeline::Config config(properties);
*
* ov::genai::TextRerankPipeline::Config config({{"top_n", 3}});
*/
explicit Config(const ov::AnyMap& properties);
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add max_length, pad_to_max_length and batch_size to the config.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather do it in a separate PR after we confirm approach is working for TextEmbeddingPipeline.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The max_length should be added in any case, some models might crash during the inference otherwise.

By the way, we can set the upper bound for the input shape if max_length parameter is set. It might be beneficial for performance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, will add max_length. Could you please specify model names which are expected to fail?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure we need max_length in any case? Neither LlamaIndex nor LangChain has a max_length parameter.


/**
* @brief Constructs a pipeline from xml/bin files, tokenizer and configuration in the same dir.
*
* @param models_path Path to the directory containing model xml/bin files and tokenizer
* @param device Device
* @param config Pipeline configuration
* @param properties Optional plugin properties to pass to ov::Core::compile_model().
*/
TextRerankPipeline(const std::filesystem::path& models_path,
const std::string& device,
const Config& config,
const ov::AnyMap& properties = {});

/**
* @brief Constructs a pipeline from xml/bin files, tokenizer and configuration in the same dir.
*
* @param models_path Path to the directory containing model xml/bin files and tokenizer
* @param device Device
* @param properties Optional plugin and/or config properties
*/
TextRerankPipeline(const std::filesystem::path& models_path,
const std::string& device,
const ov::AnyMap& properties = {});

/**
* @brief Constructs a pipeline from xml/bin files, tokenizer and configuration in the same dir.
*
* @param models_path Path to the directory containing model xml/bin files and tokenizer
* @param device Device
* @param properties Plugin and/or config properties
*/
template <typename... Properties,
typename std::enable_if<ov::util::StringAny<Properties...>::value, bool>::type = true>
TextRerankPipeline(const std::filesystem::path& models_path, const std::string& device, Properties&&... properties)
: TextRerankPipeline(models_path, device, ov::AnyMap{std::forward<Properties>(properties)...}) {}

/**
* @brief Reranks a vector of texts based on the query.
*/
std::vector<std::pair<size_t, float>> rerank(const std::string& query, const std::vector<std::string>& texts);

/**
* @brief Asynchronously reranks a vector of texts based on the query. Only one method of async family can be
* active.
*/
void start_rerank_async(const std::string& query, const std::vector<std::string>& texts);

/**
* @brief Waits for reranked texts.
*/
std::vector<std::pair<size_t, float>> wait_rerank();

~TextRerankPipeline();

private:
class TextRerankPipelineImpl;
std::unique_ptr<TextRerankPipelineImpl> m_impl;
};

/**
* @brief Number of documents to return after reranking sorted by score
*/
static constexpr ov::Property<size_t> top_n{"top_n"};

} // namespace genai
} // namespace ov
2 changes: 2 additions & 0 deletions src/cpp/include/openvino/genai/tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <vector>
#include <initializer_list>
#include <filesystem>
#include <optional>

#include "openvino/runtime/tensor.hpp"
#include "openvino/genai/visibility.hpp"
Expand All @@ -21,6 +22,7 @@ using Vocab = std::unordered_map<std::string, int64_t>; // similar to huggingfa
struct TokenizedInputs {
ov::Tensor input_ids;
ov::Tensor attention_mask;
std::optional<ov::Tensor> token_type_ids;
};

/**
Expand Down
43 changes: 34 additions & 9 deletions src/cpp/src/debug_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,59 @@

#pragma once

#include <string>
#include <iostream>
#include <fstream>

#include <iostream>
#include <openvino/runtime/tensor.hpp>
#include <string>

template <typename T>
void print_array(T * array, size_t size) {
void print_array(T* array, size_t size) {
std::cout << " => [ ";
for (size_t i = 0; i < std::min(size, size_t(10)); ++i) {
std::cout << array[i] << " ";
}
std::cout << " ] " << std::endl;
}

template <typename T>
void print_tensor(ov::Tensor tensor) {
const auto shape = tensor.get_shape();
const size_t rank = shape.size();
const auto* data = tensor.data<T>();

if (rank > 2) {
print_array(data, tensor.get_size());
return;
}

const size_t batch_size = shape[0];
const size_t seq_length = shape[1];

std::cout << " => [ \n";
for (size_t batch = 0; batch < batch_size; ++batch) {
std::cout << " [ ";
const size_t batch_offset = batch * seq_length;
for (size_t j = 0; j < seq_length; ++j) {
std::cout << data[batch_offset + j] << " ";
}
std::cout << "]\n";
}
std::cout << " ]" << std::endl;
}

inline void print_tensor(std::string name, ov::Tensor tensor) {
std::cout << name;
std::cout << " " << tensor.get_shape().to_string();
if (tensor.get_element_type() == ov::element::i32) {
print_array(tensor.data<int>(), tensor.get_size());
print_tensor<int>(tensor);
} else if (tensor.get_element_type() == ov::element::i64) {
print_array(tensor.data<int64_t>(), tensor.get_size());
print_tensor<int64_t>(tensor);
} else if (tensor.get_element_type() == ov::element::f32) {
print_array(tensor.data<float>(), tensor.get_size());
print_tensor<float>(tensor);
} else if (tensor.get_element_type() == ov::element::boolean) {
print_array(tensor.data<bool>(), tensor.get_size());
print_tensor<bool>(tensor);
} else if (tensor.get_element_type() == ov::element::f16) {
print_array(tensor.data<ov::float16>(), tensor.get_size());
print_tensor<ov::float16>(tensor);
}
}

Expand Down
Loading
Loading