Skip to content

Commit e992324

Browse files
Add GenAI reranker pipeline to llam-rag notebook (#3081)
[CVS-170413](https://jira.devtools.intel.com/browse/CVS-170413) --------- Co-authored-by: Aleksandr Mokrov <[email protected]>
1 parent 8afc843 commit e992324

File tree

6 files changed

+190
-36
lines changed

6 files changed

+190
-36
lines changed

notebooks/llm-rag-langchain/gradio_helper.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def make_demo(
4040
update_retriever_fn: Callable,
4141
model_name: str,
4242
language: str = "English",
43+
rerank_device: str | None = None,
4344
):
4445
examples = chinese_examples if (language == "Chinese") else english_examples
4546

@@ -226,8 +227,8 @@ def make_demo(
226227
value=2,
227228
step=1,
228229
label="Rerank top n",
229-
info="Number of rerank results",
230-
interactive=True,
230+
info="Number of rerank results(setted on creation step in GenAI pipeline).",
231+
interactive=(rerank_device == "NPU"),
231232
)
232233
with gr.Row():
233234
vector_search_top_k = gr.Slider(

notebooks/llm-rag-langchain/llm-rag-langchain-genai.ipynb

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -147,18 +147,10 @@
147147
},
148148
{
149149
"cell_type": "code",
150-
"execution_count": 1,
150+
"execution_count": 2,
151151
"id": "1b2c3f4e",
152152
"metadata": {},
153-
"outputs": [
154-
{
155-
"name": "stdout",
156-
"output_type": "stream",
157-
"text": [
158-
"LLM config will be updated\n"
159-
]
160-
}
161-
],
153+
"outputs": [],
162154
"source": [
163155
"import os\n",
164156
"from pathlib import Path\n",
@@ -880,7 +872,7 @@
880872
},
881873
{
882874
"cell_type": "code",
883-
"execution_count": null,
875+
"execution_count": 16,
884876
"id": "d0bab20b",
885877
"metadata": {},
886878
"outputs": [],
@@ -918,14 +910,6 @@
918910
"id": "e11e73cf",
919911
"metadata": {},
920912
"outputs": [
921-
{
922-
"name": "stderr",
923-
"output_type": "stream",
924-
"text": [
925-
"D:\\openvino_notebooks\\openvino_env\\Lib\\site-packages\\openvino\\runtime\\__init__.py:10: DeprecationWarning: The `openvino.runtime` module is deprecated and will be removed in the 2026.0 release. Please replace `openvino.runtime` with `openvino`.\n",
926-
" warnings.warn(\n"
927-
]
928-
},
929913
{
930914
"data": {
931915
"application/vnd.jupyter.widget-view+json": {
@@ -1198,19 +1182,19 @@
11981182
"\n",
11991183
"[back to top ⬆️](#Table-of-contents:)\n",
12001184
"\n",
1201-
"Now a local rerank model of OpenVINO can be accelerated on NPU by using the `OpenVINOReranker` class without PyTorch requirements.\n",
1185+
"Now a local rerank model of OpenVINO can be run via [GenAI `TextRerankPipeline`](https://github.com/openvinotoolkit/openvino.genai/tree/master/samples/python/rag) through `OpenVINOGenAIReranker` and accelerated on NPU by using the `OpenVINOReranker` class without PyTorch requirements.\n",
12021186
"\n",
12031187
"> **Note**: Rerank can be skipped in RAG.\n"
12041188
]
12051189
},
12061190
{
12071191
"cell_type": "code",
1208-
"execution_count": 25,
1192+
"execution_count": null,
12091193
"id": "b67b39f2-8394-45fb-9b2b-ea63e267a2d3",
12101194
"metadata": {},
12111195
"outputs": [],
12121196
"source": [
1213-
"from ov_langchain_helper import OpenVINOReranker\n",
1197+
"from ov_langchain_helper import OpenVINOReranker, OpenVINOGenAIReranker\n",
12141198
"\n",
12151199
"rerank_model_name = rerank_model_id.value\n",
12161200
"rerank_model_kwargs = {\"device_name\": rerank_device.value}\n",
@@ -1235,10 +1219,10 @@
12351219
" top_n=rerank_top_n,\n",
12361220
" )\n",
12371221
"else:\n",
1238-
" reranker = OpenVINOReranker(\n",
1222+
" reranker = OpenVINOGenAIReranker.from_model_path(\n",
12391223
" model_path=rerank_model_name,\n",
1240-
" model_kwargs=rerank_model_kwargs,\n",
1241-
" top_n=rerank_top_n,\n",
1224+
" device=rerank_device.value,\n",
1225+
" config_kwargs={\"top_n\": rerank_top_n},\n",
12421226
" )"
12431227
]
12441228
},
@@ -1567,8 +1551,8 @@
15671551
"\n",
15681552
" if rerank_device.value == \"NPU\":\n",
15691553
" vector_search_top_k = vector_search_top_k_npu\n",
1570-
" if vector_rerank_top_n > vector_search_top_k:\n",
1571-
" gr.Warning(\"Search top k must >= Rerank top n\")\n",
1554+
" if vector_rerank_top_n > vector_search_top_k:\n",
1555+
" gr.Warning(\"Search top k must >= Rerank top n\")\n",
15721556
"\n",
15731557
" documents = []\n",
15741558
" for doc in docs:\n",
@@ -1586,7 +1570,8 @@
15861570
" search_kwargs = {\"k\": vector_search_top_k}\n",
15871571
" retriever = db.as_retriever(search_kwargs=search_kwargs, search_type=search_method)\n",
15881572
" if run_rerank:\n",
1589-
" reranker.top_n = vector_rerank_top_n\n",
1573+
" if rerank_device.value == \"NPU\":\n",
1574+
" reranker.top_n = vector_rerank_top_n\n",
15901575
" retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever)\n",
15911576
" prompt = PromptTemplate.from_template(rag_prompt_template)\n",
15921577
" combine_docs_chain = create_stuff_documents_chain(llm, prompt)\n",
@@ -1623,7 +1608,8 @@
16231608
" retriever = db.as_retriever(search_kwargs=search_kwargs, search_type=search_method)\n",
16241609
" if run_rerank:\n",
16251610
" retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever)\n",
1626-
" reranker.top_n = vector_rerank_top_n\n",
1611+
" if rerank_device.value == \"NPU\":\n",
1612+
" reranker.top_n = vector_rerank_top_n\n",
16271613
" rag_chain = create_retrieval_chain(retriever, combine_docs_chain)\n",
16281614
"\n",
16291615
" return \"Vector database is Ready\"\n",
@@ -1735,6 +1721,7 @@
17351721
" update_retriever_fn=update_retriever,\n",
17361722
" model_name=llm_model_id.value,\n",
17371723
" language=model_language.value,\n",
1724+
" rerank_device=rerank_device.value,\n",
17381725
")\n",
17391726
"\n",
17401727
"try:\n",

notebooks/llm-rag-langchain/ov_langchain_helper.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,3 +736,68 @@ def compress_documents(
736736
)
737737
final_results.append(doc)
738738
return final_results
739+
740+
741+
class OpenVINOGenAIReranker(BaseDocumentCompressor):
742+
"""OpenVINO reranking models.
743+
744+
To use, you should have the ``openvino-genai`` python package installed.
745+
746+
Example:
747+
.. code-block:: python
748+
749+
from ov_langchain_helper import OpenVINOGenAIReranker
750+
751+
model_path = "./sentence-transformers/all-mpnet-base-v2"
752+
config_kwargs = {'top_n': 3}
753+
ov = OpenVINOGenAIReranker.from_model_path(
754+
model_path=model_path,
755+
device='CPU',
756+
config_kwargs=config_kwargs,
757+
)
758+
"""
759+
760+
ov_pipe: Any = None
761+
"""OpenVINO pipeline object."""
762+
top_n: int = 3
763+
"""return Top n texts. can not be updated after pipeline was created."""
764+
765+
model_config = ConfigDict(extra="forbid", protected_namespaces=())
766+
767+
@classmethod
768+
def from_model_path(cls, model_path: str, device: str = "CPU", config_kwargs: Dict[str, Any] = {}) -> OpenVINOGenAIReranker:
769+
"""Construct the openvino text embedding pipeline from model_path"""
770+
try:
771+
import openvino_genai
772+
773+
except ImportError:
774+
raise ImportError("Could not import OpenVINO GenAI package. " "Please install it with `pip install openvino-genai`.")
775+
776+
config = openvino_genai.TextRerankPipeline.Config()
777+
if "top_n" in config_kwargs:
778+
config.top_n = config_kwargs["top_n"]
779+
top_n = config_kwargs["top_n"]
780+
if "max_length" in config_kwargs:
781+
config.max_length = config_kwargs["max_length"]
782+
783+
ov_pipe = openvino_genai.TextRerankPipeline(model_path, device, config)
784+
785+
return cls(ov_pipe=ov_pipe)
786+
787+
def compress_documents(
788+
self,
789+
documents: Sequence[Document],
790+
query: str,
791+
callbacks: Optional[Callbacks] = None,
792+
) -> Sequence[Document]:
793+
docs = [doc.page_content for doc in documents]
794+
rerank_response = self.ov_pipe.rerank(query, docs)
795+
final_results = []
796+
for index, score in rerank_response:
797+
doc = Document(
798+
page_content=documents[index].page_content,
799+
metadata={"id": index, "relevance_score": score},
800+
)
801+
final_results.append(doc)
802+
803+
return final_results

notebooks/llm-rag-llamaindex/gradio_helper.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def make_demo(
3939
update_retriever_fn: Callable,
4040
model_name: str,
4141
language: Literal["English", "Chinese"] = "English",
42+
rerank_device: str = "auto",
4243
):
4344
examples = chinese_examples if (language == "Chinese") else english_examples
4445

@@ -189,8 +190,8 @@ def make_demo(
189190
value=2,
190191
step=1,
191192
label="Rerank top n",
192-
info="Number of rerank results",
193-
interactive=True,
193+
info="Number of rerank results(setted on creation step in GenAI pipeline).",
194+
interactive=(rerank_device == "NPU"),
194195
)
195196
with gr.Row():
196197
vector_search_top_k = gr.Slider(

notebooks/llm-rag-llamaindex/llm-rag-llamaindex.ipynb

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,14 +1154,14 @@
11541154
"\n",
11551155
"[back to top ⬆️](#Table-of-contents:)\n",
11561156
"\n",
1157-
"Now a Hugging Face rerank model can be supported by OpenVINO through [`OpenVINORerank`](https://docs.llamaindex.ai/en/stable/examples/node_postprocessor/openvino_rerank/) class of LlamaIndex.\n",
1157+
"Now a Hugging Face rerank model can be supported by OpenVINO via [GenAI](https://github.com/openvinotoolkit/openvino.genai) through `OpenVINOGenAIReranking` class. To run model on NPU use [`OpenVINORerank`](https://docs.llamaindex.ai/en/stable/examples/node_postprocessor/openvino_rerank/) class of LlamaIndex.\n",
11581158
"\n",
11591159
"> **Note**: Rerank can be skipped in RAG.\n"
11601160
]
11611161
},
11621162
{
11631163
"cell_type": "code",
1164-
"execution_count": 24,
1164+
"execution_count": null,
11651165
"id": "b67b39f2-8394-45fb-9b2b-ea63e267a2d3",
11661166
"metadata": {},
11671167
"outputs": [
@@ -1175,8 +1175,12 @@
11751175
],
11761176
"source": [
11771177
"from llama_index.postprocessor.openvino_rerank import OpenVINORerank\n",
1178+
"from ov_llamaindex_helper import OpenVINOGenAIReranking\n",
11781179
"\n",
1179-
"reranker = OpenVINORerank(model_id_or_path=rerank_model_id.value, device=rerank_device.value, top_n=2)"
1180+
"if USING_NPU:\n",
1181+
" reranker = OpenVINORerank(model_id_or_path=rerank_model_id.value, device=rerank_device.value, top_n=2)\n",
1182+
"else:\n",
1183+
" reranker = OpenVINOGenAIReranking(model_id_or_path=rerank_model_id.value, device=rerank_device.value, top_n=2)"
11801184
]
11811185
},
11821186
{
@@ -1765,6 +1769,7 @@
17651769
" update_retriever_fn=update_retriever,\n",
17661770
" model_name=llm_model_id.value,\n",
17671771
" language=model_language.value,\n",
1772+
" rerank_device=rerank_device.value,\n",
17681773
")\n",
17691774
"\n",
17701775
"try:\n",

notebooks/llm-rag-llamaindex/ov_llamaindex_helper.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,21 @@
22
DEFAULT_EMBED_BATCH_SIZE,
33
BaseEmbedding,
44
)
5+
from llama_index.core.postprocessor.types import BaseNodePostprocessor
56
from typing import Any, List, Optional, Dict
67
from llama_index.core.bridge.pydantic import Field, PrivateAttr
78
from llama_index.core.callbacks import CallbackManager
9+
from llama_index.core.callbacks import CBEventType, EventPayload
10+
from llama_index.core.instrumentation import get_dispatcher
11+
from llama_index.core.instrumentation.events.rerank import (
12+
ReRankEndEvent,
13+
ReRankStartEvent,
14+
)
15+
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle
16+
from llama_index.core.instrumentation import get_dispatcher
17+
18+
19+
dispatcher = get_dispatcher(__name__)
820

921

1022
class OpenVINOGenAIEmbedding(BaseEmbedding):
@@ -86,3 +98,86 @@ def _get_text_embedding(self, text: str) -> List[float]:
8698
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
8799
"""Get text embeddings."""
88100
return self._ov_pipe.embed_documents(texts)
101+
102+
103+
class OpenVINOGenAIReranking(BaseNodePostprocessor):
104+
model_id_or_path: str = Field(description="Huggingface model id or local path.")
105+
top_n: int = Field(description="Number of nodes to return sorted by score.")
106+
keep_retrieval_score: bool = Field(
107+
default=False,
108+
description="Whether to keep the retrieval score in metadata.",
109+
)
110+
111+
_ov_pipe: Any = PrivateAttr()
112+
113+
def __init__(
114+
self,
115+
model_id_or_path: str,
116+
max_length: Optional[int] = None,
117+
top_n: Optional[int] = 3,
118+
device: Optional[str] = "auto",
119+
model_kwargs: Dict[str, Any] = {},
120+
keep_retrieval_score: Optional[bool] = False,
121+
):
122+
try:
123+
import openvino_genai
124+
except ImportError:
125+
raise ImportError("Could not import OpenVINO GenAI package. " "Please install it with `pip install openvino-genai`.")
126+
127+
super().__init__(top_n=top_n, max_length=max_length, model_id_or_path=model_id_or_path, device=device, keep_retrieval_score=keep_retrieval_score)
128+
129+
config = openvino_genai.TextRerankPipeline.Config()
130+
config.top_n = top_n
131+
if max_length:
132+
config.max_length = max_length
133+
134+
ov_pipe = openvino_genai.TextRerankPipeline(model_id_or_path, device, config, **model_kwargs)
135+
136+
self._ov_pipe = ov_pipe
137+
138+
@classmethod
139+
def class_name(cls) -> str:
140+
return "OpenVINOGenAIReranking"
141+
142+
def _postprocess_nodes(
143+
self,
144+
nodes: List[NodeWithScore],
145+
query_bundle: Optional[QueryBundle] = None,
146+
) -> List[NodeWithScore]:
147+
dispatcher.event(
148+
ReRankStartEvent(
149+
query=query_bundle,
150+
nodes=nodes,
151+
top_n=self.top_n,
152+
model_name=self.model_id_or_path,
153+
)
154+
)
155+
156+
if query_bundle is None:
157+
raise ValueError("Missing query bundle in extra info.")
158+
if len(nodes) == 0:
159+
return []
160+
161+
nodes_text_list = [str(node.node.get_content(metadata_mode=MetadataMode.EMBED)) for node in nodes]
162+
163+
with self.callback_manager.event(
164+
CBEventType.RERANKING,
165+
payload={
166+
EventPayload.NODES: nodes,
167+
EventPayload.MODEL_NAME: self.model_id_or_path,
168+
EventPayload.QUERY_STR: query_bundle.query_str,
169+
EventPayload.TOP_K: self.top_n,
170+
},
171+
) as event:
172+
outputs = self._ov_pipe.rerank(query_bundle.query_str, nodes_text_list)
173+
174+
for node, score in zip(nodes, outputs):
175+
if self.keep_retrieval_score:
176+
# keep the retrieval score in metadata
177+
node.node.metadata["retrieval_score"] = node.score
178+
node.score = score
179+
180+
event.on_end(payload={EventPayload.NODES: nodes})
181+
182+
dispatcher.event(ReRankEndEvent(nodes=nodes))
183+
return nodes

0 commit comments

Comments
 (0)