diff --git a/python/sglang/srt/entrypoints/openai/__init__.py b/python/sglang/srt/entrypoints/openai/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py new file mode 100644 index 00000000000..c7423ed1b43 --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -0,0 +1,539 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pydantic models for OpenAI API protocol""" + +import time +from typing import Dict, List, Optional, Union + +from pydantic import BaseModel, Field, field_validator, model_validator +from typing_extensions import Literal + + +class ModelCard(BaseModel): + """Model cards.""" + + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "sglang" + root: Optional[str] = None + max_model_len: Optional[int] = None + + +class ModelList(BaseModel): + """Model list consists of model cards.""" + + object: str = "list" + data: List[ModelCard] = Field(default_factory=list) + + +class ErrorResponse(BaseModel): + object: str = "error" + message: str + type: str + param: Optional[str] = None + code: int + + +class LogProbs(BaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) + + +class TopLogprob(BaseModel): + token: str + bytes: List[int] + logprob: float + + +class ChatCompletionTokenLogprob(BaseModel): + token: str + bytes: List[int] + logprob: float + top_logprobs: List[TopLogprob] + + +class ChoiceLogprobs(BaseModel): + # build for v1/chat/completions response + content: List[ChatCompletionTokenLogprob] + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + # only used to return cached tokens when --enable-cache-report is set + prompt_tokens_details: Optional[Dict[str, int]] = None + + +class StreamOptions(BaseModel): + include_usage: Optional[bool] = False + + +class JsonSchemaResponseFormat(BaseModel): + name: str + description: Optional[str] = None + # use alias to workaround pydantic conflict + schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None) + strict: Optional[bool] = False + + +class FileRequest(BaseModel): + # https://platform.openai.com/docs/api-reference/files/create + file: bytes # The File object (not file name) to be uploaded + purpose: str = ( + "batch" # The intended purpose of the uploaded file, default is "batch" + ) + + +class FileResponse(BaseModel): + id: str + object: str = "file" + bytes: int + created_at: int + filename: str + purpose: str + + +class FileDeleteResponse(BaseModel): + id: str + object: str = "file" + deleted: bool + + +class BatchRequest(BaseModel): + input_file_id: ( + str # The ID of an uploaded file that contains requests for the new batch + ) + endpoint: str # The endpoint to be used for all requests in the batch + completion_window: str # The time frame within which the batch should be processed + metadata: Optional[dict] = None # Optional custom metadata for the batch + + +class BatchResponse(BaseModel): + id: str + object: str = "batch" + endpoint: str + errors: Optional[dict] = None + input_file_id: str + completion_window: str + status: str = "validating" + output_file_id: Optional[str] = None + error_file_id: Optional[str] = None + created_at: int + in_progress_at: Optional[int] = None + expires_at: Optional[int] = None + finalizing_at: Optional[int] = None + completed_at: Optional[int] = None + failed_at: Optional[int] = None + expired_at: Optional[int] = None + cancelling_at: Optional[int] = None + cancelled_at: Optional[int] = None + request_counts: Optional[dict] = None + metadata: Optional[dict] = None + + +class CompletionRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/completions/create + model: str + prompt: Union[List[int], List[List[int]], str, List[str]] + best_of: Optional[int] = None + echo: bool = False + frequency_penalty: float = 0.0 + logit_bias: Optional[Dict[str, float]] = None + logprobs: Optional[int] = None + max_tokens: int = 16 + n: int = 1 + presence_penalty: float = 0.0 + seed: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + stream: bool = False + stream_options: Optional[StreamOptions] = None + suffix: Optional[str] = None + temperature: float = 1.0 + top_p: float = 1.0 + user: Optional[str] = None + + # Extra parameters for SRT backend only and will be ignored by OpenAI models. + top_k: int = -1 + min_p: float = 0.0 + min_tokens: int = 0 + json_schema: Optional[str] = None + regex: Optional[str] = None + ebnf: Optional[str] = None + repetition_penalty: float = 1.0 + stop_token_ids: Optional[List[int]] = None + no_stop_trim: bool = False + ignore_eos: bool = False + skip_special_tokens: bool = True + lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + session_params: Optional[Dict] = None + + # For PD disaggregation + bootstrap_host: Optional[str] = None + bootstrap_port: Optional[int] = None + bootstrap_room: Optional[int] = None + + @field_validator("max_tokens") + @classmethod + def validate_max_tokens_positive(cls, v): + if v is not None and v <= 0: + raise ValueError("max_tokens must be positive") + return v + + +class CompletionResponseChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Literal["stop", "length", "content_filter", "abort"] + matched_stop: Union[None, int, str] = None + + +class CompletionResponse(BaseModel): + id: str + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None + matched_stop: Union[None, int, str] = None + + +class CompletionStreamResponse(BaseModel): + id: str + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] + usage: Optional[UsageInfo] = None + + +class ChatCompletionMessageContentTextPart(BaseModel): + type: Literal["text"] + text: str + + +class ChatCompletionMessageContentImageURL(BaseModel): + url: str + detail: Optional[Literal["auto", "low", "high"]] = "auto" + + +class ChatCompletionMessageContentAudioURL(BaseModel): + url: str + + +class ChatCompletionMessageContentImagePart(BaseModel): + type: Literal["image_url"] + image_url: ChatCompletionMessageContentImageURL + modalities: Optional[Literal["image", "multi-images", "video"]] = "image" + + +class ChatCompletionMessageContentAudioPart(BaseModel): + type: Literal["audio_url"] + audio_url: ChatCompletionMessageContentAudioURL + + +ChatCompletionMessageContentPart = Union[ + ChatCompletionMessageContentTextPart, + ChatCompletionMessageContentImagePart, + ChatCompletionMessageContentAudioPart, +] + + +class FunctionResponse(BaseModel): + """Function response.""" + + name: Optional[str] = None + arguments: Optional[str] = None + + +class ToolCall(BaseModel): + """Tool call response.""" + + id: Optional[str] = None + index: Optional[int] = None + type: Literal["function"] = "function" + function: FunctionResponse + + +class ChatCompletionMessageGenericParam(BaseModel): + role: Literal["system", "assistant", "tool"] + content: Union[str, List[ChatCompletionMessageContentTextPart], None] + tool_call_id: Optional[str] = None + name: Optional[str] = None + reasoning_content: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) + + +class ChatCompletionMessageUserParam(BaseModel): + role: Literal["user"] + content: Union[str, List[ChatCompletionMessageContentPart]] + + +ChatCompletionMessageParam = Union[ + ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam +] + + +class ResponseFormat(BaseModel): + type: Literal["text", "json_object", "json_schema"] + json_schema: Optional[JsonSchemaResponseFormat] = None + + +class StructuresResponseFormat(BaseModel): + begin: str + schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None) + end: str + + +class StructuralTagResponseFormat(BaseModel): + type: Literal["structural_tag"] + structures: List[StructuresResponseFormat] + triggers: List[str] + + +class Function(BaseModel): + """Function descriptions.""" + + description: Optional[str] = Field(default=None, examples=[None]) + name: Optional[str] = None + parameters: Optional[object] = None + strict: bool = False + + +class Tool(BaseModel): + """Function wrapper.""" + + type: str = Field(default="function", examples=["function"]) + function: Function + + +class ToolChoiceFuncName(BaseModel): + """The name of tool choice function.""" + + name: Optional[str] = None + + +class ToolChoice(BaseModel): + """The tool choice definition.""" + + function: ToolChoiceFuncName + type: Literal["function"] = Field(default="function", examples=["function"]) + + +class ChatCompletionRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/chat/create + messages: List[ChatCompletionMessageParam] + model: str + frequency_penalty: float = 0.0 + logit_bias: Optional[Dict[str, float]] = None + logprobs: bool = False + top_logprobs: Optional[int] = None + max_tokens: Optional[int] = Field( + default=None, + deprecated="max_tokens is deprecated in favor of the max_completion_tokens field", + description="The maximum number of tokens that can be generated in the chat completion. ", + ) + max_completion_tokens: Optional[int] = Field( + default=None, + description="The maximum number of completion tokens for a chat completion request, " + "including visible output tokens and reasoning tokens. Input tokens are not included. ", + ) + n: int = 1 + presence_penalty: float = 0.0 + response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None + seed: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + stream: bool = False + stream_options: Optional[StreamOptions] = None + temperature: float = 0.7 + top_p: float = 1.0 + user: Optional[str] = None + tools: Optional[List[Tool]] = Field(default=None, examples=[None]) + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field( + default="auto", examples=["none"] + ) # noqa + + @model_validator(mode="before") + @classmethod + def set_tool_choice_default(cls, values): + if isinstance(values, dict): + if values.get("tool_choice") is None: + if values.get("tools") is None: + values["tool_choice"] = "none" + else: + values["tool_choice"] = "auto" + return values + + @field_validator("messages") + @classmethod + def validate_messages_not_empty(cls, v): + if not v: + raise ValueError("Messages cannot be empty") + return v + + # Extra parameters for SRT backend only and will be ignored by OpenAI models. + top_k: int = -1 + min_p: float = 0.0 + min_tokens: int = 0 + regex: Optional[str] = None + ebnf: Optional[str] = None + repetition_penalty: float = 1.0 + stop_token_ids: Optional[List[int]] = None + no_stop_trim: bool = False + ignore_eos: bool = False + continue_final_message: bool = False + skip_special_tokens: bool = True + lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + session_params: Optional[Dict] = None + separate_reasoning: bool = True + stream_reasoning: bool = True + chat_template_kwargs: Optional[Dict] = None + + # The request id. + rid: Optional[str] = None + + # For PD disaggregation + bootstrap_host: Optional[str] = None + bootstrap_port: Optional[int] = None + bootstrap_room: Optional[int] = None + + +class ChatMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + reasoning_content: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None + finish_reason: Literal[ + "stop", "length", "tool_calls", "content_filter", "function_call", "abort" + ] + matched_stop: Union[None, int, str] = None + + +class ChatCompletionResponse(BaseModel): + id: str + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + reasoning_content: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None + finish_reason: Optional[ + Literal["stop", "length", "tool_calls", "content_filter", "function_call"] + ] = None + matched_stop: Union[None, int, str] = None + + +class ChatCompletionStreamResponse(BaseModel): + id: str + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] + usage: Optional[UsageInfo] = None + + +class MultimodalEmbeddingInput(BaseModel): + text: Optional[str] = None + image: Optional[str] = None + + +EmbeddingInput = Union[ + List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput] +] + + +class EmbeddingRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/embeddings/create + input: EmbeddingInput + model: str + encoding_format: str = "float" + dimensions: int = None + user: Optional[str] = None + + # The request id. + rid: Optional[str] = None + + +class EmbeddingObject(BaseModel): + embedding: List[float] + index: int + object: str = "embedding" + + +class EmbeddingResponse(BaseModel): + data: List[EmbeddingObject] + model: str + object: str = "list" + usage: Optional[UsageInfo] = None + + +class ScoringRequest(BaseModel): + query: Optional[Union[str, List[int]]] = ( + None # Query text or pre-tokenized token IDs + ) + items: Optional[Union[str, List[str], List[List[int]]]] = ( + None # Item text(s) or pre-tokenized token IDs + ) + label_token_ids: Optional[List[int]] = ( + None # Token IDs to compute probabilities for + ) + apply_softmax: bool = False + item_first: bool = False + model: str + + +class ScoringResponse(BaseModel): + scores: List[ + List[float] + ] # List of lists of probabilities, each in the order of label_token_ids + model: str + usage: Optional[UsageInfo] = None + object: str = "scoring" + + +OpenAIServingRequest = Union[ + ChatCompletionRequest, CompletionRequest, EmbeddingRequest, ScoringRequest +] diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py new file mode 100644 index 00000000000..718378f5e0f --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -0,0 +1,178 @@ +import json +import logging +import uuid +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Union + +from fastapi import Request +from fastapi.responses import ORJSONResponse, StreamingResponse + +from sglang.srt.entrypoints.openai.protocol import ( + ErrorResponse, + OpenAIServingRequest, + UsageInfo, +) +from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.managers.tokenizer_manager import TokenizerManager + +logger = logging.getLogger(__name__) + + +# Base class for specific endpoint handlers +class OpenAIServingBase(ABC): + """Abstract base class for OpenAI endpoint handlers""" + + def __init__(self, tokenizer_manager: TokenizerManager): + self.tokenizer_manager = tokenizer_manager + + async def handle_request( + self, request: OpenAIServingRequest, raw_request: Request + ) -> Union[Any, StreamingResponse, ErrorResponse]: + """Handle the specific request type with common pattern""" + try: + # Validate request + error_msg = self._validate_request(request) + if error_msg: + return self.create_error_response(error_msg) + + # Convert to internal format + adapted_request, processed_request = self._convert_to_internal_request( + [request], [self._generate_request_id_base(request)] + ) + + # Note(Xinyuan): raw_request below is only used for detecting the connection of the client + if hasattr(request, "stream") and request.stream: + return await self._handle_streaming_request( + adapted_request, processed_request, raw_request + ) + else: + return await self._handle_non_streaming_request( + adapted_request, processed_request, raw_request + ) + + except Exception as e: + logger.error(f"Error in request: {e}") + return self.create_error_response( + message=f"Internal server error: {str(e)}", + err_type="InternalServerError", + status_code=500, + ) + + @abstractmethod + def _request_id_prefix(self) -> str: + """Generate request ID based on request type""" + pass + + def _generate_request_id_base(self, request: OpenAIServingRequest) -> str: + """Generate request ID based on request type""" + if rid := getattr(request, "rid", None): + return rid + + return f"{self._request_id_prefix()}{uuid.uuid4().hex}" + + @abstractmethod + def _convert_to_internal_request( + self, + all_requests: List[OpenAIServingRequest], + request_ids: List[str], + ) -> tuple[ + GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]] + ]: + """Convert OpenAI request to internal format""" + pass + + async def _handle_streaming_request( + self, + adapted_request: GenerateReqInput, + request: OpenAIServingRequest, + raw_request: Request, + ) -> StreamingResponse: + """Handle streaming request + + Override this method in child classes that support streaming requests. + """ + return self.create_error_response( + message=f"{self.__class__.__name__} does not support streaming requests", + err_type="NotImplementedError", + status_code=501, + ) + + async def _handle_non_streaming_request( + self, + adapted_request: GenerateReqInput, + request: OpenAIServingRequest, + raw_request: Request, + ) -> Union[Any, ErrorResponse]: + """Handle non-streaming request + + Override this method in child classes that support non-streaming requests. + """ + return self.create_error_response( + message=f"{self.__class__.__name__} does not support non-streaming requests", + err_type="NotImplementedError", + status_code=501, + ) + + def _validate_request(self, request: OpenAIServingRequest) -> Optional[str]: + """Validate request""" + pass + + def _calculate_streaming_usage_base( + self, + prompt_tokens: Dict[int, int], + completion_tokens: Dict[int, int], + cached_tokens: Dict[int, int], + n_choices: int, + ) -> UsageInfo: + """Calculate usage information for streaming responses (common logic)""" + total_prompt_tokens = sum( + tokens for i, tokens in prompt_tokens.items() if i % n_choices == 0 + ) + total_completion_tokens = sum(tokens for tokens in completion_tokens.values()) + + cache_report = self.tokenizer_manager.server_args.enable_cache_report + prompt_tokens_details = None + if cache_report: + cached_tokens_sum = sum(tokens for tokens in cached_tokens.values()) + if cached_tokens_sum > 0: + prompt_tokens_details = {"cached_tokens": cached_tokens_sum} + + return UsageInfo( + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens, + prompt_tokens_details=prompt_tokens_details, + ) + + def create_error_response( + self, + message: str, + err_type: str = "BadRequestError", + status_code: int = 400, + param: Optional[str] = None, + ) -> ORJSONResponse: + """Create an error response""" + error = ErrorResponse( + object="error", + message=message, + type=err_type, + param=param, + code=status_code, + ) + return ORJSONResponse(content=error.model_dump(), status_code=status_code) + + def create_streaming_error_response( + self, + message: str, + err_type: str = "BadRequestError", + status_code: int = 400, + ) -> str: + """Create a streaming error response""" + error = ErrorResponse( + object="error", + message=message, + type=err_type, + param=None, + code=status_code, + ) + return json.dumps({"error": error.model_dump()}) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py new file mode 100644 index 00000000000..54b490131ab --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -0,0 +1,938 @@ +import base64 +import json +import logging +import time +import uuid +from typing import Any, Dict, List, Optional, Union + +from fastapi import Request +from fastapi.responses import StreamingResponse + +from sglang.srt.conversation import generate_chat_conv +from sglang.srt.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatCompletionTokenLogprob, + ChatMessage, + ChoiceLogprobs, + DeltaMessage, + ErrorResponse, + FunctionResponse, + LogProbs, + ToolCall, + TopLogprob, +) +from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase +from sglang.srt.entrypoints.openai.utils import ( + aggregate_token_usage, + detect_template_content_format, + process_content_for_template_format, + to_openai_style_logprobs, +) +from sglang.srt.function_call.function_call_parser import FunctionCallParser +from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.reasoning_parser import ReasoningParser +from sglang.utils import convert_json_schema_to_str + +logger = logging.getLogger(__name__) + + +class OpenAIServingChat(OpenAIServingBase): + """Handler for chat completion requests""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Instance-specific cache for template content format detection + self._cached_chat_template = None + self._cached_template_format = None + + def _request_id_prefix(self) -> str: + return "chatcmpl-" + + def _validate_request(self, request: ChatCompletionRequest) -> Optional[str]: + """Validate chat messages format and content""" + if not (messages := request.messages): + return "Messages cannot be empty" + + # Check for alternating user/assistant pattern (optional validation) + roles = [msg.role for msg in messages] + + # First message should typically be from user or system + if roles[0] not in ["user", "system"]: + return "First message should be from 'user' or 'system'" + + # Check for consecutive assistant messages (which might indicate an error) + for i in range(1, len(roles)): + if roles[i] == "assistant" and roles[i - 1] == "assistant": + # This is actually allowed in some cases, so just warn + pass + + # Validate message content + for i, msg in enumerate(messages): + if msg.role == "user": + if not msg.content: + return f"User message at index {i} has no content" + elif msg.role == "assistant": + # Assistant messages can have no content if they have tool_calls + if not msg.content and not getattr(msg, "tool_calls", None): + return ( + f"Assistant message at index {i} has no content or tool calls" + ) + + return None + + def _convert_to_internal_request( + self, + all_requests: List[ChatCompletionRequest], + request_ids: List[str], + ) -> tuple[ + GenerateReqInput, Union[ChatCompletionRequest, List[ChatCompletionRequest]] + ]: + """Convert OpenAI chat completion request to internal format""" + input_ids = [] + prompts = [] + sampling_params_list = [] + image_data_list = [] + audio_data_list = [] + return_logprobs = [] + logprob_start_lens = [] + top_logprobs_nums = [] + modalities_list = [] + lora_paths = [] + + is_multimodal = self.tokenizer_manager.model_config.is_multimodal + + for request in all_requests: + # Process messages and apply chat template + ( + prompt, + prompt_ids, + image_data, + audio_data, + modalities, + stop, + tool_call_constraint, + ) = self._process_messages(request, is_multimodal) + + input_ids.append(prompt_ids) + prompts.append(prompt) + return_logprobs.append(request.logprobs) + logprob_start_lens.append(-1) + top_logprobs_nums.append(request.top_logprobs or 0) + lora_paths.append(request.lora_path) + + # Build sampling parameters + sampling_params = self._build_sampling_params( + request, stop, tool_call_constraint + ) + sampling_params_list.append(sampling_params) + + image_data_list.append(image_data) + audio_data_list.append(audio_data) + modalities_list.append(modalities) + + # Handle single vs multiple requests + if len(all_requests) == 1: + if is_multimodal: + prompt_kwargs = {"text": prompts[0]} + else: + if isinstance(input_ids[0], str): + prompt_kwargs = {"text": input_ids[0]} + else: + prompt_kwargs = {"input_ids": input_ids[0]} + + sampling_params_list = sampling_params_list[0] + image_data_list = image_data_list[0] + audio_data_list = audio_data_list[0] + return_logprobs = return_logprobs[0] + logprob_start_lens = logprob_start_lens[0] + top_logprobs_nums = top_logprobs_nums[0] + modalities_list = modalities_list[0] + lora_paths = lora_paths[0] + request_ids = request_ids[0] + else: + if is_multimodal: + prompt_kwargs = {"text": prompts} + else: + if isinstance(input_ids[0], str): + prompt_kwargs = {"text": input_ids} + else: + prompt_kwargs = {"input_ids": input_ids} + + adapted_request = GenerateReqInput( + **prompt_kwargs, + image_data=image_data_list, + audio_data=audio_data_list, + sampling_params=sampling_params_list, + return_logprob=return_logprobs, + logprob_start_len=logprob_start_lens, + top_logprobs_num=top_logprobs_nums, + stream=all_requests[0].stream, + return_text_in_logprobs=True, + rid=request_ids, + modalities=modalities_list, + lora_path=lora_paths, + bootstrap_host=all_requests[0].bootstrap_host, + bootstrap_port=all_requests[0].bootstrap_port, + bootstrap_room=all_requests[0].bootstrap_room, + ) + + return adapted_request, ( + all_requests if len(all_requests) > 1 else all_requests[0] + ) + + def _process_messages( + self, request: ChatCompletionRequest, is_multimodal: bool + ) -> tuple[ + str, + Union[str, List[int]], + Optional[Any], + Optional[Any], + List[str], + List[str], + Optional[Any], + ]: + """Process chat messages and apply chat template""" + tool_call_constraint = None + prompt = "" + prompt_ids = [] + + if not isinstance(request.messages, str): + # Apply chat template and its stop strings + tools = None + if request.tools and request.tool_choice != "none": + request.skip_special_tokens = False + if not isinstance(request.tool_choice, str): + tools = [ + item.function.model_dump() + for item in request.tools + if item.function.name == request.tool_choice.function.name + ] + else: + tools = [item.function.model_dump() for item in request.tools] + + tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser + parser = FunctionCallParser(request.tools, tool_call_parser) + tool_call_constraint = parser.get_structure_constraint( + request.tool_choice + ) + + # Use chat template + if ( + hasattr(self.tokenizer_manager, "chat_template_name") + and self.tokenizer_manager.chat_template_name is None + ): + prompt, prompt_ids, image_data, audio_data, modalities, stop = ( + self._apply_jinja_template(request, tools, is_multimodal) + ) + else: + prompt, image_data, audio_data, modalities, stop = ( + self._apply_conversation_template(request) + ) + if not is_multimodal: + prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt) + else: + # Use raw prompt + prompt_ids = request.messages + stop = request.stop or [] + image_data = None + audio_data = None + modalities = [] + prompt = request.messages + + return ( + prompt, + prompt_ids, + image_data, + audio_data, + modalities, + stop, + tool_call_constraint, + ) + + def _apply_jinja_template( + self, + request: ChatCompletionRequest, + tools: Optional[List[Dict]], + is_multimodal: bool, + ) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]: + """Apply Jinja chat template""" + openai_compatible_messages = [] + image_data = [] + audio_data = [] + modalities = [] + + # Detect template content format + current_template = self.tokenizer_manager.tokenizer.chat_template + if current_template != self._cached_chat_template: + self._cached_chat_template = current_template + self._cached_template_format = detect_template_content_format( + current_template + ) + logger.info( + f"Detected chat template content format: {self._cached_template_format}" + ) + + template_content_format = self._cached_template_format + + for message in request.messages: + if message.content is None: + message.content = "" + msg_dict = message.model_dump() + + # Process content based on detected template format + processed_msg = process_content_for_template_format( + msg_dict, + template_content_format, + image_data, + audio_data, + modalities, + ) + openai_compatible_messages.append(processed_msg) + + # Handle assistant prefix for continue_final_message + assistant_prefix = None + if ( + openai_compatible_messages + and openai_compatible_messages[-1]["role"] == "assistant" + ): + if request.continue_final_message: + assistant_prefix = openai_compatible_messages[-1]["content"] + openai_compatible_messages = openai_compatible_messages[:-1] + + try: + prompt_ids = self.tokenizer_manager.tokenizer.apply_chat_template( + openai_compatible_messages, + tokenize=True, + add_generation_prompt=True, + tools=tools, + **( + request.chat_template_kwargs if request.chat_template_kwargs else {} + ), + ) + except Exception: + # This except branch will be triggered when the chosen model + # has a different tools input format that is not compatible + # with openAI's apply_chat_template tool_call format, like Mistral. + tools = ( + [t if "function" in t else {"function": t} for t in tools] + if tools + else None + ) + prompt_ids = self.tokenizer_manager.tokenizer.apply_chat_template( + openai_compatible_messages, + tokenize=True, + add_generation_prompt=True, + tools=tools, + **( + request.chat_template_kwargs if request.chat_template_kwargs else {} + ), + ) + + if assistant_prefix: + encoded = self.tokenizer_manager.tokenizer.encode(assistant_prefix) + if encoded and encoded[0] == self.tokenizer_manager.tokenizer.bos_token_id: + encoded = encoded[1:] + prompt_ids += encoded + + if is_multimodal: + prompt = self.tokenizer_manager.tokenizer.decode(prompt_ids) + + stop = request.stop or [] + return prompt, prompt_ids, image_data, audio_data, modalities, stop + + def _apply_conversation_template( + self, request: ChatCompletionRequest + ) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str]]: + """Apply conversation template""" + conv = generate_chat_conv(request, self.tokenizer_manager.chat_template_name) + + # If we should continue the final assistant message, adjust the conversation. + if ( + request.continue_final_message + and request.messages + and request.messages[-1].role == "assistant" + ): + # Remove the auto-added blank assistant turn, if present. + if conv.messages and conv.messages[-1][1] is None: + conv.messages.pop() + # Rebuild the prompt from the conversation. + prompt = conv.get_prompt() + # Strip trailing stop tokens or separators that indicate end-of-assistant. + if isinstance(conv.stop_str, list): + for stop_token in conv.stop_str: + if prompt.endswith(stop_token): + prompt = prompt[: -len(stop_token)] + elif isinstance(conv.stop_str, str) and prompt.endswith(conv.stop_str): + prompt = prompt[: -len(conv.stop_str)] + if conv.sep and prompt.endswith(conv.sep): + prompt = prompt[: -len(conv.sep)] + if getattr(conv, "sep2", None) and prompt.endswith(conv.sep2): + prompt = prompt[: -len(conv.sep2)] + else: + prompt = conv.get_prompt() + + image_data = conv.image_data + audio_data = conv.audio_data + modalities = conv.modalities + stop = conv.stop_str or [] if not request.ignore_eos else [] + + if request.stop: + if isinstance(request.stop, str): + stop.append(request.stop) + else: + stop.extend(request.stop) + + return prompt, image_data, audio_data, modalities, stop + + def _build_sampling_params( + self, + request: ChatCompletionRequest, + stop: List[str], + tool_call_constraint: Optional[Any], + ) -> Dict[str, Any]: + """Build sampling parameters for the request""" + + sampling_params = { + "temperature": request.temperature, + "max_new_tokens": request.max_tokens or request.max_completion_tokens, + "min_new_tokens": request.min_tokens, + "stop": stop, + "stop_token_ids": request.stop_token_ids, + "top_p": request.top_p, + "top_k": request.top_k, + "min_p": request.min_p, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "repetition_penalty": request.repetition_penalty, + "regex": request.regex, + "ebnf": request.ebnf, + "n": request.n, + "no_stop_trim": request.no_stop_trim, + "ignore_eos": request.ignore_eos, + "skip_special_tokens": request.skip_special_tokens, + "logit_bias": request.logit_bias, + } + + if request.response_format and request.response_format.type == "json_schema": + sampling_params["json_schema"] = convert_json_schema_to_str( + request.response_format.json_schema.schema_ + ) + elif request.response_format and request.response_format.type == "json_object": + sampling_params["json_schema"] = '{"type": "object"}' + elif ( + request.response_format and request.response_format.type == "structural_tag" + ): + sampling_params["structural_tag"] = convert_json_schema_to_str( + request.response_format.model_dump(by_alias=True) + ) + + # Check if there are already existing output constraints + has_existing_constraints = ( + sampling_params.get("regex") + or sampling_params.get("ebnf") + or sampling_params.get("structural_tag") + or sampling_params.get("json_schema") + ) + + if tool_call_constraint and has_existing_constraints: + logger.warning("Constrained decoding is not compatible with tool calls.") + elif tool_call_constraint: + constraint_type, constraint_value = tool_call_constraint + if constraint_type == "structural_tag": + sampling_params[constraint_type] = convert_json_schema_to_str( + constraint_value.model_dump(by_alias=True) + ) + else: + sampling_params[constraint_type] = constraint_value + return sampling_params + + async def _handle_streaming_request( + self, + adapted_request: GenerateReqInput, + request: ChatCompletionRequest, + raw_request: Request, + ) -> StreamingResponse: + """Handle streaming chat completion request""" + + async def generate_stream_resp(): + parser_dict = {} + reasoning_parser_dict = {} + tool_call_first = True + is_firsts = {} + stream_buffers = {} + n_prev_tokens = {} + prompt_tokens = {} + completion_tokens = {} + cached_tokens = {} + + try: + async for content in self.tokenizer_manager.generate_request( + adapted_request, raw_request + ): + index = content.get("index", 0) + + is_first = is_firsts.get(index, True) + stream_buffer = stream_buffers.get(index, "") + n_prev_token = n_prev_tokens.get(index, 0) + + prompt_tokens[index] = content["meta_info"]["prompt_tokens"] + completion_tokens[index] = content["meta_info"]["completion_tokens"] + cached_tokens[index] = content["meta_info"].get("cached_tokens", 0) + + # Handle logprobs + choice_logprobs = None + if request.logprobs: + choice_logprobs = self._process_streaming_logprobs( + content, n_prev_token + ) + n_prev_token = len( + content["meta_info"]["output_token_logprobs"] + ) + + finish_reason = content["meta_info"]["finish_reason"] + finish_reason_type = ( + finish_reason["type"] if finish_reason else None + ) + + # First chunk with role + if is_first: + is_first = False + delta = DeltaMessage(role="assistant") + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=delta, + finish_reason=finish_reason_type, + matched_stop=( + finish_reason["matched"] + if finish_reason and "matched" in finish_reason + else None + ), + logprobs=choice_logprobs, + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=int(time.time()), + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + + # Process content delta + delta = content["text"][len(stream_buffer) :] + new_stream_buffer = stream_buffer + delta + + # Handle reasoning content + enable_thinking = getattr(request, "chat_template_kwargs", {}).get( + "enable_thinking", True + ) + if ( + self.tokenizer_manager.server_args.reasoning_parser + and request.separate_reasoning + and enable_thinking + ): + reasoning_text, delta = self._process_reasoning_stream( + index, delta, reasoning_parser_dict, content, request + ) + if reasoning_text: + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(reasoning_content=reasoning_text), + finish_reason=finish_reason_type, + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=int(time.time()), + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + + if not delta: + stream_buffers[index] = new_stream_buffer + is_firsts[index] = is_first + n_prev_tokens[index] = n_prev_token + continue + + # Handle tool calls + if request.tool_choice != "none" and request.tools: + async for chunk in self._process_tool_call_stream( + index, + delta, + parser_dict, + content, + request, + finish_reason_type, + ): + yield chunk + else: + # Regular content + if delta or not ( + request.stream_options + and request.stream_options.include_usage + ): + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(content=delta if delta else None), + finish_reason=( + None + if request.stream_options + and request.stream_options.include_usage + else finish_reason_type + ), + matched_stop=( + finish_reason["matched"] + if finish_reason and "matched" in finish_reason + else None + ), + logprobs=choice_logprobs, + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=int(time.time()), + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + + stream_buffers[index] = new_stream_buffer + is_firsts[index] = is_first + n_prev_tokens[index] = n_prev_token + + # Final chunk with usage + if request.stream_options and request.stream_options.include_usage: + usage = self._calculate_streaming_usage_base( + prompt_tokens, completion_tokens, cached_tokens, request.n + ) + else: + usage = None + + final_chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=int(time.time()), + choices=[ + ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(), + finish_reason=finish_reason_type, + ) + ], + model=request.model, + usage=usage, + ) + yield f"data: {final_chunk.model_dump_json()}\n\n" + + except Exception as e: + error = self.create_streaming_error_response(str(e)) + yield f"data: {error}\n\n" + + yield "data: [DONE]\n\n" + + return StreamingResponse( + generate_stream_resp(), + media_type="text/event-stream", + background=self.tokenizer_manager.create_abort_task(adapted_request), + ) + + async def _handle_non_streaming_request( + self, + adapted_request: GenerateReqInput, + request: ChatCompletionRequest, + raw_request: Request, + ) -> Union[ChatCompletionResponse, ErrorResponse]: + """Handle non-streaming chat completion request""" + try: + ret = await self.tokenizer_manager.generate_request( + adapted_request, raw_request + ).__anext__() + except ValueError as e: + return self.create_error_response(str(e)) + + if not isinstance(ret, list): + ret = [ret] + + response = self._build_chat_response( + request, + ret, + int(time.time()), + cache_report=self.tokenizer_manager.server_args.enable_cache_report, + tool_call_parser=self.tokenizer_manager.server_args.tool_call_parser, + reasoning_parser=self.tokenizer_manager.server_args.reasoning_parser, + ) + + return response + + def _build_chat_response( + self, + request: ChatCompletionRequest, + ret: List[Dict[str, Any]], + created: int, + cache_report: bool = False, + tool_call_parser: Optional[str] = None, + reasoning_parser: Optional[str] = None, + ) -> ChatCompletionResponse: + """Build chat completion response from generation results""" + choices = [] + + for idx, ret_item in enumerate(ret): + # Process logprobs + choice_logprobs = None + if request.logprobs: + choice_logprobs = self._process_response_logprobs(ret_item) + + finish_reason = ret_item["meta_info"]["finish_reason"] + text = ret_item["text"] + + # Handle reasoning content + reasoning_text = None + enable_thinking = getattr(request, "chat_template_kwargs", {}).get( + "enable_thinking", True + ) + if reasoning_parser and request.separate_reasoning and enable_thinking: + try: + parser = ReasoningParser( + model_type=reasoning_parser, stream_reasoning=False + ) + reasoning_text, text = parser.parse_non_stream(text) + except Exception as e: + logger.error(f"Reasoning parsing error: {e}") + return self.create_error_response( + "Failed to parse reasoning content", + err_type="InternalServerError", + status_code=500, + ) + + # Handle tool calls + tool_calls = None + if request.tool_choice != "none" and request.tools: + tool_calls, text, finish_reason = self._process_tool_calls( + text, request.tools, tool_call_parser, finish_reason + ) + + choice_data = ChatCompletionResponseChoice( + index=idx, + message=ChatMessage( + role="assistant", + content=text if text else None, + tool_calls=tool_calls, + reasoning_content=reasoning_text if reasoning_text else None, + ), + logprobs=choice_logprobs, + finish_reason=finish_reason["type"] if finish_reason else None, + matched_stop=( + finish_reason["matched"] + if finish_reason and "matched" in finish_reason + else None + ), + ) + choices.append(choice_data) + + # Calculate usage + usage = aggregate_token_usage(ret, request.n, cache_report) + + return ChatCompletionResponse( + id=ret[0]["meta_info"]["id"], + created=created, + model=request.model, + choices=choices, + usage=usage, + ) + + def _process_logprobs_tokens( + self, logprobs: LogProbs, use_token_index: bool = False + ) -> List[ChatCompletionTokenLogprob]: + """Common helper to process logprobs tokens for both streaming and non-streaming + + Args: + logprobs: LogProbs data from model + use_token_index: True for non-streaming (use token_idx), False for streaming (use index 0) + """ + token_logprobs = [] + + for token_idx, (token, logprob) in enumerate( + zip(logprobs.tokens, logprobs.token_logprobs) + ): + token_bytes = list(token.encode("utf-8")) + top_logprobs = [] + if logprobs.top_logprobs: + # - Non-streaming (use_token_index=True): uses token_idx for full data + # - Streaming (use_token_index=False): uses index 0 for pre-sliced data + top_logprobs_idx = token_idx if use_token_index else 0 + for top_token, top_logprob in logprobs.top_logprobs[ + top_logprobs_idx + ].items(): + top_token_bytes = list(top_token.encode("utf-8")) + top_logprobs.append( + TopLogprob( + token=top_token, + bytes=top_token_bytes, + logprob=top_logprob, + ) + ) + token_logprobs.append( + ChatCompletionTokenLogprob( + token=token, + bytes=token_bytes, + logprob=logprob, + top_logprobs=top_logprobs, + ) + ) + + return token_logprobs + + def _process_response_logprobs(self, ret_item: Dict[str, Any]) -> ChoiceLogprobs: + """Process logprobs for non-streaming response""" + logprobs = to_openai_style_logprobs( + output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"], + output_top_logprobs=ret_item["meta_info"].get("output_top_logprobs", None), + ) + + token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True) + return ChoiceLogprobs(content=token_logprobs) + + def _process_tool_calls( + self, + text: str, + tools: List[Any], + tool_call_parser: Optional[str], + finish_reason: Dict[str, Any], + ) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]: + """Process tool calls in the response""" + parser = FunctionCallParser(tools, tool_call_parser) + if parser.has_tool_call(text): + if finish_reason["type"] == "stop": + finish_reason["type"] = "tool_calls" + finish_reason["matched"] = None + try: + text, call_info_list = parser.parse_non_stream(text) + tool_calls = [ + ToolCall( + id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}", + function=FunctionResponse( + name=call_info.name, arguments=call_info.parameters + ), + ) + for call_info in call_info_list + ] + return tool_calls, text, finish_reason + except Exception as e: + logger.error(f"Tool call parsing error: {e}") + # Return error but don't fail the whole request + return None, text, finish_reason + + return None, text, finish_reason + + def _process_streaming_logprobs( + self, content: Dict[str, Any], n_prev_token: int + ) -> ChoiceLogprobs: + """Process logprobs for streaming response""" + logprobs = to_openai_style_logprobs( + output_token_logprobs=content["meta_info"]["output_token_logprobs"][ + n_prev_token: + ], + output_top_logprobs=content["meta_info"].get("output_top_logprobs", [])[ + n_prev_token: + ], + ) + + token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=False) + return ChoiceLogprobs(content=token_logprobs) + + def _process_reasoning_stream( + self, + index: int, + delta: str, + reasoning_parser_dict: Dict[int, ReasoningParser], + content: Dict[str, Any], + request: ChatCompletionRequest, + ) -> tuple[Optional[str], str]: + """Process reasoning content in streaming response""" + if index not in reasoning_parser_dict: + reasoning_parser_dict[index] = ReasoningParser( + self.tokenizer_manager.server_args.reasoning_parser, + request.stream_reasoning, + ) + reasoning_parser = reasoning_parser_dict[index] + return reasoning_parser.parse_stream_chunk(delta) + + async def _process_tool_call_stream( + self, + index: int, + delta: str, + parser_dict: Dict[int, FunctionCallParser], + content: Dict[str, Any], + request: ChatCompletionRequest, + finish_reason_type: Optional[str], + ): + """Process tool calls in streaming response""" + if index not in parser_dict: + parser_dict[index] = FunctionCallParser( + tools=request.tools, + tool_call_parser=self.tokenizer_manager.server_args.tool_call_parser, + ) + parser = parser_dict[index] + + normal_text, calls = parser.parse_stream_chunk(delta) + + # Yield normal text + if normal_text: + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(content=normal_text), + finish_reason=finish_reason_type, + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=int(time.time()), + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + + # Yield tool calls + for call_item in calls: + if finish_reason_type == "stop": + # Handle remaining arguments + latest_delta_len = 0 + if isinstance(call_item.parameters, str): + latest_delta_len = len(call_item.parameters) + + expected_call = json.dumps( + parser.detector.prev_tool_call_arr[index].get("arguments", {}), + ensure_ascii=False, + ) + actual_call = parser.detector.streamed_args_for_tool[index] + if latest_delta_len > 0: + actual_call = actual_call[:-latest_delta_len] + remaining_call = expected_call.replace(actual_call, "", 1) + call_item.parameters = remaining_call + finish_reason_type = "tool_calls" + + tool_call = ToolCall( + id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}", + index=call_item.tool_index, + function=FunctionResponse( + name=call_item.name, + arguments=call_item.parameters, + ), + ) + + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(tool_calls=[tool_call]), + finish_reason=( + None + if request.stream_options and request.stream_options.include_usage + else finish_reason_type + ), + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=int(time.time()), + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py new file mode 100644 index 00000000000..af501727562 --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -0,0 +1,467 @@ +import time +from typing import Any, Dict, List, Optional, Union + +from fastapi import Request +from fastapi.responses import StreamingResponse + +from sglang.srt.code_completion_parser import ( + generate_completion_prompt_from_request, + is_completion_template_defined, +) +from sglang.srt.entrypoints.openai.protocol import ( + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + ErrorResponse, +) +from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase +from sglang.srt.entrypoints.openai.utils import ( + aggregate_token_usage, + to_openai_style_logprobs, +) +from sglang.srt.managers.io_struct import GenerateReqInput + + +class OpenAIServingCompletion(OpenAIServingBase): + """Handler for completion requests""" + + def _request_id_prefix(self) -> str: + return "cmpl-" + + def _validate_request(self, request: CompletionRequest) -> Optional[str]: + """Validate completion prompt format and content""" + if not (prompt := request.prompt): + return "Prompt cannot be None" + + if isinstance(prompt, str): + if not prompt.strip(): + return "Prompt cannot be empty or whitespace only" + elif isinstance(prompt, list): + if not prompt: + return "Prompt list cannot be empty" + + # Check if it's a list of strings + if all(isinstance(item, str) for item in prompt): + for i, item in enumerate(prompt): + if not item.strip(): + return f"Prompt at index {i} cannot be empty or whitespace only" + + # Check if it's a list of token IDs (integers) + elif all(isinstance(item, int) for item in prompt): + if any(item < 0 for item in prompt): + return "Token IDs must be non-negative" + + # Check if it's a list of lists (multiple token sequences) + elif all(isinstance(item, list) for item in prompt): + for i, item in enumerate(prompt): + if not item: + return f"Token sequence at index {i} cannot be empty" + if not all(isinstance(token, int) for token in item): + return f"Token sequence at index {i} must contain only integers" + if any(token < 0 for token in item): + return ( + f"Token sequence at index {i} contains negative token IDs" + ) + else: + return "Prompt must be string, list of strings, list of integers, or list of integer lists" + else: + return "Prompt must be string or list" + + return None + + def _convert_to_internal_request( + self, + all_requests: List[CompletionRequest], + request_ids: List[str], + ) -> tuple[GenerateReqInput, Union[CompletionRequest, List[CompletionRequest]]]: + """Convert OpenAI completion request to internal format""" + # Validate batch requests + if len(all_requests) > 1: + first_prompt_type = type(all_requests[0].prompt) + for request in all_requests: + assert ( + type(request.prompt) is first_prompt_type + ), "All prompts must be of the same type in file input settings" + if request.n > 1: + raise ValueError( + "Parallel sampling is not supported for completions from files" + ) + + prompts = [] + sampling_params_list = [] + return_logprobs = [] + logprob_start_lens = [] + top_logprobs_nums = [] + lora_paths = [] + + for request in all_requests: + # Process prompt + prompt = request.prompt + if is_completion_template_defined(): + prompt = generate_completion_prompt_from_request(request) + + prompts.append(prompt) + + lora_paths.append(request.lora_path) + + # Set logprob start length based on echo and logprobs + if request.echo and request.logprobs: + current_logprob_start_len = 0 + else: + current_logprob_start_len = -1 + + # Build sampling parameters + sampling_params = self._build_sampling_params(request) + sampling_params_list.append(sampling_params) + + return_logprobs.append(request.logprobs is not None) + logprob_start_lens.append(current_logprob_start_len) + top_logprobs_nums.append( + request.logprobs if request.logprobs is not None else 0 + ) + + # Handle single vs multiple requests + if len(all_requests) == 1: + if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): + prompt_kwargs = {"text": prompts[0]} + else: + prompt_kwargs = {"input_ids": prompts[0]} + sampling_params_list = sampling_params_list[0] + return_logprobs = return_logprobs[0] + logprob_start_lens = logprob_start_lens[0] + top_logprobs_nums = top_logprobs_nums[0] + lora_paths = lora_paths[0] + request_ids = request_ids[0] + else: + if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): + prompt_kwargs = {"text": prompts} + else: + prompt_kwargs = {"input_ids": prompts} + + adapted_request = GenerateReqInput( + **prompt_kwargs, + sampling_params=sampling_params_list, + return_logprob=return_logprobs, + top_logprobs_num=top_logprobs_nums, + logprob_start_len=logprob_start_lens, + return_text_in_logprobs=True, + stream=all_requests[0].stream, + rid=request_ids, + lora_path=lora_paths, + bootstrap_host=all_requests[0].bootstrap_host, + bootstrap_port=all_requests[0].bootstrap_port, + bootstrap_room=all_requests[0].bootstrap_room, + ) + + return adapted_request, ( + all_requests if len(all_requests) > 1 else all_requests[0] + ) + + def _build_sampling_params(self, request: CompletionRequest) -> Dict[str, Any]: + """Build sampling parameters for the request""" + # Start with common parameters + sampling_params = { + "temperature": request.temperature, + "max_new_tokens": request.max_tokens, + "min_new_tokens": request.min_tokens, + "stop": request.stop, + "stop_token_ids": request.stop_token_ids, + "top_p": request.top_p, + "top_k": request.top_k, + "min_p": request.min_p, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "repetition_penalty": request.repetition_penalty, + "regex": request.regex, + "json_schema": request.json_schema, + "ebnf": request.ebnf, + "n": request.n, + "no_stop_trim": request.no_stop_trim, + "ignore_eos": request.ignore_eos, + "skip_special_tokens": request.skip_special_tokens, + "logit_bias": request.logit_bias, + } + + # No additional completion-specific parameters needed currently + # (json_schema is already handled in base method) + + return sampling_params + + async def _handle_streaming_request( + self, + adapted_request: GenerateReqInput, + request: CompletionRequest, + raw_request: Request, + ) -> StreamingResponse: + """Handle streaming completion request""" + created = int(time.time()) + + async def generate_stream_resp(): + stream_buffers = {} + n_prev_tokens = {} + prompt_tokens = {} + completion_tokens = {} + cached_tokens = {} + + try: + async for content in self.tokenizer_manager.generate_request( + adapted_request, raw_request + ): + index = content.get("index", 0) + + stream_buffer = stream_buffers.get(index, "") + n_prev_token = n_prev_tokens.get(index, 0) + + text = content["text"] + prompt_tokens[index] = content["meta_info"]["prompt_tokens"] + completion_tokens[index] = content["meta_info"]["completion_tokens"] + cached_tokens[index] = content["meta_info"].get("cached_tokens", 0) + + # Handle echo for first chunk + if not stream_buffer: # The first chunk + if request.echo: + echo_text = self._get_echo_text(request, index) + text = echo_text + text + + # Handle logprobs + logprobs = None + if request.logprobs is not None: + # The first chunk and echo is enabled. + if not stream_buffer and request.echo: + input_token_logprobs = content["meta_info"][ + "input_token_logprobs" + ] + input_top_logprobs = content["meta_info"][ + "input_top_logprobs" + ] + else: + input_token_logprobs = None + input_top_logprobs = None + + logprobs = to_openai_style_logprobs( + input_token_logprobs=input_token_logprobs, + input_top_logprobs=input_top_logprobs, + output_token_logprobs=content["meta_info"][ + "output_token_logprobs" + ][n_prev_token:], + output_top_logprobs=content["meta_info"][ + "output_top_logprobs" + ][n_prev_token:], + ) + n_prev_token = len( + content["meta_info"]["output_token_logprobs"] + ) + + # Generate delta + delta = text[len(stream_buffer) :] + stream_buffer = stream_buffer + delta + finish_reason = content["meta_info"]["finish_reason"] + + choice_data = CompletionResponseStreamChoice( + index=index, + text=delta, + logprobs=logprobs, + finish_reason=finish_reason["type"] if finish_reason else None, + matched_stop=( + finish_reason["matched"] + if finish_reason and "matched" in finish_reason + else None + ), + ) + chunk = CompletionStreamResponse( + id=content["meta_info"]["id"], + created=created, + object="text_completion", + choices=[choice_data], + model=request.model, + ) + + stream_buffers[index] = stream_buffer + n_prev_tokens[index] = n_prev_token + + yield f"data: {chunk.model_dump_json()}\n\n" + + # Handle final usage chunk + if request.stream_options and request.stream_options.include_usage: + usage = self._calculate_streaming_usage_base( + prompt_tokens, completion_tokens, cached_tokens, request.n + ) + final_usage_chunk = CompletionStreamResponse( + id=content["meta_info"]["id"], + created=created, + choices=[], + model=request.model, + usage=usage, + ) + final_usage_data = final_usage_chunk.model_dump_json( + exclude_none=True + ) + yield f"data: {final_usage_data}\n\n" + + except Exception as e: + error = self.create_streaming_error_response(str(e)) + yield f"data: {error}\n\n" + + yield "data: [DONE]\n\n" + + return StreamingResponse( + generate_stream_resp(), + media_type="text/event-stream", + background=self.tokenizer_manager.create_abort_task(adapted_request), + ) + + async def _handle_non_streaming_request( + self, + adapted_request: GenerateReqInput, + request: CompletionRequest, + raw_request: Request, + ) -> Union[CompletionResponse, ErrorResponse]: + """Handle non-streaming completion request""" + try: + generator = self.tokenizer_manager.generate_request( + adapted_request, raw_request + ) + ret = await generator.__anext__() + except ValueError as e: + return self.create_error_response(str(e)) + + if not isinstance(ret, list): + ret = [ret] + + response = self._build_completion_response( + request, + ret, + int(time.time()), + cache_report=self.tokenizer_manager.server_args.enable_cache_report, + ) + + return response + + def _build_completion_response( + self, + request: CompletionRequest, + ret: List[Dict[str, Any]], + created: int, + cache_report: bool = False, + ) -> CompletionResponse: + """Build completion response from generation results""" + choices = [] + echo = False + + # Prepare echo prompts if needed + echo_prompts = [] + if (not isinstance(request, list)) and request.echo: + echo_prompts = self._prepare_echo_prompts(request) + echo = True + + for idx, ret_item in enumerate(ret): + text = ret_item["text"] + + # Handle echo + if isinstance(request, list) and request[idx].echo: + echo = True + text = request[idx].prompt + text + elif echo and not isinstance(request, list): + prompt_index = idx // request.n + text = echo_prompts[prompt_index] + text + + # Handle logprobs + logprobs = None + if isinstance(request, list) and request[idx].logprobs is not None: + logprobs = True + elif (not isinstance(request, list)) and request.logprobs is not None: + logprobs = True + + if logprobs: + if echo: + input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"] + input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"] + else: + input_token_logprobs = None + input_top_logprobs = None + + logprobs = to_openai_style_logprobs( + input_token_logprobs=input_token_logprobs, + input_top_logprobs=input_top_logprobs, + output_token_logprobs=ret_item["meta_info"][ + "output_token_logprobs" + ], + output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"], + ) + + finish_reason = ret_item["meta_info"]["finish_reason"] + + choice_data = CompletionResponseChoice( + index=idx, + text=text, + logprobs=logprobs, + finish_reason=finish_reason["type"] if finish_reason else None, + matched_stop=( + finish_reason["matched"] + if finish_reason and "matched" in finish_reason + else None + ), + ) + choices.append(choice_data) + + # Calculate usage + usage = aggregate_token_usage(ret, request.n, cache_report) + + return CompletionResponse( + id=ret[0]["meta_info"]["id"], + model=request.model, + created=created, + choices=choices, + usage=usage, + ) + + def _get_echo_text(self, request: CompletionRequest, index: int) -> str: + """Get echo text for streaming response""" + if isinstance(request.prompt, str): + # for the case of single str prompts + return request.prompt + elif isinstance(request.prompt, list): + if isinstance(request.prompt[0], str): + # for the case of multiple str prompts + return request.prompt[index // request.n] + elif isinstance(request.prompt[0], int): + # for the case of single token ids prompt + return self.tokenizer_manager.tokenizer.decode( + request.prompt, skip_special_tokens=True + ) + elif isinstance(request.prompt[0], list) and isinstance( + request.prompt[0][0], int + ): + # for the case of multiple token ids prompts + return self.tokenizer_manager.tokenizer.decode( + request.prompt[index // request.n], + skip_special_tokens=True, + ) + return "" + + def _prepare_echo_prompts(self, request: CompletionRequest) -> List[str]: + """Prepare echo prompts for non-streaming response""" + # TODO: handle the case prompt is token ids + if isinstance(request.prompt, list) and isinstance(request.prompt[0], str): + # for the case of multiple str prompts + return request.prompt + elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list): + # for the case of multiple token ids prompts + return [ + self.tokenizer_manager.tokenizer.decode( + prompt, skip_special_tokens=True + ) + for prompt in request.prompt + ] + elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int): + # for the case of single token ids prompt + return [ + self.tokenizer_manager.tokenizer.decode( + request.prompt, skip_special_tokens=True + ) + ] + else: + # for the case of single str prompt + return [request.prompt] diff --git a/python/sglang/srt/entrypoints/openai/serving_embedding.py b/python/sglang/srt/entrypoints/openai/serving_embedding.py new file mode 100644 index 00000000000..33d1e591878 --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/serving_embedding.py @@ -0,0 +1,227 @@ +from typing import Any, Dict, List, Optional, Union + +from fastapi import Request + +from sglang.srt.conversation import generate_embedding_convs +from sglang.srt.entrypoints.openai.protocol import ( + EmbeddingObject, + EmbeddingRequest, + EmbeddingResponse, + ErrorResponse, + MultimodalEmbeddingInput, + UsageInfo, +) +from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase +from sglang.srt.managers.io_struct import EmbeddingReqInput + + +class OpenAIServingEmbedding(OpenAIServingBase): + """Handler for embedding requests""" + + def _request_id_prefix(self) -> str: + return "embd-" + + def _validate_request(self, request: EmbeddingRequest) -> Optional[str]: + """Validate that the input is not empty or whitespace only.""" + if not (input := request.input): + return "Input cannot be empty" + + # Handle single string + if isinstance(input, str): + if not input.strip(): + return "Input cannot be empty or whitespace only" + return None + + # Handle list inputs + if isinstance(input, list): + if len(input) == 0: + return "Input cannot be empty" + + # Check first element to determine type + first_item = input[0] + + if isinstance(first_item, str): + # List of strings + for i, item in enumerate(input): + if not isinstance(item, str): + return f"All items in input list must be strings" + if not item.strip(): + return f"Input at index {i} cannot be empty or whitespace only" + elif isinstance(first_item, int): + # List of integers (token IDs) + for i, item in enumerate(input): + if not isinstance(item, int): + return f"All items in input list must be integers" + if item < 0: + return f"Token ID at index {i} must be non-negative" + elif isinstance(first_item, list): + # List of lists (multiple token sequences) + for i, item in enumerate(input): + if not isinstance(item, list): + return f"Input at index {i} must be a list" + if not item: + return f"Input at index {i} cannot be empty" + if not all(isinstance(token, int) for token in item): + return f"Input at index {i} must contain only integers" + if any(token < 0 for token in item): + return f"Input at index {i} contains negative token IDs" + # Note: MultimodalEmbeddingInput validation would be handled by Pydantic + + return None + + def _convert_to_internal_request( + self, + all_requests: List[EmbeddingRequest], + request_ids: List[str], + ) -> tuple[EmbeddingReqInput, Union[EmbeddingRequest, List[EmbeddingRequest]]]: + """Convert OpenAI embedding request to internal format""" + prompts = [request.input for request in all_requests] + + # Handle single vs multiple requests + if len(all_requests) == 1: + prompt = prompts[0] + if isinstance(prompt, str): + # Single string input + prompt_kwargs = {"text": prompt} + elif isinstance(prompt, list): + if len(prompt) > 0 and isinstance(prompt[0], str): + # List of strings + prompt_kwargs = {"text": prompt} + elif len(prompt) > 0 and isinstance( + prompt[0], MultimodalEmbeddingInput + ): + # Handle multimodal embedding inputs + texts = [] + images = [] + for item in prompt: + # Use padding for text if None - this could be improved + texts.append(item.text if item.text is not None else "padding") + images.append(item.image if item.image is not None else None) + + generate_prompts = [] + # Check if we have a chat template for multimodal embeddings + # This would need to be passed in from the server configuration + chat_template_name = getattr( + self.tokenizer_manager, "chat_template_name", None + ) + if chat_template_name is not None: + convs = generate_embedding_convs( + texts, images, chat_template_name + ) + for conv in convs: + generate_prompts.append(conv.get_prompt()) + else: + generate_prompts = texts + + if len(generate_prompts) == 1: + prompt_kwargs = { + "text": generate_prompts[0], + "image_data": images[0], + } + else: + prompt_kwargs = { + "text": generate_prompts, + "image_data": images, + } + else: + # List of integers (token IDs) or empty list + prompt_kwargs = {"input_ids": prompt} + else: + # Other types (should not happen but handle gracefully) + prompt_kwargs = {"input_ids": prompt} + # Use the passed request_ids for single request + final_request_id = request_ids[0] if len(all_requests) == 1 else request_ids + else: + # Handle batch requests + if len(prompts) > 0: + # Validate that all prompts have the same type + first_prompt = prompts[0] + first_type = type(first_prompt) + for i, prompt in enumerate(prompts[1:], 1): + if type(prompt) != first_type: + raise AssertionError( + f"All prompts in batch must have the same type, but prompt at index {i} has different type" + ) + + if isinstance(first_prompt, str): + # Batch of strings + prompt_kwargs = {"text": prompts} + elif isinstance(first_prompt, list): + if len(first_prompt) > 0 and isinstance(first_prompt[0], str): + # Batch of lists of strings + prompt_kwargs = {"text": prompts} + elif len(first_prompt) > 0 and isinstance( + first_prompt[0], MultimodalEmbeddingInput + ): + # Handle multimodal batch requests + raise NotImplementedError( + "Multiple requests with multimodal inputs are not supported yet" + ) + else: + # Batch of token ID lists + prompt_kwargs = {"input_ids": prompts} + else: + # Other types + prompt_kwargs = {"input_ids": prompts} + else: + prompt_kwargs = {"input_ids": prompts} + # Use the passed request_ids for batch requests + final_request_id = request_ids + + adapted_request = EmbeddingReqInput( + rid=final_request_id, + **prompt_kwargs, + ) + + return adapted_request, ( + all_requests[0] if len(all_requests) == 1 else all_requests + ) + + async def _handle_non_streaming_request( + self, + adapted_request: EmbeddingReqInput, + request: EmbeddingRequest, + raw_request: Request, + ) -> Union[EmbeddingResponse, ErrorResponse]: + """Handle the embedding request""" + try: + ret = await self.tokenizer_manager.generate_request( + adapted_request, raw_request + ).__anext__() + except ValueError as e: + return self.create_error_response(str(e)) + + if not isinstance(ret, list): + ret = [ret] + + response = self._build_embedding_response( + ret, self.tokenizer_manager.model_path + ) + return response + + def _build_embedding_response( + self, ret: List[Dict[str, Any]], model_path: str + ) -> EmbeddingResponse: + """Build the embedding response""" + embedding_objects = [] + prompt_tokens = 0 + + for idx, ret_item in enumerate(ret): + embedding_objects.append( + EmbeddingObject( + embedding=ret_item["embedding"], + index=idx, + ) + ) + # Handle missing prompt_tokens gracefully + meta_info = ret_item.get("meta_info", {}) + prompt_tokens += meta_info.get("prompt_tokens", 0) + + return EmbeddingResponse( + data=embedding_objects, + model=model_path, + usage=UsageInfo( + prompt_tokens=prompt_tokens, + total_tokens=prompt_tokens, + ), + ) diff --git a/python/sglang/srt/entrypoints/openai/utils.py b/python/sglang/srt/entrypoints/openai/utils.py new file mode 100644 index 00000000000..53c67831cdb --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/utils.py @@ -0,0 +1,264 @@ +import logging +from typing import Any, Dict, List, Optional + +import jinja2.nodes +import transformers.utils.chat_template_utils as hf_chat_utils + +from sglang.srt.entrypoints.openai.protocol import LogProbs, UsageInfo + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# JINJA TEMPLATE CONTENT FORMAT DETECTION +# ============================================================================ +# +# This adapts vLLM's approach for detecting chat template content format: +# https://github.com/vllm-project/vllm/blob/02f0c7b220422792f5e53de2a7d51d2d3ff2df28/vllm/entrypoints/chat_utils.py#L296-L313 +# - Analyzes Jinja template AST to detect content iteration patterns +# - 'openai' format: templates with {%- for content in message['content'] -%} loops +# - 'string' format: templates that expect simple string content +# - Processes content accordingly to match template expectations + + +def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: + """Check if node is a variable access like {{ varname }}""" + if isinstance(node, jinja2.nodes.Name): + return node.ctx == "load" and node.name == varname + return False + + +def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: + """Check if node is an attribute access like {{ varname['key'] }} or {{ varname.key }}""" + if isinstance(node, jinja2.nodes.Getitem): + return ( + _is_var_access(node.node, varname) + and isinstance(node.arg, jinja2.nodes.Const) + and node.arg.value == key + ) + + if isinstance(node, jinja2.nodes.Getattr): + return _is_var_access(node.node, varname) and node.attr == key + + return False + + +def _is_var_or_elems_access( + node: jinja2.nodes.Node, + varname: str, + key: str = None, +) -> bool: + """Check if node accesses varname or varname[key] with filters/tests""" + if isinstance(node, jinja2.nodes.Filter): + return node.node is not None and _is_var_or_elems_access( + node.node, varname, key + ) + if isinstance(node, jinja2.nodes.Test): + return _is_var_or_elems_access(node.node, varname, key) + + if isinstance(node, jinja2.nodes.Getitem) and isinstance( + node.arg, jinja2.nodes.Slice + ): + return _is_var_or_elems_access(node.node, varname, key) + + return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname) + + +def _try_extract_ast(chat_template: str): + """Try to parse the Jinja template into an AST""" + try: + jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) + return jinja_compiled.environment.parse(chat_template) + except Exception as e: + logger.debug(f"Error when compiling Jinja template: {e}") + return None + + +def detect_template_content_format(chat_template: str) -> str: + """ + Detect whether a chat template expects 'string' or 'openai' content format. + + - 'string': content is a simple string (like DeepSeek templates) + - 'openai': content is a list of structured dicts (like Llama4 templates) + + Detection logic: + - If template has loops like {%- for content in message['content'] -%} → 'openai' + - Otherwise → 'string' + """ + jinja_ast = _try_extract_ast(chat_template) + if jinja_ast is None: + return "string" + + try: + # Look for patterns like: {%- for content in message['content'] -%} + for loop_ast in jinja_ast.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + + # Check if iterating over message['content'] or similar + if _is_var_or_elems_access(loop_iter, "message", "content"): + return "openai" # Found content iteration → openai format + + return "string" # No content loops found → string format + except Exception as e: + logger.debug(f"Error when parsing AST of Jinja template: {e}") + return "string" + + +def process_content_for_template_format( + msg_dict: dict, + content_format: str, + image_data: list, + audio_data: list, + modalities: list, +) -> dict: + """ + Process message content based on detected template format. + + Args: + msg_dict: Message dictionary with content + content_format: 'string' or 'openai' (detected via AST analysis) + image_data: List to append extracted image URLs + audio_data: List to append extracted audio URLs + modalities: List to append modalities + + Returns: + Processed message dictionary + """ + if not isinstance(msg_dict.get("content"), list): + # Already a string or None, no processing needed + return {k: v for k, v in msg_dict.items() if v is not None} + + if content_format == "openai": + # OpenAI format: preserve structured content list, normalize types + processed_content_parts = [] + for chunk in msg_dict["content"]: + if isinstance(chunk, dict): + chunk_type = chunk.get("type") + + if chunk_type == "image_url": + image_data.append(chunk["image_url"]["url"]) + if chunk.get("modalities"): + modalities.append(chunk.get("modalities")) + # Normalize to simple 'image' type for template compatibility + processed_content_parts.append({"type": "image"}) + elif chunk_type == "audio_url": + audio_data.append(chunk["audio_url"]["url"]) + # Normalize to simple 'audio' type + processed_content_parts.append({"type": "audio"}) + else: + # Keep other content as-is (text, etc.) + processed_content_parts.append(chunk) + + new_msg = { + k: v for k, v in msg_dict.items() if v is not None and k != "content" + } + new_msg["content"] = processed_content_parts + return new_msg + + else: # content_format == "string" + # String format: flatten to text only (for templates like DeepSeek) + text_parts = [] + for chunk in msg_dict["content"]: + if isinstance(chunk, dict) and chunk.get("type") == "text": + text_parts.append(chunk["text"]) + # Note: For string format, we ignore images/audio since the template + # doesn't expect structured content - multimodal placeholders would + # need to be inserted differently + + new_msg = msg_dict.copy() + new_msg["content"] = " ".join(text_parts) if text_parts else "" + new_msg = {k: v for k, v in new_msg.items() if v is not None} + return new_msg + + +def calculate_token_usage( + prompt_tokens: int, + completion_tokens: int, + cached_tokens: Optional[Dict[str, int]] = None, +) -> UsageInfo: + """Calculate token usage information""" + return UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + prompt_tokens_details=cached_tokens, + ) + + +def aggregate_token_usage( + responses: List[Dict[str, Any]], + n_choices: int = 1, + enable_cache_report: bool = False, +) -> UsageInfo: + """Aggregate token usage from multiple responses + + Args: + responses: List of response dictionaries with meta_info + n_choices: Number of choices per request (for prompt token counting) + enable_cache_report: Whether to include cached token details + + Returns: + Aggregated UsageInfo + """ + # Sum completion tokens from all responses + completion_tokens = sum( + response["meta_info"]["completion_tokens"] for response in responses + ) + + # For prompt tokens, only count every n_choices-th response to avoid double counting + prompt_tokens = sum( + responses[i]["meta_info"]["prompt_tokens"] + for i in range(0, len(responses), n_choices) + ) + + # Handle cached tokens if cache reporting is enabled + cached_tokens_details = None + if enable_cache_report: + cached_tokens_sum = sum( + response["meta_info"].get("cached_tokens", 0) for response in responses + ) + if cached_tokens_sum > 0: + cached_tokens_details = {"cached_tokens": cached_tokens_sum} + + return calculate_token_usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + cached_tokens=cached_tokens_details, + ) + + +def to_openai_style_logprobs( + input_token_logprobs=None, + output_token_logprobs=None, + input_top_logprobs=None, + output_top_logprobs=None, +): + ret_logprobs = LogProbs() + + def append_token_logprobs(token_logprobs): + for logprob, _, token_text in token_logprobs: + ret_logprobs.tokens.append(token_text) + ret_logprobs.token_logprobs.append(logprob) + + # Not supported yet + ret_logprobs.text_offset.append(-1) + + def append_top_logprobs(top_logprobs): + for tokens in top_logprobs: + if tokens is not None: + ret_logprobs.top_logprobs.append( + {token[2]: token[0] for token in tokens} + ) + else: + ret_logprobs.top_logprobs.append(None) + + if input_token_logprobs is not None: + append_token_logprobs(input_token_logprobs) + if output_token_logprobs is not None: + append_token_logprobs(output_token_logprobs) + if input_top_logprobs is not None: + append_top_logprobs(input_top_logprobs) + if output_top_logprobs is not None: + append_top_logprobs(output_top_logprobs) + + return ret_logprobs diff --git a/test/pytest.ini b/test/pytest.ini new file mode 100644 index 00000000000..2f4c80e3075 --- /dev/null +++ b/test/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode = auto diff --git a/test/srt/openai/test_protocol.py b/test/srt/openai/test_protocol.py new file mode 100644 index 00000000000..a14b3d7179f --- /dev/null +++ b/test/srt/openai/test_protocol.py @@ -0,0 +1,683 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for OpenAI API protocol models""" + +import json +import time +from typing import Dict, List, Optional + +import pytest +from pydantic import ValidationError + +from sglang.srt.entrypoints.openai.protocol import ( + BatchRequest, + BatchResponse, + ChatCompletionMessageContentImagePart, + ChatCompletionMessageContentTextPart, + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatCompletionTokenLogprob, + ChatMessage, + ChoiceLogprobs, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + DeltaMessage, + EmbeddingObject, + EmbeddingRequest, + EmbeddingResponse, + ErrorResponse, + FileDeleteResponse, + FileRequest, + FileResponse, + Function, + FunctionResponse, + JsonSchemaResponseFormat, + LogProbs, + ModelCard, + ModelList, + MultimodalEmbeddingInput, + ResponseFormat, + ScoringRequest, + ScoringResponse, + StreamOptions, + StructuralTagResponseFormat, + Tool, + ToolCall, + ToolChoice, + TopLogprob, + UsageInfo, +) + + +class TestModelCard: + """Test ModelCard protocol model""" + + def test_basic_model_card_creation(self): + """Test basic model card creation with required fields""" + card = ModelCard(id="test-model") + assert card.id == "test-model" + assert card.object == "model" + assert card.owned_by == "sglang" + assert isinstance(card.created, int) + assert card.root is None + assert card.max_model_len is None + + def test_model_card_with_optional_fields(self): + """Test model card with optional fields""" + card = ModelCard( + id="test-model", + root="/path/to/model", + max_model_len=2048, + created=1234567890, + ) + assert card.id == "test-model" + assert card.root == "/path/to/model" + assert card.max_model_len == 2048 + assert card.created == 1234567890 + + def test_model_card_serialization(self): + """Test model card JSON serialization""" + card = ModelCard(id="test-model", max_model_len=4096) + data = card.model_dump() + assert data["id"] == "test-model" + assert data["object"] == "model" + assert data["max_model_len"] == 4096 + + +class TestModelList: + """Test ModelList protocol model""" + + def test_empty_model_list(self): + """Test empty model list creation""" + model_list = ModelList() + assert model_list.object == "list" + assert len(model_list.data) == 0 + + def test_model_list_with_cards(self): + """Test model list with model cards""" + cards = [ + ModelCard(id="model-1"), + ModelCard(id="model-2", max_model_len=2048), + ] + model_list = ModelList(data=cards) + assert len(model_list.data) == 2 + assert model_list.data[0].id == "model-1" + assert model_list.data[1].id == "model-2" + + +class TestErrorResponse: + """Test ErrorResponse protocol model""" + + def test_basic_error_response(self): + """Test basic error response creation""" + error = ErrorResponse( + message="Invalid request", type="BadRequestError", code=400 + ) + assert error.object == "error" + assert error.message == "Invalid request" + assert error.type == "BadRequestError" + assert error.code == 400 + assert error.param is None + + def test_error_response_with_param(self): + """Test error response with parameter""" + error = ErrorResponse( + message="Invalid temperature", + type="ValidationError", + code=422, + param="temperature", + ) + assert error.param == "temperature" + + +class TestUsageInfo: + """Test UsageInfo protocol model""" + + def test_basic_usage_info(self): + """Test basic usage info creation""" + usage = UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30) + assert usage.prompt_tokens == 10 + assert usage.completion_tokens == 20 + assert usage.total_tokens == 30 + assert usage.prompt_tokens_details is None + + def test_usage_info_with_cache_details(self): + """Test usage info with cache details""" + usage = UsageInfo( + prompt_tokens=10, + completion_tokens=20, + total_tokens=30, + prompt_tokens_details={"cached_tokens": 5}, + ) + assert usage.prompt_tokens_details == {"cached_tokens": 5} + + +class TestCompletionRequest: + """Test CompletionRequest protocol model""" + + def test_basic_completion_request(self): + """Test basic completion request""" + request = CompletionRequest(model="test-model", prompt="Hello world") + assert request.model == "test-model" + assert request.prompt == "Hello world" + assert request.max_tokens == 16 # default + assert request.temperature == 1.0 # default + assert request.n == 1 # default + assert not request.stream # default + assert not request.echo # default + + def test_completion_request_with_options(self): + """Test completion request with various options""" + request = CompletionRequest( + model="test-model", + prompt=["Hello", "world"], + max_tokens=100, + temperature=0.7, + top_p=0.9, + n=2, + stream=True, + echo=True, + stop=[".", "!"], + logprobs=5, + ) + assert request.prompt == ["Hello", "world"] + assert request.max_tokens == 100 + assert request.temperature == 0.7 + assert request.top_p == 0.9 + assert request.n == 2 + assert request.stream + assert request.echo + assert request.stop == [".", "!"] + assert request.logprobs == 5 + + def test_completion_request_sglang_extensions(self): + """Test completion request with SGLang-specific extensions""" + request = CompletionRequest( + model="test-model", + prompt="Hello", + top_k=50, + min_p=0.1, + repetition_penalty=1.1, + regex=r"\d+", + json_schema='{"type": "object"}', + lora_path="/path/to/lora", + ) + assert request.top_k == 50 + assert request.min_p == 0.1 + assert request.repetition_penalty == 1.1 + assert request.regex == r"\d+" + assert request.json_schema == '{"type": "object"}' + assert request.lora_path == "/path/to/lora" + + def test_completion_request_validation_errors(self): + """Test completion request validation errors""" + with pytest.raises(ValidationError): + CompletionRequest() # missing required fields + + with pytest.raises(ValidationError): + CompletionRequest(model="test-model") # missing prompt + + +class TestCompletionResponse: + """Test CompletionResponse protocol model""" + + def test_basic_completion_response(self): + """Test basic completion response""" + choice = CompletionResponseChoice( + index=0, text="Hello world!", finish_reason="stop" + ) + usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5) + response = CompletionResponse( + id="test-id", model="test-model", choices=[choice], usage=usage + ) + assert response.id == "test-id" + assert response.object == "text_completion" + assert response.model == "test-model" + assert len(response.choices) == 1 + assert response.choices[0].text == "Hello world!" + assert response.usage.total_tokens == 5 + + +class TestChatCompletionRequest: + """Test ChatCompletionRequest protocol model""" + + def test_basic_chat_completion_request(self): + """Test basic chat completion request""" + messages = [{"role": "user", "content": "Hello"}] + request = ChatCompletionRequest(model="test-model", messages=messages) + assert request.model == "test-model" + assert len(request.messages) == 1 + assert request.messages[0].role == "user" + assert request.messages[0].content == "Hello" + assert request.temperature == 0.7 # default + assert not request.stream # default + assert request.tool_choice == "none" # default when no tools + + def test_chat_completion_with_multimodal_content(self): + """Test chat completion with multimodal content""" + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,/9j/4AAQ..."}, + }, + ], + } + ] + request = ChatCompletionRequest(model="test-model", messages=messages) + assert len(request.messages[0].content) == 2 + assert request.messages[0].content[0].type == "text" + assert request.messages[0].content[1].type == "image_url" + + def test_chat_completion_with_tools(self): + """Test chat completion with tools""" + messages = [{"role": "user", "content": "What's the weather?"}] + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather information", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + } + ] + request = ChatCompletionRequest( + model="test-model", messages=messages, tools=tools + ) + assert len(request.tools) == 1 + assert request.tools[0].function.name == "get_weather" + assert request.tool_choice == "auto" # default when tools present + + def test_chat_completion_tool_choice_validation(self): + """Test tool choice validation logic""" + messages = [{"role": "user", "content": "Hello"}] + + # No tools, tool_choice should default to "none" + request1 = ChatCompletionRequest(model="test-model", messages=messages) + assert request1.tool_choice == "none" + + # With tools, tool_choice should default to "auto" + tools = [ + { + "type": "function", + "function": {"name": "test_func", "description": "Test function"}, + } + ] + request2 = ChatCompletionRequest( + model="test-model", messages=messages, tools=tools + ) + assert request2.tool_choice == "auto" + + def test_chat_completion_sglang_extensions(self): + """Test chat completion with SGLang extensions""" + messages = [{"role": "user", "content": "Hello"}] + request = ChatCompletionRequest( + model="test-model", + messages=messages, + top_k=40, + min_p=0.05, + separate_reasoning=False, + stream_reasoning=False, + chat_template_kwargs={"custom_param": "value"}, + ) + assert request.top_k == 40 + assert request.min_p == 0.05 + assert not request.separate_reasoning + assert not request.stream_reasoning + assert request.chat_template_kwargs == {"custom_param": "value"} + + +class TestChatCompletionResponse: + """Test ChatCompletionResponse protocol model""" + + def test_basic_chat_completion_response(self): + """Test basic chat completion response""" + message = ChatMessage(role="assistant", content="Hello there!") + choice = ChatCompletionResponseChoice( + index=0, message=message, finish_reason="stop" + ) + usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5) + response = ChatCompletionResponse( + id="test-id", model="test-model", choices=[choice], usage=usage + ) + assert response.id == "test-id" + assert response.object == "chat.completion" + assert response.model == "test-model" + assert len(response.choices) == 1 + assert response.choices[0].message.content == "Hello there!" + + def test_chat_completion_response_with_tool_calls(self): + """Test chat completion response with tool calls""" + tool_call = ToolCall( + id="call_123", + function=FunctionResponse( + name="get_weather", arguments='{"location": "San Francisco"}' + ), + ) + message = ChatMessage(role="assistant", content=None, tool_calls=[tool_call]) + choice = ChatCompletionResponseChoice( + index=0, message=message, finish_reason="tool_calls" + ) + usage = UsageInfo(prompt_tokens=10, completion_tokens=5, total_tokens=15) + response = ChatCompletionResponse( + id="test-id", model="test-model", choices=[choice], usage=usage + ) + assert response.choices[0].message.tool_calls[0].function.name == "get_weather" + assert response.choices[0].finish_reason == "tool_calls" + + +class TestEmbeddingRequest: + """Test EmbeddingRequest protocol model""" + + def test_basic_embedding_request(self): + """Test basic embedding request""" + request = EmbeddingRequest(model="test-model", input="Hello world") + assert request.model == "test-model" + assert request.input == "Hello world" + assert request.encoding_format == "float" # default + assert request.dimensions is None # default + + def test_embedding_request_with_list_input(self): + """Test embedding request with list input""" + request = EmbeddingRequest( + model="test-model", input=["Hello", "world"], dimensions=512 + ) + assert request.input == ["Hello", "world"] + assert request.dimensions == 512 + + def test_multimodal_embedding_request(self): + """Test multimodal embedding request""" + multimodal_input = [ + MultimodalEmbeddingInput(text="Hello", image="base64_image_data"), + MultimodalEmbeddingInput(text="World", image=None), + ] + request = EmbeddingRequest(model="test-model", input=multimodal_input) + assert len(request.input) == 2 + assert request.input[0].text == "Hello" + assert request.input[0].image == "base64_image_data" + assert request.input[1].text == "World" + assert request.input[1].image is None + + +class TestEmbeddingResponse: + """Test EmbeddingResponse protocol model""" + + def test_basic_embedding_response(self): + """Test basic embedding response""" + embedding_obj = EmbeddingObject(embedding=[0.1, 0.2, 0.3], index=0) + usage = UsageInfo(prompt_tokens=3, total_tokens=3) + response = EmbeddingResponse( + data=[embedding_obj], model="test-model", usage=usage + ) + assert response.object == "list" + assert len(response.data) == 1 + assert response.data[0].embedding == [0.1, 0.2, 0.3] + assert response.data[0].index == 0 + assert response.usage.prompt_tokens == 3 + + +class TestScoringRequest: + """Test ScoringRequest protocol model""" + + def test_basic_scoring_request(self): + """Test basic scoring request""" + request = ScoringRequest( + model="test-model", query="Hello", items=["World", "Earth"] + ) + assert request.model == "test-model" + assert request.query == "Hello" + assert request.items == ["World", "Earth"] + assert not request.apply_softmax # default + assert not request.item_first # default + + def test_scoring_request_with_token_ids(self): + """Test scoring request with token IDs""" + request = ScoringRequest( + model="test-model", + query=[1, 2, 3], + items=[[4, 5], [6, 7]], + label_token_ids=[8, 9], + apply_softmax=True, + item_first=True, + ) + assert request.query == [1, 2, 3] + assert request.items == [[4, 5], [6, 7]] + assert request.label_token_ids == [8, 9] + assert request.apply_softmax + assert request.item_first + + +class TestScoringResponse: + """Test ScoringResponse protocol model""" + + def test_basic_scoring_response(self): + """Test basic scoring response""" + response = ScoringResponse(scores=[[0.1, 0.9], [0.3, 0.7]], model="test-model") + assert response.object == "scoring" + assert response.scores == [[0.1, 0.9], [0.3, 0.7]] + assert response.model == "test-model" + assert response.usage is None # default + + +class TestFileOperations: + """Test file operation protocol models""" + + def test_file_request(self): + """Test file request model""" + file_data = b"test file content" + request = FileRequest(file=file_data, purpose="batch") + assert request.file == file_data + assert request.purpose == "batch" + + def test_file_response(self): + """Test file response model""" + response = FileResponse( + id="file-123", + bytes=1024, + created_at=1234567890, + filename="test.jsonl", + purpose="batch", + ) + assert response.id == "file-123" + assert response.object == "file" + assert response.bytes == 1024 + assert response.filename == "test.jsonl" + + def test_file_delete_response(self): + """Test file delete response model""" + response = FileDeleteResponse(id="file-123", deleted=True) + assert response.id == "file-123" + assert response.object == "file" + assert response.deleted + + +class TestBatchOperations: + """Test batch operation protocol models""" + + def test_batch_request(self): + """Test batch request model""" + request = BatchRequest( + input_file_id="file-123", + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={"custom": "value"}, + ) + assert request.input_file_id == "file-123" + assert request.endpoint == "/v1/chat/completions" + assert request.completion_window == "24h" + assert request.metadata == {"custom": "value"} + + def test_batch_response(self): + """Test batch response model""" + response = BatchResponse( + id="batch-123", + endpoint="/v1/chat/completions", + input_file_id="file-123", + completion_window="24h", + created_at=1234567890, + ) + assert response.id == "batch-123" + assert response.object == "batch" + assert response.status == "validating" # default + assert response.endpoint == "/v1/chat/completions" + + +class TestResponseFormats: + """Test response format protocol models""" + + def test_basic_response_format(self): + """Test basic response format""" + format_obj = ResponseFormat(type="json_object") + assert format_obj.type == "json_object" + assert format_obj.json_schema is None + + def test_json_schema_response_format(self): + """Test JSON schema response format""" + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + json_schema = JsonSchemaResponseFormat( + name="person_schema", description="Person schema", schema=schema + ) + format_obj = ResponseFormat(type="json_schema", json_schema=json_schema) + assert format_obj.type == "json_schema" + assert format_obj.json_schema.name == "person_schema" + assert format_obj.json_schema.schema_ == schema + + def test_structural_tag_response_format(self): + """Test structural tag response format""" + structures = [ + { + "begin": "", + "schema_": {"type": "string"}, + "end": "", + } + ] + format_obj = StructuralTagResponseFormat( + type="structural_tag", structures=structures, triggers=["think"] + ) + assert format_obj.type == "structural_tag" + assert len(format_obj.structures) == 1 + assert format_obj.triggers == ["think"] + + +class TestLogProbs: + """Test LogProbs protocol models""" + + def test_basic_logprobs(self): + """Test basic LogProbs model""" + logprobs = LogProbs( + text_offset=[0, 5, 11], + token_logprobs=[-0.1, -0.2, -0.3], + tokens=["Hello", " ", "world"], + top_logprobs=[{"Hello": -0.1}, {" ": -0.2}, {"world": -0.3}], + ) + assert len(logprobs.tokens) == 3 + assert logprobs.tokens == ["Hello", " ", "world"] + assert logprobs.token_logprobs == [-0.1, -0.2, -0.3] + + def test_choice_logprobs(self): + """Test ChoiceLogprobs model""" + token_logprob = ChatCompletionTokenLogprob( + token="Hello", + bytes=[72, 101, 108, 108, 111], + logprob=-0.1, + top_logprobs=[ + TopLogprob(token="Hello", bytes=[72, 101, 108, 108, 111], logprob=-0.1) + ], + ) + choice_logprobs = ChoiceLogprobs(content=[token_logprob]) + assert len(choice_logprobs.content) == 1 + assert choice_logprobs.content[0].token == "Hello" + + +class TestStreamingModels: + """Test streaming response models""" + + def test_stream_options(self): + """Test StreamOptions model""" + options = StreamOptions(include_usage=True) + assert options.include_usage + + def test_chat_completion_stream_response(self): + """Test ChatCompletionStreamResponse model""" + delta = DeltaMessage(role="assistant", content="Hello") + choice = ChatCompletionResponseStreamChoice(index=0, delta=delta) + response = ChatCompletionStreamResponse( + id="test-id", model="test-model", choices=[choice] + ) + assert response.object == "chat.completion.chunk" + assert response.choices[0].delta.content == "Hello" + + +class TestValidationEdgeCases: + """Test edge cases and validation scenarios""" + + def test_empty_messages_validation(self): + """Test validation with empty messages""" + with pytest.raises(ValidationError): + ChatCompletionRequest(model="test-model", messages=[]) + + def test_invalid_tool_choice_type(self): + """Test invalid tool choice type""" + messages = [{"role": "user", "content": "Hello"}] + with pytest.raises(ValidationError): + ChatCompletionRequest( + model="test-model", messages=messages, tool_choice=123 + ) + + def test_negative_token_limits(self): + """Test negative token limits""" + with pytest.raises(ValidationError): + CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1) + + def test_invalid_temperature_range(self): + """Test invalid temperature values""" + # Note: The current protocol doesn't enforce temperature range, + # but this test documents expected behavior + request = CompletionRequest(model="test-model", prompt="Hello", temperature=5.0) + assert request.temperature == 5.0 # Currently allowed + + def test_model_serialization_roundtrip(self): + """Test that models can be serialized and deserialized""" + original_request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + temperature=0.7, + max_tokens=100, + ) + + # Serialize to dict + data = original_request.model_dump() + + # Deserialize back + restored_request = ChatCompletionRequest(**data) + + assert restored_request.model == original_request.model + assert restored_request.temperature == original_request.temperature + assert restored_request.max_tokens == original_request.max_tokens + assert len(restored_request.messages) == len(original_request.messages) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/srt/openai/test_serving_chat.py b/test/srt/openai/test_serving_chat.py new file mode 100644 index 00000000000..b2015866b63 --- /dev/null +++ b/test/srt/openai/test_serving_chat.py @@ -0,0 +1,634 @@ +""" +Unit tests for the OpenAIServingChat class from serving_chat.py. + +These tests ensure that the refactored implementation maintains compatibility +with the original adapter.py functionality. +""" + +import uuid +from unittest.mock import Mock, patch + +import pytest +from fastapi import Request + +from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse +from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat +from sglang.srt.managers.io_struct import GenerateReqInput + + +# Mock TokenizerManager since it may not be directly importable in tests +class MockTokenizerManager: + def __init__(self): + self.model_config = Mock() + self.model_config.is_multimodal = False + self.server_args = Mock() + self.server_args.enable_cache_report = False + self.server_args.tool_call_parser = "hermes" + self.server_args.reasoning_parser = None + self.chat_template_name = "llama-3" + + # Mock tokenizer + self.tokenizer = Mock() + self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5]) + self.tokenizer.decode = Mock(return_value="Test response") + self.tokenizer.chat_template = None + self.tokenizer.bos_token_id = 1 + + # Mock generate_request method + async def mock_generate(): + yield { + "text": "Test response", + "meta_info": { + "id": f"chatcmpl-{uuid.uuid4()}", + "prompt_tokens": 10, + "completion_tokens": 5, + "cached_tokens": 0, + "finish_reason": {"type": "stop", "matched": None}, + "output_token_logprobs": [(0.1, 1, "Test"), (0.2, 2, "response")], + "output_top_logprobs": None, + }, + "index": 0, + } + + self.generate_request = Mock(return_value=mock_generate()) + self.create_abort_task = Mock(return_value=None) + + +@pytest.fixture +def mock_tokenizer_manager(): + """Create a mock tokenizer manager for testing.""" + return MockTokenizerManager() + + +@pytest.fixture +def serving_chat(mock_tokenizer_manager): + """Create a OpenAIServingChat instance for testing.""" + return OpenAIServingChat(mock_tokenizer_manager) + + +@pytest.fixture +def mock_request(): + """Create a mock FastAPI request.""" + request = Mock(spec=Request) + request.headers = {} + return request + + +@pytest.fixture +def basic_chat_request(): + """Create a basic chat completion request.""" + return ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=0.7, + max_tokens=100, + stream=False, + ) + + +@pytest.fixture +def streaming_chat_request(): + """Create a streaming chat completion request.""" + return ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=0.7, + max_tokens=100, + stream=True, + ) + + +class TestOpenAIServingChatConversion: + """Test request conversion methods.""" + + def test_convert_to_internal_request_single( + self, serving_chat, basic_chat_request, mock_tokenizer_manager + ): + """Test converting single request to internal format.""" + with patch( + "sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv" + ) as mock_conv: + mock_conv_instance = Mock() + mock_conv_instance.get_prompt.return_value = "Test prompt" + mock_conv_instance.image_data = None + mock_conv_instance.audio_data = None + mock_conv_instance.modalities = [] + mock_conv_instance.stop_str = [""] + mock_conv.return_value = mock_conv_instance + + # Mock the _process_messages method to return expected values + with patch.object(serving_chat, "_process_messages") as mock_process: + mock_process.return_value = ( + "Test prompt", + [1, 2, 3], + None, + None, + [], + [""], + None, # tool_call_constraint + ) + + adapted_request, processed_request = ( + serving_chat._convert_to_internal_request( + [basic_chat_request], ["test-id"] + ) + ) + + assert isinstance(adapted_request, GenerateReqInput) + assert adapted_request.stream == basic_chat_request.stream + assert processed_request == basic_chat_request + + +class TestToolCalls: + """Test tool call functionality from adapter.py""" + + def test_tool_call_request_conversion(self, serving_chat): + """Test request with tool calls""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "What's the weather?"}], + tools=[ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather information", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + } + ], + tool_choice="auto", + ) + + with patch.object(serving_chat, "_process_messages") as mock_process: + mock_process.return_value = ( + "Test prompt", + [1, 2, 3], + None, + None, + [], + [""], + None, # tool_call_constraint + ) + + adapted_request, _ = serving_chat._convert_to_internal_request( + [request], ["test-id"] + ) + + assert adapted_request.rid == "test-id" + # Tool call constraint should be processed + assert request.tools is not None + + def test_tool_choice_none(self, serving_chat): + """Test tool_choice=none disables tool calls""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + tools=[{"type": "function", "function": {"name": "test_func"}}], + tool_choice="none", + ) + + with patch.object(serving_chat, "_process_messages") as mock_process: + mock_process.return_value = ( + "Test prompt", + [1, 2, 3], + None, + None, + [], + [""], + None, # tool_call_constraint + ) + + adapted_request, _ = serving_chat._convert_to_internal_request( + [request], ["test-id"] + ) + + # Tools should not be processed when tool_choice is "none" + assert adapted_request.rid == "test-id" + + def test_tool_call_response_processing(self, serving_chat): + """Test processing tool calls in response""" + mock_ret_item = { + "text": '{"name": "get_weather", "parameters": {"location": "Paris"}}', + "meta_info": { + "output_token_logprobs": [], + "output_top_logprobs": None, + }, + } + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + } + ] + + finish_reason = {"type": "stop", "matched": None} + + # Mock FunctionCallParser + with patch( + "sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser" + ) as mock_parser_class: + mock_parser = Mock() + mock_parser.has_tool_call.return_value = True + + # Create proper mock tool call object + mock_tool_call = Mock() + mock_tool_call.name = "get_weather" + mock_tool_call.parameters = '{"location": "Paris"}' + + mock_parser.parse_non_stream.return_value = ("", [mock_tool_call]) + mock_parser_class.return_value = mock_parser + + tool_calls, text, updated_finish_reason = serving_chat._process_tool_calls( + mock_ret_item["text"], tools, "hermes", finish_reason + ) + + assert tool_calls is not None + assert len(tool_calls) == 1 + assert updated_finish_reason["type"] == "tool_calls" + + +class TestMultimodalContent: + """Test multimodal content handling from adapter.py""" + + def test_multimodal_request_with_images(self, serving_chat): + """Test request with image content""" + request = ChatCompletionRequest( + model="test-model", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,..."}, + }, + ], + } + ], + ) + + # Set multimodal mode + serving_chat.tokenizer_manager.model_config.is_multimodal = True + + with patch.object(serving_chat, "_apply_jinja_template") as mock_apply: + mock_apply.return_value = ( + "prompt", + [1, 2, 3], + ["image_data"], + None, + [], + [], + ) + + with patch.object( + serving_chat, "_apply_conversation_template" + ) as mock_conv: + mock_conv.return_value = ("prompt", ["image_data"], None, [], []) + + ( + prompt, + prompt_ids, + image_data, + audio_data, + modalities, + stop, + tool_call_constraint, + ) = serving_chat._process_messages(request, True) + + assert image_data == ["image_data"] + assert prompt == "prompt" + + def test_multimodal_request_with_audio(self, serving_chat): + """Test request with audio content""" + request = ChatCompletionRequest( + model="test-model", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Transcribe this audio"}, + { + "type": "audio_url", + "audio_url": {"url": "data:audio/wav;base64,UklGR..."}, + }, + ], + } + ], + ) + + serving_chat.tokenizer_manager.model_config.is_multimodal = True + + with patch.object(serving_chat, "_apply_jinja_template") as mock_apply: + mock_apply.return_value = ( + "prompt", + [1, 2, 3], + None, + ["audio_data"], + ["audio"], + [], + ) + + with patch.object( + serving_chat, "_apply_conversation_template" + ) as mock_conv: + mock_conv.return_value = ("prompt", None, ["audio_data"], ["audio"], []) + + ( + prompt, + prompt_ids, + image_data, + audio_data, + modalities, + stop, + tool_call_constraint, + ) = serving_chat._process_messages(request, True) + + assert audio_data == ["audio_data"] + assert modalities == ["audio"] + + +class TestTemplateHandling: + """Test chat template handling from adapter.py""" + + def test_jinja_template_processing(self, serving_chat): + """Test Jinja template processing""" + request = ChatCompletionRequest( + model="test-model", messages=[{"role": "user", "content": "Hello"}] + ) + + # Mock the template attribute directly + serving_chat.tokenizer_manager.chat_template_name = None + serving_chat.tokenizer_manager.tokenizer.chat_template = "" + + with patch.object(serving_chat, "_apply_jinja_template") as mock_apply: + mock_apply.return_value = ( + "processed_prompt", + [1, 2, 3], + None, + None, + [], + [""], + ) + + # Mock hasattr to simulate the None check + with patch("builtins.hasattr") as mock_hasattr: + mock_hasattr.return_value = True + + ( + prompt, + prompt_ids, + image_data, + audio_data, + modalities, + stop, + tool_call_constraint, + ) = serving_chat._process_messages(request, False) + + assert prompt == "processed_prompt" + assert prompt_ids == [1, 2, 3] + + def test_conversation_template_processing(self, serving_chat): + """Test conversation template processing""" + request = ChatCompletionRequest( + model="test-model", messages=[{"role": "user", "content": "Hello"}] + ) + + serving_chat.tokenizer_manager.chat_template_name = "llama-3" + + with patch.object(serving_chat, "_apply_conversation_template") as mock_apply: + mock_apply.return_value = ("conv_prompt", None, None, [], [""]) + + ( + prompt, + prompt_ids, + image_data, + audio_data, + modalities, + stop, + tool_call_constraint, + ) = serving_chat._process_messages(request, False) + + assert prompt == "conv_prompt" + assert stop == [""] + + def test_continue_final_message(self, serving_chat): + """Test continue_final_message functionality""" + request = ChatCompletionRequest( + model="test-model", + messages=[ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + ], + continue_final_message=True, + ) + + with patch.object(serving_chat, "_apply_conversation_template") as mock_apply: + mock_apply.return_value = ("Hi there", None, None, [], [""]) + + ( + prompt, + prompt_ids, + image_data, + audio_data, + modalities, + stop, + tool_call_constraint, + ) = serving_chat._process_messages(request, False) + + # Should handle continue_final_message properly + assert prompt == "Hi there" + + +class TestReasoningContent: + """Test reasoning content separation from adapter.py""" + + def test_reasoning_content_request(self, serving_chat): + """Test request with reasoning content separation""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Solve this math problem"}], + separate_reasoning=True, + stream_reasoning=False, + ) + + with patch.object(serving_chat, "_process_messages") as mock_process: + mock_process.return_value = ( + "Test prompt", + [1, 2, 3], + None, + None, + [], + [""], + None, # tool_call_constraint + ) + + adapted_request, _ = serving_chat._convert_to_internal_request( + [request], ["test-id"] + ) + + assert adapted_request.rid == "test-id" + assert request.separate_reasoning == True + + def test_reasoning_content_response(self, serving_chat): + """Test reasoning content in response""" + mock_ret_item = { + "text": "This is reasoningAnswer: 42", + "meta_info": { + "output_token_logprobs": [], + "output_top_logprobs": None, + }, + } + + # Mock ReasoningParser + with patch( + "sglang.srt.entrypoints.openai.serving_chat.ReasoningParser" + ) as mock_parser_class: + mock_parser = Mock() + mock_parser.parse_non_stream.return_value = ( + "This is reasoning", + "Answer: 42", + ) + mock_parser_class.return_value = mock_parser + + choice_logprobs = None + reasoning_text = None + text = mock_ret_item["text"] + + # Simulate reasoning processing + enable_thinking = True + if enable_thinking: + parser = mock_parser_class(model_type="test", stream_reasoning=False) + reasoning_text, text = parser.parse_non_stream(text) + + assert reasoning_text == "This is reasoning" + assert text == "Answer: 42" + + +class TestSamplingParams: + """Test sampling parameter handling from adapter.py""" + + def test_all_sampling_parameters(self, serving_chat): + """Test all sampling parameters are properly handled""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + temperature=0.8, + max_tokens=150, + max_completion_tokens=200, + min_tokens=5, + top_p=0.9, + top_k=50, + min_p=0.1, + presence_penalty=0.1, + frequency_penalty=0.2, + repetition_penalty=1.1, + stop=["<|endoftext|>"], + stop_token_ids=[13, 14], + regex=r"\d+", + ebnf=" ::= ", + n=2, + no_stop_trim=True, + ignore_eos=True, + skip_special_tokens=False, + logit_bias={"1": 0.5, "2": -0.3}, + ) + + with patch.object(serving_chat, "_process_messages") as mock_process: + mock_process.return_value = ( + "Test prompt", + [1, 2, 3], + None, + None, + [], + [""], + None, # tool_call_constraint + ) + + sampling_params = serving_chat._build_sampling_params( + request, [""], None + ) + + # Verify all parameters + assert sampling_params["temperature"] == 0.8 + assert sampling_params["max_new_tokens"] == 150 + assert sampling_params["min_new_tokens"] == 5 + assert sampling_params["top_p"] == 0.9 + assert sampling_params["top_k"] == 50 + assert sampling_params["min_p"] == 0.1 + assert sampling_params["presence_penalty"] == 0.1 + assert sampling_params["frequency_penalty"] == 0.2 + assert sampling_params["repetition_penalty"] == 1.1 + assert sampling_params["stop"] == [""] + assert sampling_params["logit_bias"] == {"1": 0.5, "2": -0.3} + + def test_response_format_json_schema(self, serving_chat): + """Test response format with JSON schema""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Generate JSON"}], + response_format={ + "type": "json_schema", + "json_schema": { + "name": "response", + "schema": { + "type": "object", + "properties": {"answer": {"type": "string"}}, + }, + }, + }, + ) + + with patch.object(serving_chat, "_process_messages") as mock_process: + mock_process.return_value = ( + "Test prompt", + [1, 2, 3], + None, + None, + [], + [""], + None, # tool_call_constraint + ) + + sampling_params = serving_chat._build_sampling_params( + request, [""], None + ) + + assert "json_schema" in sampling_params + assert '"type": "object"' in sampling_params["json_schema"] + + def test_response_format_json_object(self, serving_chat): + """Test response format with JSON object""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Generate JSON"}], + response_format={"type": "json_object"}, + ) + + with patch.object(serving_chat, "_process_messages") as mock_process: + mock_process.return_value = ( + "Test prompt", + [1, 2, 3], + None, + None, + [], + [""], + None, # tool_call_constraint + ) + + sampling_params = serving_chat._build_sampling_params( + request, [""], None + ) + + assert sampling_params["json_schema"] == '{"type": "object"}' diff --git a/test/srt/openai/test_serving_completions.py b/test/srt/openai/test_serving_completions.py new file mode 100644 index 00000000000..3e8fc42c8c6 --- /dev/null +++ b/test/srt/openai/test_serving_completions.py @@ -0,0 +1,176 @@ +""" +Tests for the refactored completions serving handler +""" + +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from sglang.srt.entrypoints.openai.protocol import ( + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionStreamResponse, + ErrorResponse, +) +from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion +from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.managers.tokenizer_manager import TokenizerManager + + +@pytest.fixture +def mock_tokenizer_manager(): + """Create a mock tokenizer manager""" + manager = Mock(spec=TokenizerManager) + + # Mock tokenizer + manager.tokenizer = Mock() + manager.tokenizer.encode = Mock(return_value=[1, 2, 3, 4]) + manager.tokenizer.decode = Mock(return_value="decoded text") + manager.tokenizer.bos_token_id = 1 + + # Mock model config + manager.model_config = Mock() + manager.model_config.is_multimodal = False + + # Mock server args + manager.server_args = Mock() + manager.server_args.enable_cache_report = False + + # Mock generation + manager.generate_request = AsyncMock() + manager.create_abort_task = Mock(return_value=None) + + return manager + + +@pytest.fixture +def serving_completion(mock_tokenizer_manager): + """Create a OpenAIServingCompletion instance""" + return OpenAIServingCompletion(mock_tokenizer_manager) + + +class TestPromptHandling: + """Test different prompt types and formats from adapter.py""" + + def test_single_string_prompt(self, serving_completion): + """Test handling single string prompt""" + request = CompletionRequest( + model="test-model", prompt="Hello world", max_tokens=100 + ) + + adapted_request, _ = serving_completion._convert_to_internal_request( + [request], ["test-id"] + ) + + assert adapted_request.text == "Hello world" + + def test_single_token_ids_prompt(self, serving_completion): + """Test handling single token IDs prompt""" + request = CompletionRequest( + model="test-model", prompt=[1, 2, 3, 4], max_tokens=100 + ) + + adapted_request, _ = serving_completion._convert_to_internal_request( + [request], ["test-id"] + ) + + assert adapted_request.input_ids == [1, 2, 3, 4] + + def test_completion_template_handling(self, serving_completion): + """Test completion template processing""" + request = CompletionRequest( + model="test-model", + prompt="def hello():", + suffix="return 'world'", + max_tokens=100, + ) + + with patch( + "sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined", + return_value=True, + ): + with patch( + "sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request", + return_value="processed_prompt", + ): + adapted_request, _ = serving_completion._convert_to_internal_request( + [request], ["test-id"] + ) + + assert adapted_request.text == "processed_prompt" + + +class TestEchoHandling: + """Test echo functionality from adapter.py""" + + def test_echo_with_string_prompt_streaming(self, serving_completion): + """Test echo handling with string prompt in streaming""" + request = CompletionRequest( + model="test-model", prompt="Hello", max_tokens=100, echo=True + ) + + # Test _get_echo_text method + echo_text = serving_completion._get_echo_text(request, 0) + assert echo_text == "Hello" + + def test_echo_with_list_of_strings_streaming(self, serving_completion): + """Test echo handling with list of strings in streaming""" + request = CompletionRequest( + model="test-model", + prompt=["Hello", "World"], + max_tokens=100, + echo=True, + n=1, + ) + + echo_text = serving_completion._get_echo_text(request, 0) + assert echo_text == "Hello" + + echo_text = serving_completion._get_echo_text(request, 1) + assert echo_text == "World" + + def test_echo_with_token_ids_streaming(self, serving_completion): + """Test echo handling with token IDs in streaming""" + request = CompletionRequest( + model="test-model", prompt=[1, 2, 3], max_tokens=100, echo=True + ) + + serving_completion.tokenizer_manager.tokenizer.decode.return_value = ( + "decoded_prompt" + ) + echo_text = serving_completion._get_echo_text(request, 0) + assert echo_text == "decoded_prompt" + + def test_echo_with_multiple_token_ids_streaming(self, serving_completion): + """Test echo handling with multiple token ID prompts in streaming""" + request = CompletionRequest( + model="test-model", prompt=[[1, 2], [3, 4]], max_tokens=100, echo=True, n=1 + ) + + serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded" + echo_text = serving_completion._get_echo_text(request, 0) + assert echo_text == "decoded" + + def test_prepare_echo_prompts_non_streaming(self, serving_completion): + """Test prepare echo prompts for non-streaming response""" + # Test with single string + request = CompletionRequest(model="test-model", prompt="Hello", echo=True) + + echo_prompts = serving_completion._prepare_echo_prompts(request) + assert echo_prompts == ["Hello"] + + # Test with list of strings + request = CompletionRequest( + model="test-model", prompt=["Hello", "World"], echo=True + ) + + echo_prompts = serving_completion._prepare_echo_prompts(request) + assert echo_prompts == ["Hello", "World"] + + # Test with token IDs + request = CompletionRequest(model="test-model", prompt=[1, 2, 3], echo=True) + + serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded" + echo_prompts = serving_completion._prepare_echo_prompts(request) + assert echo_prompts == ["decoded"] diff --git a/test/srt/openai/test_serving_embedding.py b/test/srt/openai/test_serving_embedding.py new file mode 100644 index 00000000000..58438bba884 --- /dev/null +++ b/test/srt/openai/test_serving_embedding.py @@ -0,0 +1,316 @@ +""" +Unit tests for the OpenAIServingEmbedding class from serving_embedding.py. + +These tests ensure that the embedding serving implementation maintains compatibility +with the original adapter.py functionality and follows OpenAI API specifications. +""" + +import asyncio +import json +import time +import uuid +from typing import Any, Dict, List +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from fastapi import Request +from fastapi.responses import ORJSONResponse +from pydantic_core import ValidationError + +from sglang.srt.entrypoints.openai.protocol import ( + EmbeddingObject, + EmbeddingRequest, + EmbeddingResponse, + ErrorResponse, + MultimodalEmbeddingInput, + UsageInfo, +) +from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from sglang.srt.managers.io_struct import EmbeddingReqInput + + +# Mock TokenizerManager for embedding tests +class MockTokenizerManager: + def __init__(self): + self.model_config = Mock() + self.model_config.is_multimodal = False + self.server_args = Mock() + self.server_args.enable_cache_report = False + self.model_path = "test-model" + + # Mock tokenizer + self.tokenizer = Mock() + self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5]) + self.tokenizer.decode = Mock(return_value="Test embedding input") + self.tokenizer.chat_template = None + self.tokenizer.bos_token_id = 1 + + # Mock generate_request method for embeddings + async def mock_generate_embedding(): + yield { + "embedding": [0.1, 0.2, 0.3, 0.4, 0.5] * 20, # 100-dim embedding + "meta_info": { + "id": f"embd-{uuid.uuid4()}", + "prompt_tokens": 5, + }, + } + + self.generate_request = Mock(return_value=mock_generate_embedding()) + + +@pytest.fixture +def mock_tokenizer_manager(): + """Create a mock tokenizer manager for testing.""" + return MockTokenizerManager() + + +@pytest.fixture +def serving_embedding(mock_tokenizer_manager): + """Create an OpenAIServingEmbedding instance for testing.""" + return OpenAIServingEmbedding(mock_tokenizer_manager) + + +@pytest.fixture +def mock_request(): + """Create a mock FastAPI request.""" + request = Mock(spec=Request) + request.headers = {} + return request + + +@pytest.fixture +def basic_embedding_request(): + """Create a basic embedding request.""" + return EmbeddingRequest( + model="test-model", + input="Hello, how are you?", + encoding_format="float", + ) + + +@pytest.fixture +def list_embedding_request(): + """Create an embedding request with list input.""" + return EmbeddingRequest( + model="test-model", + input=["Hello, how are you?", "I am fine, thank you!"], + encoding_format="float", + ) + + +@pytest.fixture +def multimodal_embedding_request(): + """Create a multimodal embedding request.""" + return EmbeddingRequest( + model="test-model", + input=[ + MultimodalEmbeddingInput(text="Hello", image="base64_image_data"), + MultimodalEmbeddingInput(text="World", image=None), + ], + encoding_format="float", + ) + + +@pytest.fixture +def token_ids_embedding_request(): + """Create an embedding request with token IDs.""" + return EmbeddingRequest( + model="test-model", + input=[1, 2, 3, 4, 5], + encoding_format="float", + ) + + +class TestOpenAIServingEmbeddingConversion: + """Test request conversion methods.""" + + def test_convert_single_string_request( + self, serving_embedding, basic_embedding_request + ): + """Test converting single string request to internal format.""" + adapted_request, processed_request = ( + serving_embedding._convert_to_internal_request( + [basic_embedding_request], ["test-id"] + ) + ) + + assert isinstance(adapted_request, EmbeddingReqInput) + assert adapted_request.text == "Hello, how are you?" + assert adapted_request.rid == "test-id" + assert processed_request == basic_embedding_request + + def test_convert_list_string_request( + self, serving_embedding, list_embedding_request + ): + """Test converting list of strings request to internal format.""" + adapted_request, processed_request = ( + serving_embedding._convert_to_internal_request( + [list_embedding_request], ["test-id"] + ) + ) + + assert isinstance(adapted_request, EmbeddingReqInput) + assert adapted_request.text == ["Hello, how are you?", "I am fine, thank you!"] + assert adapted_request.rid == "test-id" + assert processed_request == list_embedding_request + + def test_convert_token_ids_request( + self, serving_embedding, token_ids_embedding_request + ): + """Test converting token IDs request to internal format.""" + adapted_request, processed_request = ( + serving_embedding._convert_to_internal_request( + [token_ids_embedding_request], ["test-id"] + ) + ) + + assert isinstance(adapted_request, EmbeddingReqInput) + assert adapted_request.input_ids == [1, 2, 3, 4, 5] + assert adapted_request.rid == "test-id" + assert processed_request == token_ids_embedding_request + + def test_convert_multimodal_request( + self, serving_embedding, multimodal_embedding_request + ): + """Test converting multimodal request to internal format.""" + adapted_request, processed_request = ( + serving_embedding._convert_to_internal_request( + [multimodal_embedding_request], ["test-id"] + ) + ) + + assert isinstance(adapted_request, EmbeddingReqInput) + # Should extract text and images separately + assert len(adapted_request.text) == 2 + assert "Hello" in adapted_request.text + assert "World" in adapted_request.text + assert adapted_request.image_data[0] == "base64_image_data" + assert adapted_request.image_data[1] is None + assert adapted_request.rid == "test-id" + + +class TestEmbeddingResponseBuilding: + """Test response building methods.""" + + def test_build_single_embedding_response(self, serving_embedding): + """Test building response for single embedding.""" + ret_data = [ + { + "embedding": [0.1, 0.2, 0.3, 0.4, 0.5], + "meta_info": {"prompt_tokens": 5}, + } + ] + + response = serving_embedding._build_embedding_response(ret_data, "test-model") + + assert isinstance(response, EmbeddingResponse) + assert response.model == "test-model" + assert len(response.data) == 1 + assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5] + assert response.data[0].index == 0 + assert response.data[0].object == "embedding" + assert response.usage.prompt_tokens == 5 + assert response.usage.total_tokens == 5 + assert response.usage.completion_tokens == 0 + + def test_build_multiple_embedding_response(self, serving_embedding): + """Test building response for multiple embeddings.""" + ret_data = [ + { + "embedding": [0.1, 0.2, 0.3], + "meta_info": {"prompt_tokens": 3}, + }, + { + "embedding": [0.4, 0.5, 0.6], + "meta_info": {"prompt_tokens": 4}, + }, + ] + + response = serving_embedding._build_embedding_response(ret_data, "test-model") + + assert isinstance(response, EmbeddingResponse) + assert len(response.data) == 2 + assert response.data[0].embedding == [0.1, 0.2, 0.3] + assert response.data[0].index == 0 + assert response.data[1].embedding == [0.4, 0.5, 0.6] + assert response.data[1].index == 1 + assert response.usage.prompt_tokens == 7 # 3 + 4 + assert response.usage.total_tokens == 7 + + +@pytest.mark.asyncio +class TestOpenAIServingEmbeddingAsyncMethods: + """Test async methods of OpenAIServingEmbedding.""" + + async def test_handle_request_success( + self, serving_embedding, basic_embedding_request, mock_request + ): + """Test successful embedding request handling.""" + + # Mock the generate_request to return expected data + async def mock_generate(): + yield { + "embedding": [0.1, 0.2, 0.3, 0.4, 0.5], + "meta_info": {"prompt_tokens": 5}, + } + + serving_embedding.tokenizer_manager.generate_request = Mock( + return_value=mock_generate() + ) + + response = await serving_embedding.handle_request( + basic_embedding_request, mock_request + ) + + assert isinstance(response, EmbeddingResponse) + assert len(response.data) == 1 + assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5] + + async def test_handle_request_validation_error( + self, serving_embedding, mock_request + ): + """Test handling request with validation error.""" + invalid_request = EmbeddingRequest(model="test-model", input="") + + response = await serving_embedding.handle_request(invalid_request, mock_request) + + assert isinstance(response, ORJSONResponse) + assert response.status_code == 400 + + async def test_handle_request_generation_error( + self, serving_embedding, basic_embedding_request, mock_request + ): + """Test handling request with generation error.""" + + # Mock generate_request to raise an error + async def mock_generate_error(): + raise ValueError("Generation failed") + yield # This won't be reached but needed for async generator + + serving_embedding.tokenizer_manager.generate_request = Mock( + return_value=mock_generate_error() + ) + + response = await serving_embedding.handle_request( + basic_embedding_request, mock_request + ) + + assert isinstance(response, ORJSONResponse) + assert response.status_code == 400 + + async def test_handle_request_internal_error( + self, serving_embedding, basic_embedding_request, mock_request + ): + """Test handling request with internal server error.""" + # Mock _convert_to_internal_request to raise an exception + with patch.object( + serving_embedding, + "_convert_to_internal_request", + side_effect=Exception("Internal error"), + ): + response = await serving_embedding.handle_request( + basic_embedding_request, mock_request + ) + + assert isinstance(response, ORJSONResponse) + assert response.status_code == 500