From 04e7d53e4156e8dcc172ed219acbbb14115cac8f Mon Sep 17 00:00:00 2001 From: Eero Tamminen Date: Mon, 11 Aug 2025 17:29:56 +0300 Subject: [PATCH] Remove unused "chatqna_wrapper.py" file Signed-off-by: Eero Tamminen --- ChatQnA/chatqna_wrapper.py | 142 ------------------------------------- 1 file changed, 142 deletions(-) delete mode 100644 ChatQnA/chatqna_wrapper.py diff --git a/ChatQnA/chatqna_wrapper.py b/ChatQnA/chatqna_wrapper.py deleted file mode 100644 index 971d9edb4c..0000000000 --- a/ChatQnA/chatqna_wrapper.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import os - -from comps import MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceRoleType, ServiceType -from comps.cores.mega.utils import handle_message -from comps.cores.proto.api_protocol import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatMessage, - UsageInfo, -) -from comps.cores.proto.docarray import LLMParams, RerankerParms, RetrieverParms -from fastapi import Request -from fastapi.responses import StreamingResponse - -MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0") -MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888)) -EMBEDDING_SERVICE_HOST_IP = os.getenv("EMBEDDING_SERVICE_HOST_IP", "0.0.0.0") -EMBEDDING_SERVICE_PORT = int(os.getenv("EMBEDDING_SERVICE_PORT", 6000)) -RETRIEVER_SERVICE_HOST_IP = os.getenv("RETRIEVER_SERVICE_HOST_IP", "0.0.0.0") -RETRIEVER_SERVICE_PORT = int(os.getenv("RETRIEVER_SERVICE_PORT", 7000)) -RERANK_SERVICE_HOST_IP = os.getenv("RERANK_SERVICE_HOST_IP", "0.0.0.0") -RERANK_SERVICE_PORT = int(os.getenv("RERANK_SERVICE_PORT", 8000)) -LLM_SERVICE_HOST_IP = os.getenv("LLM_SERVICE_HOST_IP", "0.0.0.0") -LLM_SERVICE_PORT = int(os.getenv("LLM_SERVICE_PORT", 9000)) - - -class ChatQnAService: - def __init__(self, host="0.0.0.0", port=8000): - self.host = host - self.port = port - self.megaservice = ServiceOrchestrator() - self.endpoint = str(MegaServiceEndpoint.CHAT_QNA) - - def add_remote_service(self): - embedding = MicroService( - name="embedding", - host=EMBEDDING_SERVICE_HOST_IP, - port=EMBEDDING_SERVICE_PORT, - endpoint="/v1/embeddings", - use_remote_service=True, - service_type=ServiceType.EMBEDDING, - ) - retriever = MicroService( - name="retriever", - host=RETRIEVER_SERVICE_HOST_IP, - port=RETRIEVER_SERVICE_PORT, - endpoint="/v1/retrieval", - use_remote_service=True, - service_type=ServiceType.RETRIEVER, - ) - rerank = MicroService( - name="rerank", - host=RERANK_SERVICE_HOST_IP, - port=RERANK_SERVICE_PORT, - endpoint="/v1/reranking", - use_remote_service=True, - service_type=ServiceType.RERANK, - ) - llm = MicroService( - name="llm", - host=LLM_SERVICE_HOST_IP, - port=LLM_SERVICE_PORT, - endpoint="/v1/chat/completions", - use_remote_service=True, - service_type=ServiceType.LLM, - ) - self.megaservice.add(embedding).add(retriever).add(rerank).add(llm) - self.megaservice.flow_to(embedding, retriever) - self.megaservice.flow_to(retriever, rerank) - self.megaservice.flow_to(rerank, llm) - - async def handle_request(self, request: Request): - data = await request.json() - stream_opt = data.get("stream", True) - chat_request = ChatCompletionRequest.parse_obj(data) - prompt = handle_message(chat_request.messages) - parameters = LLMParams( - max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, - top_k=chat_request.top_k if chat_request.top_k else 10, - top_p=chat_request.top_p if chat_request.top_p else 0.95, - temperature=chat_request.temperature if chat_request.temperature else 0.01, - frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0, - presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0, - repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03, - stream=stream_opt, - chat_template=chat_request.chat_template if chat_request.chat_template else None, - ) - retriever_parameters = RetrieverParms( - search_type=chat_request.search_type if chat_request.search_type else "similarity", - k=chat_request.k if chat_request.k else 4, - distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None, - fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20, - lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5, - score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2, - ) - reranker_parameters = RerankerParms( - top_n=chat_request.top_n if chat_request.top_n else 1, - ) - result_dict, runtime_graph = await self.megaservice.schedule( - initial_inputs={"text": prompt}, - llm_parameters=parameters, - retriever_parameters=retriever_parameters, - reranker_parameters=reranker_parameters, - ) - for node, response in result_dict.items(): - if isinstance(response, StreamingResponse): - return response - last_node = runtime_graph.all_leaves()[-1] - response = result_dict[last_node]["text"] - choices = [] - usage = UsageInfo() - choices.append( - ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role="assistant", content=response), - finish_reason="stop", - ) - ) - return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage) - - def start(self): - self.service = MicroService( - self.__class__.__name__, - service_role=ServiceRoleType.MEGASERVICE, - host=self.host, - port=self.port, - endpoint=self.endpoint, - input_datatype=ChatCompletionRequest, - output_datatype=ChatCompletionResponse, - ) - self.service.add_route(self.endpoint, self.handle_request, methods=["POST"]) - self.service.start() - - -if __name__ == "__main__": - chatqna = ChatQnAService(port=MEGA_SERVICE_PORT) - chatqna.add_remote_service() - chatqna.start()