diff --git a/python/sglang/srt/entrypoints/openai/__init__.py b/python/sglang/srt/entrypoints/openai/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py
new file mode 100644
index 00000000000..c7423ed1b43
--- /dev/null
+++ b/python/sglang/srt/entrypoints/openai/protocol.py
@@ -0,0 +1,539 @@
+# Copyright 2023-2024 SGLang Team
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Pydantic models for OpenAI API protocol"""
+
+import time
+from typing import Dict, List, Optional, Union
+
+from pydantic import BaseModel, Field, field_validator, model_validator
+from typing_extensions import Literal
+
+
+class ModelCard(BaseModel):
+ """Model cards."""
+
+ id: str
+ object: str = "model"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ owned_by: str = "sglang"
+ root: Optional[str] = None
+ max_model_len: Optional[int] = None
+
+
+class ModelList(BaseModel):
+ """Model list consists of model cards."""
+
+ object: str = "list"
+ data: List[ModelCard] = Field(default_factory=list)
+
+
+class ErrorResponse(BaseModel):
+ object: str = "error"
+ message: str
+ type: str
+ param: Optional[str] = None
+ code: int
+
+
+class LogProbs(BaseModel):
+ text_offset: List[int] = Field(default_factory=list)
+ token_logprobs: List[Optional[float]] = Field(default_factory=list)
+ tokens: List[str] = Field(default_factory=list)
+ top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
+
+
+class TopLogprob(BaseModel):
+ token: str
+ bytes: List[int]
+ logprob: float
+
+
+class ChatCompletionTokenLogprob(BaseModel):
+ token: str
+ bytes: List[int]
+ logprob: float
+ top_logprobs: List[TopLogprob]
+
+
+class ChoiceLogprobs(BaseModel):
+ # build for v1/chat/completions response
+ content: List[ChatCompletionTokenLogprob]
+
+
+class UsageInfo(BaseModel):
+ prompt_tokens: int = 0
+ total_tokens: int = 0
+ completion_tokens: Optional[int] = 0
+ # only used to return cached tokens when --enable-cache-report is set
+ prompt_tokens_details: Optional[Dict[str, int]] = None
+
+
+class StreamOptions(BaseModel):
+ include_usage: Optional[bool] = False
+
+
+class JsonSchemaResponseFormat(BaseModel):
+ name: str
+ description: Optional[str] = None
+ # use alias to workaround pydantic conflict
+ schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
+ strict: Optional[bool] = False
+
+
+class FileRequest(BaseModel):
+ # https://platform.openai.com/docs/api-reference/files/create
+ file: bytes # The File object (not file name) to be uploaded
+ purpose: str = (
+ "batch" # The intended purpose of the uploaded file, default is "batch"
+ )
+
+
+class FileResponse(BaseModel):
+ id: str
+ object: str = "file"
+ bytes: int
+ created_at: int
+ filename: str
+ purpose: str
+
+
+class FileDeleteResponse(BaseModel):
+ id: str
+ object: str = "file"
+ deleted: bool
+
+
+class BatchRequest(BaseModel):
+ input_file_id: (
+ str # The ID of an uploaded file that contains requests for the new batch
+ )
+ endpoint: str # The endpoint to be used for all requests in the batch
+ completion_window: str # The time frame within which the batch should be processed
+ metadata: Optional[dict] = None # Optional custom metadata for the batch
+
+
+class BatchResponse(BaseModel):
+ id: str
+ object: str = "batch"
+ endpoint: str
+ errors: Optional[dict] = None
+ input_file_id: str
+ completion_window: str
+ status: str = "validating"
+ output_file_id: Optional[str] = None
+ error_file_id: Optional[str] = None
+ created_at: int
+ in_progress_at: Optional[int] = None
+ expires_at: Optional[int] = None
+ finalizing_at: Optional[int] = None
+ completed_at: Optional[int] = None
+ failed_at: Optional[int] = None
+ expired_at: Optional[int] = None
+ cancelling_at: Optional[int] = None
+ cancelled_at: Optional[int] = None
+ request_counts: Optional[dict] = None
+ metadata: Optional[dict] = None
+
+
+class CompletionRequest(BaseModel):
+ # Ordered by official OpenAI API documentation
+ # https://platform.openai.com/docs/api-reference/completions/create
+ model: str
+ prompt: Union[List[int], List[List[int]], str, List[str]]
+ best_of: Optional[int] = None
+ echo: bool = False
+ frequency_penalty: float = 0.0
+ logit_bias: Optional[Dict[str, float]] = None
+ logprobs: Optional[int] = None
+ max_tokens: int = 16
+ n: int = 1
+ presence_penalty: float = 0.0
+ seed: Optional[int] = None
+ stop: Optional[Union[str, List[str]]] = None
+ stream: bool = False
+ stream_options: Optional[StreamOptions] = None
+ suffix: Optional[str] = None
+ temperature: float = 1.0
+ top_p: float = 1.0
+ user: Optional[str] = None
+
+ # Extra parameters for SRT backend only and will be ignored by OpenAI models.
+ top_k: int = -1
+ min_p: float = 0.0
+ min_tokens: int = 0
+ json_schema: Optional[str] = None
+ regex: Optional[str] = None
+ ebnf: Optional[str] = None
+ repetition_penalty: float = 1.0
+ stop_token_ids: Optional[List[int]] = None
+ no_stop_trim: bool = False
+ ignore_eos: bool = False
+ skip_special_tokens: bool = True
+ lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
+ session_params: Optional[Dict] = None
+
+ # For PD disaggregation
+ bootstrap_host: Optional[str] = None
+ bootstrap_port: Optional[int] = None
+ bootstrap_room: Optional[int] = None
+
+ @field_validator("max_tokens")
+ @classmethod
+ def validate_max_tokens_positive(cls, v):
+ if v is not None and v <= 0:
+ raise ValueError("max_tokens must be positive")
+ return v
+
+
+class CompletionResponseChoice(BaseModel):
+ index: int
+ text: str
+ logprobs: Optional[LogProbs] = None
+ finish_reason: Literal["stop", "length", "content_filter", "abort"]
+ matched_stop: Union[None, int, str] = None
+
+
+class CompletionResponse(BaseModel):
+ id: str
+ object: str = "text_completion"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[CompletionResponseChoice]
+ usage: UsageInfo
+
+
+class CompletionResponseStreamChoice(BaseModel):
+ index: int
+ text: str
+ logprobs: Optional[LogProbs] = None
+ finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
+ matched_stop: Union[None, int, str] = None
+
+
+class CompletionStreamResponse(BaseModel):
+ id: str
+ object: str = "text_completion"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[CompletionResponseStreamChoice]
+ usage: Optional[UsageInfo] = None
+
+
+class ChatCompletionMessageContentTextPart(BaseModel):
+ type: Literal["text"]
+ text: str
+
+
+class ChatCompletionMessageContentImageURL(BaseModel):
+ url: str
+ detail: Optional[Literal["auto", "low", "high"]] = "auto"
+
+
+class ChatCompletionMessageContentAudioURL(BaseModel):
+ url: str
+
+
+class ChatCompletionMessageContentImagePart(BaseModel):
+ type: Literal["image_url"]
+ image_url: ChatCompletionMessageContentImageURL
+ modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
+
+
+class ChatCompletionMessageContentAudioPart(BaseModel):
+ type: Literal["audio_url"]
+ audio_url: ChatCompletionMessageContentAudioURL
+
+
+ChatCompletionMessageContentPart = Union[
+ ChatCompletionMessageContentTextPart,
+ ChatCompletionMessageContentImagePart,
+ ChatCompletionMessageContentAudioPart,
+]
+
+
+class FunctionResponse(BaseModel):
+ """Function response."""
+
+ name: Optional[str] = None
+ arguments: Optional[str] = None
+
+
+class ToolCall(BaseModel):
+ """Tool call response."""
+
+ id: Optional[str] = None
+ index: Optional[int] = None
+ type: Literal["function"] = "function"
+ function: FunctionResponse
+
+
+class ChatCompletionMessageGenericParam(BaseModel):
+ role: Literal["system", "assistant", "tool"]
+ content: Union[str, List[ChatCompletionMessageContentTextPart], None]
+ tool_call_id: Optional[str] = None
+ name: Optional[str] = None
+ reasoning_content: Optional[str] = None
+ tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
+
+
+class ChatCompletionMessageUserParam(BaseModel):
+ role: Literal["user"]
+ content: Union[str, List[ChatCompletionMessageContentPart]]
+
+
+ChatCompletionMessageParam = Union[
+ ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
+]
+
+
+class ResponseFormat(BaseModel):
+ type: Literal["text", "json_object", "json_schema"]
+ json_schema: Optional[JsonSchemaResponseFormat] = None
+
+
+class StructuresResponseFormat(BaseModel):
+ begin: str
+ schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
+ end: str
+
+
+class StructuralTagResponseFormat(BaseModel):
+ type: Literal["structural_tag"]
+ structures: List[StructuresResponseFormat]
+ triggers: List[str]
+
+
+class Function(BaseModel):
+ """Function descriptions."""
+
+ description: Optional[str] = Field(default=None, examples=[None])
+ name: Optional[str] = None
+ parameters: Optional[object] = None
+ strict: bool = False
+
+
+class Tool(BaseModel):
+ """Function wrapper."""
+
+ type: str = Field(default="function", examples=["function"])
+ function: Function
+
+
+class ToolChoiceFuncName(BaseModel):
+ """The name of tool choice function."""
+
+ name: Optional[str] = None
+
+
+class ToolChoice(BaseModel):
+ """The tool choice definition."""
+
+ function: ToolChoiceFuncName
+ type: Literal["function"] = Field(default="function", examples=["function"])
+
+
+class ChatCompletionRequest(BaseModel):
+ # Ordered by official OpenAI API documentation
+ # https://platform.openai.com/docs/api-reference/chat/create
+ messages: List[ChatCompletionMessageParam]
+ model: str
+ frequency_penalty: float = 0.0
+ logit_bias: Optional[Dict[str, float]] = None
+ logprobs: bool = False
+ top_logprobs: Optional[int] = None
+ max_tokens: Optional[int] = Field(
+ default=None,
+ deprecated="max_tokens is deprecated in favor of the max_completion_tokens field",
+ description="The maximum number of tokens that can be generated in the chat completion. ",
+ )
+ max_completion_tokens: Optional[int] = Field(
+ default=None,
+ description="The maximum number of completion tokens for a chat completion request, "
+ "including visible output tokens and reasoning tokens. Input tokens are not included. ",
+ )
+ n: int = 1
+ presence_penalty: float = 0.0
+ response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
+ seed: Optional[int] = None
+ stop: Optional[Union[str, List[str]]] = None
+ stream: bool = False
+ stream_options: Optional[StreamOptions] = None
+ temperature: float = 0.7
+ top_p: float = 1.0
+ user: Optional[str] = None
+ tools: Optional[List[Tool]] = Field(default=None, examples=[None])
+ tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
+ default="auto", examples=["none"]
+ ) # noqa
+
+ @model_validator(mode="before")
+ @classmethod
+ def set_tool_choice_default(cls, values):
+ if isinstance(values, dict):
+ if values.get("tool_choice") is None:
+ if values.get("tools") is None:
+ values["tool_choice"] = "none"
+ else:
+ values["tool_choice"] = "auto"
+ return values
+
+ @field_validator("messages")
+ @classmethod
+ def validate_messages_not_empty(cls, v):
+ if not v:
+ raise ValueError("Messages cannot be empty")
+ return v
+
+ # Extra parameters for SRT backend only and will be ignored by OpenAI models.
+ top_k: int = -1
+ min_p: float = 0.0
+ min_tokens: int = 0
+ regex: Optional[str] = None
+ ebnf: Optional[str] = None
+ repetition_penalty: float = 1.0
+ stop_token_ids: Optional[List[int]] = None
+ no_stop_trim: bool = False
+ ignore_eos: bool = False
+ continue_final_message: bool = False
+ skip_special_tokens: bool = True
+ lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
+ session_params: Optional[Dict] = None
+ separate_reasoning: bool = True
+ stream_reasoning: bool = True
+ chat_template_kwargs: Optional[Dict] = None
+
+ # The request id.
+ rid: Optional[str] = None
+
+ # For PD disaggregation
+ bootstrap_host: Optional[str] = None
+ bootstrap_port: Optional[int] = None
+ bootstrap_room: Optional[int] = None
+
+
+class ChatMessage(BaseModel):
+ role: Optional[str] = None
+ content: Optional[str] = None
+ reasoning_content: Optional[str] = None
+ tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
+
+
+class ChatCompletionResponseChoice(BaseModel):
+ index: int
+ message: ChatMessage
+ logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
+ finish_reason: Literal[
+ "stop", "length", "tool_calls", "content_filter", "function_call", "abort"
+ ]
+ matched_stop: Union[None, int, str] = None
+
+
+class ChatCompletionResponse(BaseModel):
+ id: str
+ object: str = "chat.completion"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[ChatCompletionResponseChoice]
+ usage: UsageInfo
+
+
+class DeltaMessage(BaseModel):
+ role: Optional[str] = None
+ content: Optional[str] = None
+ reasoning_content: Optional[str] = None
+ tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
+
+
+class ChatCompletionResponseStreamChoice(BaseModel):
+ index: int
+ delta: DeltaMessage
+ logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
+ finish_reason: Optional[
+ Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
+ ] = None
+ matched_stop: Union[None, int, str] = None
+
+
+class ChatCompletionStreamResponse(BaseModel):
+ id: str
+ object: str = "chat.completion.chunk"
+ created: int = Field(default_factory=lambda: int(time.time()))
+ model: str
+ choices: List[ChatCompletionResponseStreamChoice]
+ usage: Optional[UsageInfo] = None
+
+
+class MultimodalEmbeddingInput(BaseModel):
+ text: Optional[str] = None
+ image: Optional[str] = None
+
+
+EmbeddingInput = Union[
+ List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
+]
+
+
+class EmbeddingRequest(BaseModel):
+ # Ordered by official OpenAI API documentation
+ # https://platform.openai.com/docs/api-reference/embeddings/create
+ input: EmbeddingInput
+ model: str
+ encoding_format: str = "float"
+ dimensions: int = None
+ user: Optional[str] = None
+
+ # The request id.
+ rid: Optional[str] = None
+
+
+class EmbeddingObject(BaseModel):
+ embedding: List[float]
+ index: int
+ object: str = "embedding"
+
+
+class EmbeddingResponse(BaseModel):
+ data: List[EmbeddingObject]
+ model: str
+ object: str = "list"
+ usage: Optional[UsageInfo] = None
+
+
+class ScoringRequest(BaseModel):
+ query: Optional[Union[str, List[int]]] = (
+ None # Query text or pre-tokenized token IDs
+ )
+ items: Optional[Union[str, List[str], List[List[int]]]] = (
+ None # Item text(s) or pre-tokenized token IDs
+ )
+ label_token_ids: Optional[List[int]] = (
+ None # Token IDs to compute probabilities for
+ )
+ apply_softmax: bool = False
+ item_first: bool = False
+ model: str
+
+
+class ScoringResponse(BaseModel):
+ scores: List[
+ List[float]
+ ] # List of lists of probabilities, each in the order of label_token_ids
+ model: str
+ usage: Optional[UsageInfo] = None
+ object: str = "scoring"
+
+
+OpenAIServingRequest = Union[
+ ChatCompletionRequest, CompletionRequest, EmbeddingRequest, ScoringRequest
+]
diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py
new file mode 100644
index 00000000000..718378f5e0f
--- /dev/null
+++ b/python/sglang/srt/entrypoints/openai/serving_base.py
@@ -0,0 +1,178 @@
+import json
+import logging
+import uuid
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional, Union
+
+from fastapi import Request
+from fastapi.responses import ORJSONResponse, StreamingResponse
+
+from sglang.srt.entrypoints.openai.protocol import (
+ ErrorResponse,
+ OpenAIServingRequest,
+ UsageInfo,
+)
+from sglang.srt.managers.io_struct import GenerateReqInput
+from sglang.srt.managers.tokenizer_manager import TokenizerManager
+
+logger = logging.getLogger(__name__)
+
+
+# Base class for specific endpoint handlers
+class OpenAIServingBase(ABC):
+ """Abstract base class for OpenAI endpoint handlers"""
+
+ def __init__(self, tokenizer_manager: TokenizerManager):
+ self.tokenizer_manager = tokenizer_manager
+
+ async def handle_request(
+ self, request: OpenAIServingRequest, raw_request: Request
+ ) -> Union[Any, StreamingResponse, ErrorResponse]:
+ """Handle the specific request type with common pattern"""
+ try:
+ # Validate request
+ error_msg = self._validate_request(request)
+ if error_msg:
+ return self.create_error_response(error_msg)
+
+ # Convert to internal format
+ adapted_request, processed_request = self._convert_to_internal_request(
+ [request], [self._generate_request_id_base(request)]
+ )
+
+ # Note(Xinyuan): raw_request below is only used for detecting the connection of the client
+ if hasattr(request, "stream") and request.stream:
+ return await self._handle_streaming_request(
+ adapted_request, processed_request, raw_request
+ )
+ else:
+ return await self._handle_non_streaming_request(
+ adapted_request, processed_request, raw_request
+ )
+
+ except Exception as e:
+ logger.error(f"Error in request: {e}")
+ return self.create_error_response(
+ message=f"Internal server error: {str(e)}",
+ err_type="InternalServerError",
+ status_code=500,
+ )
+
+ @abstractmethod
+ def _request_id_prefix(self) -> str:
+ """Generate request ID based on request type"""
+ pass
+
+ def _generate_request_id_base(self, request: OpenAIServingRequest) -> str:
+ """Generate request ID based on request type"""
+ if rid := getattr(request, "rid", None):
+ return rid
+
+ return f"{self._request_id_prefix()}{uuid.uuid4().hex}"
+
+ @abstractmethod
+ def _convert_to_internal_request(
+ self,
+ all_requests: List[OpenAIServingRequest],
+ request_ids: List[str],
+ ) -> tuple[
+ GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]]
+ ]:
+ """Convert OpenAI request to internal format"""
+ pass
+
+ async def _handle_streaming_request(
+ self,
+ adapted_request: GenerateReqInput,
+ request: OpenAIServingRequest,
+ raw_request: Request,
+ ) -> StreamingResponse:
+ """Handle streaming request
+
+ Override this method in child classes that support streaming requests.
+ """
+ return self.create_error_response(
+ message=f"{self.__class__.__name__} does not support streaming requests",
+ err_type="NotImplementedError",
+ status_code=501,
+ )
+
+ async def _handle_non_streaming_request(
+ self,
+ adapted_request: GenerateReqInput,
+ request: OpenAIServingRequest,
+ raw_request: Request,
+ ) -> Union[Any, ErrorResponse]:
+ """Handle non-streaming request
+
+ Override this method in child classes that support non-streaming requests.
+ """
+ return self.create_error_response(
+ message=f"{self.__class__.__name__} does not support non-streaming requests",
+ err_type="NotImplementedError",
+ status_code=501,
+ )
+
+ def _validate_request(self, request: OpenAIServingRequest) -> Optional[str]:
+ """Validate request"""
+ pass
+
+ def _calculate_streaming_usage_base(
+ self,
+ prompt_tokens: Dict[int, int],
+ completion_tokens: Dict[int, int],
+ cached_tokens: Dict[int, int],
+ n_choices: int,
+ ) -> UsageInfo:
+ """Calculate usage information for streaming responses (common logic)"""
+ total_prompt_tokens = sum(
+ tokens for i, tokens in prompt_tokens.items() if i % n_choices == 0
+ )
+ total_completion_tokens = sum(tokens for tokens in completion_tokens.values())
+
+ cache_report = self.tokenizer_manager.server_args.enable_cache_report
+ prompt_tokens_details = None
+ if cache_report:
+ cached_tokens_sum = sum(tokens for tokens in cached_tokens.values())
+ if cached_tokens_sum > 0:
+ prompt_tokens_details = {"cached_tokens": cached_tokens_sum}
+
+ return UsageInfo(
+ prompt_tokens=total_prompt_tokens,
+ completion_tokens=total_completion_tokens,
+ total_tokens=total_prompt_tokens + total_completion_tokens,
+ prompt_tokens_details=prompt_tokens_details,
+ )
+
+ def create_error_response(
+ self,
+ message: str,
+ err_type: str = "BadRequestError",
+ status_code: int = 400,
+ param: Optional[str] = None,
+ ) -> ORJSONResponse:
+ """Create an error response"""
+ error = ErrorResponse(
+ object="error",
+ message=message,
+ type=err_type,
+ param=param,
+ code=status_code,
+ )
+ return ORJSONResponse(content=error.model_dump(), status_code=status_code)
+
+ def create_streaming_error_response(
+ self,
+ message: str,
+ err_type: str = "BadRequestError",
+ status_code: int = 400,
+ ) -> str:
+ """Create a streaming error response"""
+ error = ErrorResponse(
+ object="error",
+ message=message,
+ type=err_type,
+ param=None,
+ code=status_code,
+ )
+ return json.dumps({"error": error.model_dump()})
diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py
new file mode 100644
index 00000000000..54b490131ab
--- /dev/null
+++ b/python/sglang/srt/entrypoints/openai/serving_chat.py
@@ -0,0 +1,938 @@
+import base64
+import json
+import logging
+import time
+import uuid
+from typing import Any, Dict, List, Optional, Union
+
+from fastapi import Request
+from fastapi.responses import StreamingResponse
+
+from sglang.srt.conversation import generate_chat_conv
+from sglang.srt.entrypoints.openai.protocol import (
+ ChatCompletionRequest,
+ ChatCompletionResponse,
+ ChatCompletionResponseChoice,
+ ChatCompletionResponseStreamChoice,
+ ChatCompletionStreamResponse,
+ ChatCompletionTokenLogprob,
+ ChatMessage,
+ ChoiceLogprobs,
+ DeltaMessage,
+ ErrorResponse,
+ FunctionResponse,
+ LogProbs,
+ ToolCall,
+ TopLogprob,
+)
+from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
+from sglang.srt.entrypoints.openai.utils import (
+ aggregate_token_usage,
+ detect_template_content_format,
+ process_content_for_template_format,
+ to_openai_style_logprobs,
+)
+from sglang.srt.function_call.function_call_parser import FunctionCallParser
+from sglang.srt.managers.io_struct import GenerateReqInput
+from sglang.srt.reasoning_parser import ReasoningParser
+from sglang.utils import convert_json_schema_to_str
+
+logger = logging.getLogger(__name__)
+
+
+class OpenAIServingChat(OpenAIServingBase):
+ """Handler for chat completion requests"""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # Instance-specific cache for template content format detection
+ self._cached_chat_template = None
+ self._cached_template_format = None
+
+ def _request_id_prefix(self) -> str:
+ return "chatcmpl-"
+
+ def _validate_request(self, request: ChatCompletionRequest) -> Optional[str]:
+ """Validate chat messages format and content"""
+ if not (messages := request.messages):
+ return "Messages cannot be empty"
+
+ # Check for alternating user/assistant pattern (optional validation)
+ roles = [msg.role for msg in messages]
+
+ # First message should typically be from user or system
+ if roles[0] not in ["user", "system"]:
+ return "First message should be from 'user' or 'system'"
+
+ # Check for consecutive assistant messages (which might indicate an error)
+ for i in range(1, len(roles)):
+ if roles[i] == "assistant" and roles[i - 1] == "assistant":
+ # This is actually allowed in some cases, so just warn
+ pass
+
+ # Validate message content
+ for i, msg in enumerate(messages):
+ if msg.role == "user":
+ if not msg.content:
+ return f"User message at index {i} has no content"
+ elif msg.role == "assistant":
+ # Assistant messages can have no content if they have tool_calls
+ if not msg.content and not getattr(msg, "tool_calls", None):
+ return (
+ f"Assistant message at index {i} has no content or tool calls"
+ )
+
+ return None
+
+ def _convert_to_internal_request(
+ self,
+ all_requests: List[ChatCompletionRequest],
+ request_ids: List[str],
+ ) -> tuple[
+ GenerateReqInput, Union[ChatCompletionRequest, List[ChatCompletionRequest]]
+ ]:
+ """Convert OpenAI chat completion request to internal format"""
+ input_ids = []
+ prompts = []
+ sampling_params_list = []
+ image_data_list = []
+ audio_data_list = []
+ return_logprobs = []
+ logprob_start_lens = []
+ top_logprobs_nums = []
+ modalities_list = []
+ lora_paths = []
+
+ is_multimodal = self.tokenizer_manager.model_config.is_multimodal
+
+ for request in all_requests:
+ # Process messages and apply chat template
+ (
+ prompt,
+ prompt_ids,
+ image_data,
+ audio_data,
+ modalities,
+ stop,
+ tool_call_constraint,
+ ) = self._process_messages(request, is_multimodal)
+
+ input_ids.append(prompt_ids)
+ prompts.append(prompt)
+ return_logprobs.append(request.logprobs)
+ logprob_start_lens.append(-1)
+ top_logprobs_nums.append(request.top_logprobs or 0)
+ lora_paths.append(request.lora_path)
+
+ # Build sampling parameters
+ sampling_params = self._build_sampling_params(
+ request, stop, tool_call_constraint
+ )
+ sampling_params_list.append(sampling_params)
+
+ image_data_list.append(image_data)
+ audio_data_list.append(audio_data)
+ modalities_list.append(modalities)
+
+ # Handle single vs multiple requests
+ if len(all_requests) == 1:
+ if is_multimodal:
+ prompt_kwargs = {"text": prompts[0]}
+ else:
+ if isinstance(input_ids[0], str):
+ prompt_kwargs = {"text": input_ids[0]}
+ else:
+ prompt_kwargs = {"input_ids": input_ids[0]}
+
+ sampling_params_list = sampling_params_list[0]
+ image_data_list = image_data_list[0]
+ audio_data_list = audio_data_list[0]
+ return_logprobs = return_logprobs[0]
+ logprob_start_lens = logprob_start_lens[0]
+ top_logprobs_nums = top_logprobs_nums[0]
+ modalities_list = modalities_list[0]
+ lora_paths = lora_paths[0]
+ request_ids = request_ids[0]
+ else:
+ if is_multimodal:
+ prompt_kwargs = {"text": prompts}
+ else:
+ if isinstance(input_ids[0], str):
+ prompt_kwargs = {"text": input_ids}
+ else:
+ prompt_kwargs = {"input_ids": input_ids}
+
+ adapted_request = GenerateReqInput(
+ **prompt_kwargs,
+ image_data=image_data_list,
+ audio_data=audio_data_list,
+ sampling_params=sampling_params_list,
+ return_logprob=return_logprobs,
+ logprob_start_len=logprob_start_lens,
+ top_logprobs_num=top_logprobs_nums,
+ stream=all_requests[0].stream,
+ return_text_in_logprobs=True,
+ rid=request_ids,
+ modalities=modalities_list,
+ lora_path=lora_paths,
+ bootstrap_host=all_requests[0].bootstrap_host,
+ bootstrap_port=all_requests[0].bootstrap_port,
+ bootstrap_room=all_requests[0].bootstrap_room,
+ )
+
+ return adapted_request, (
+ all_requests if len(all_requests) > 1 else all_requests[0]
+ )
+
+ def _process_messages(
+ self, request: ChatCompletionRequest, is_multimodal: bool
+ ) -> tuple[
+ str,
+ Union[str, List[int]],
+ Optional[Any],
+ Optional[Any],
+ List[str],
+ List[str],
+ Optional[Any],
+ ]:
+ """Process chat messages and apply chat template"""
+ tool_call_constraint = None
+ prompt = ""
+ prompt_ids = []
+
+ if not isinstance(request.messages, str):
+ # Apply chat template and its stop strings
+ tools = None
+ if request.tools and request.tool_choice != "none":
+ request.skip_special_tokens = False
+ if not isinstance(request.tool_choice, str):
+ tools = [
+ item.function.model_dump()
+ for item in request.tools
+ if item.function.name == request.tool_choice.function.name
+ ]
+ else:
+ tools = [item.function.model_dump() for item in request.tools]
+
+ tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
+ parser = FunctionCallParser(request.tools, tool_call_parser)
+ tool_call_constraint = parser.get_structure_constraint(
+ request.tool_choice
+ )
+
+ # Use chat template
+ if (
+ hasattr(self.tokenizer_manager, "chat_template_name")
+ and self.tokenizer_manager.chat_template_name is None
+ ):
+ prompt, prompt_ids, image_data, audio_data, modalities, stop = (
+ self._apply_jinja_template(request, tools, is_multimodal)
+ )
+ else:
+ prompt, image_data, audio_data, modalities, stop = (
+ self._apply_conversation_template(request)
+ )
+ if not is_multimodal:
+ prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)
+ else:
+ # Use raw prompt
+ prompt_ids = request.messages
+ stop = request.stop or []
+ image_data = None
+ audio_data = None
+ modalities = []
+ prompt = request.messages
+
+ return (
+ prompt,
+ prompt_ids,
+ image_data,
+ audio_data,
+ modalities,
+ stop,
+ tool_call_constraint,
+ )
+
+ def _apply_jinja_template(
+ self,
+ request: ChatCompletionRequest,
+ tools: Optional[List[Dict]],
+ is_multimodal: bool,
+ ) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]:
+ """Apply Jinja chat template"""
+ openai_compatible_messages = []
+ image_data = []
+ audio_data = []
+ modalities = []
+
+ # Detect template content format
+ current_template = self.tokenizer_manager.tokenizer.chat_template
+ if current_template != self._cached_chat_template:
+ self._cached_chat_template = current_template
+ self._cached_template_format = detect_template_content_format(
+ current_template
+ )
+ logger.info(
+ f"Detected chat template content format: {self._cached_template_format}"
+ )
+
+ template_content_format = self._cached_template_format
+
+ for message in request.messages:
+ if message.content is None:
+ message.content = ""
+ msg_dict = message.model_dump()
+
+ # Process content based on detected template format
+ processed_msg = process_content_for_template_format(
+ msg_dict,
+ template_content_format,
+ image_data,
+ audio_data,
+ modalities,
+ )
+ openai_compatible_messages.append(processed_msg)
+
+ # Handle assistant prefix for continue_final_message
+ assistant_prefix = None
+ if (
+ openai_compatible_messages
+ and openai_compatible_messages[-1]["role"] == "assistant"
+ ):
+ if request.continue_final_message:
+ assistant_prefix = openai_compatible_messages[-1]["content"]
+ openai_compatible_messages = openai_compatible_messages[:-1]
+
+ try:
+ prompt_ids = self.tokenizer_manager.tokenizer.apply_chat_template(
+ openai_compatible_messages,
+ tokenize=True,
+ add_generation_prompt=True,
+ tools=tools,
+ **(
+ request.chat_template_kwargs if request.chat_template_kwargs else {}
+ ),
+ )
+ except Exception:
+ # This except branch will be triggered when the chosen model
+ # has a different tools input format that is not compatible
+ # with openAI's apply_chat_template tool_call format, like Mistral.
+ tools = (
+ [t if "function" in t else {"function": t} for t in tools]
+ if tools
+ else None
+ )
+ prompt_ids = self.tokenizer_manager.tokenizer.apply_chat_template(
+ openai_compatible_messages,
+ tokenize=True,
+ add_generation_prompt=True,
+ tools=tools,
+ **(
+ request.chat_template_kwargs if request.chat_template_kwargs else {}
+ ),
+ )
+
+ if assistant_prefix:
+ encoded = self.tokenizer_manager.tokenizer.encode(assistant_prefix)
+ if encoded and encoded[0] == self.tokenizer_manager.tokenizer.bos_token_id:
+ encoded = encoded[1:]
+ prompt_ids += encoded
+
+ if is_multimodal:
+ prompt = self.tokenizer_manager.tokenizer.decode(prompt_ids)
+
+ stop = request.stop or []
+ return prompt, prompt_ids, image_data, audio_data, modalities, stop
+
+ def _apply_conversation_template(
+ self, request: ChatCompletionRequest
+ ) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str]]:
+ """Apply conversation template"""
+ conv = generate_chat_conv(request, self.tokenizer_manager.chat_template_name)
+
+ # If we should continue the final assistant message, adjust the conversation.
+ if (
+ request.continue_final_message
+ and request.messages
+ and request.messages[-1].role == "assistant"
+ ):
+ # Remove the auto-added blank assistant turn, if present.
+ if conv.messages and conv.messages[-1][1] is None:
+ conv.messages.pop()
+ # Rebuild the prompt from the conversation.
+ prompt = conv.get_prompt()
+ # Strip trailing stop tokens or separators that indicate end-of-assistant.
+ if isinstance(conv.stop_str, list):
+ for stop_token in conv.stop_str:
+ if prompt.endswith(stop_token):
+ prompt = prompt[: -len(stop_token)]
+ elif isinstance(conv.stop_str, str) and prompt.endswith(conv.stop_str):
+ prompt = prompt[: -len(conv.stop_str)]
+ if conv.sep and prompt.endswith(conv.sep):
+ prompt = prompt[: -len(conv.sep)]
+ if getattr(conv, "sep2", None) and prompt.endswith(conv.sep2):
+ prompt = prompt[: -len(conv.sep2)]
+ else:
+ prompt = conv.get_prompt()
+
+ image_data = conv.image_data
+ audio_data = conv.audio_data
+ modalities = conv.modalities
+ stop = conv.stop_str or [] if not request.ignore_eos else []
+
+ if request.stop:
+ if isinstance(request.stop, str):
+ stop.append(request.stop)
+ else:
+ stop.extend(request.stop)
+
+ return prompt, image_data, audio_data, modalities, stop
+
+ def _build_sampling_params(
+ self,
+ request: ChatCompletionRequest,
+ stop: List[str],
+ tool_call_constraint: Optional[Any],
+ ) -> Dict[str, Any]:
+ """Build sampling parameters for the request"""
+
+ sampling_params = {
+ "temperature": request.temperature,
+ "max_new_tokens": request.max_tokens or request.max_completion_tokens,
+ "min_new_tokens": request.min_tokens,
+ "stop": stop,
+ "stop_token_ids": request.stop_token_ids,
+ "top_p": request.top_p,
+ "top_k": request.top_k,
+ "min_p": request.min_p,
+ "presence_penalty": request.presence_penalty,
+ "frequency_penalty": request.frequency_penalty,
+ "repetition_penalty": request.repetition_penalty,
+ "regex": request.regex,
+ "ebnf": request.ebnf,
+ "n": request.n,
+ "no_stop_trim": request.no_stop_trim,
+ "ignore_eos": request.ignore_eos,
+ "skip_special_tokens": request.skip_special_tokens,
+ "logit_bias": request.logit_bias,
+ }
+
+ if request.response_format and request.response_format.type == "json_schema":
+ sampling_params["json_schema"] = convert_json_schema_to_str(
+ request.response_format.json_schema.schema_
+ )
+ elif request.response_format and request.response_format.type == "json_object":
+ sampling_params["json_schema"] = '{"type": "object"}'
+ elif (
+ request.response_format and request.response_format.type == "structural_tag"
+ ):
+ sampling_params["structural_tag"] = convert_json_schema_to_str(
+ request.response_format.model_dump(by_alias=True)
+ )
+
+ # Check if there are already existing output constraints
+ has_existing_constraints = (
+ sampling_params.get("regex")
+ or sampling_params.get("ebnf")
+ or sampling_params.get("structural_tag")
+ or sampling_params.get("json_schema")
+ )
+
+ if tool_call_constraint and has_existing_constraints:
+ logger.warning("Constrained decoding is not compatible with tool calls.")
+ elif tool_call_constraint:
+ constraint_type, constraint_value = tool_call_constraint
+ if constraint_type == "structural_tag":
+ sampling_params[constraint_type] = convert_json_schema_to_str(
+ constraint_value.model_dump(by_alias=True)
+ )
+ else:
+ sampling_params[constraint_type] = constraint_value
+ return sampling_params
+
+ async def _handle_streaming_request(
+ self,
+ adapted_request: GenerateReqInput,
+ request: ChatCompletionRequest,
+ raw_request: Request,
+ ) -> StreamingResponse:
+ """Handle streaming chat completion request"""
+
+ async def generate_stream_resp():
+ parser_dict = {}
+ reasoning_parser_dict = {}
+ tool_call_first = True
+ is_firsts = {}
+ stream_buffers = {}
+ n_prev_tokens = {}
+ prompt_tokens = {}
+ completion_tokens = {}
+ cached_tokens = {}
+
+ try:
+ async for content in self.tokenizer_manager.generate_request(
+ adapted_request, raw_request
+ ):
+ index = content.get("index", 0)
+
+ is_first = is_firsts.get(index, True)
+ stream_buffer = stream_buffers.get(index, "")
+ n_prev_token = n_prev_tokens.get(index, 0)
+
+ prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
+ completion_tokens[index] = content["meta_info"]["completion_tokens"]
+ cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
+
+ # Handle logprobs
+ choice_logprobs = None
+ if request.logprobs:
+ choice_logprobs = self._process_streaming_logprobs(
+ content, n_prev_token
+ )
+ n_prev_token = len(
+ content["meta_info"]["output_token_logprobs"]
+ )
+
+ finish_reason = content["meta_info"]["finish_reason"]
+ finish_reason_type = (
+ finish_reason["type"] if finish_reason else None
+ )
+
+ # First chunk with role
+ if is_first:
+ is_first = False
+ delta = DeltaMessage(role="assistant")
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=index,
+ delta=delta,
+ finish_reason=finish_reason_type,
+ matched_stop=(
+ finish_reason["matched"]
+ if finish_reason and "matched" in finish_reason
+ else None
+ ),
+ logprobs=choice_logprobs,
+ )
+ chunk = ChatCompletionStreamResponse(
+ id=content["meta_info"]["id"],
+ created=int(time.time()),
+ choices=[choice_data],
+ model=request.model,
+ )
+ yield f"data: {chunk.model_dump_json()}\n\n"
+
+ # Process content delta
+ delta = content["text"][len(stream_buffer) :]
+ new_stream_buffer = stream_buffer + delta
+
+ # Handle reasoning content
+ enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
+ "enable_thinking", True
+ )
+ if (
+ self.tokenizer_manager.server_args.reasoning_parser
+ and request.separate_reasoning
+ and enable_thinking
+ ):
+ reasoning_text, delta = self._process_reasoning_stream(
+ index, delta, reasoning_parser_dict, content, request
+ )
+ if reasoning_text:
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=index,
+ delta=DeltaMessage(reasoning_content=reasoning_text),
+ finish_reason=finish_reason_type,
+ )
+ chunk = ChatCompletionStreamResponse(
+ id=content["meta_info"]["id"],
+ created=int(time.time()),
+ choices=[choice_data],
+ model=request.model,
+ )
+ yield f"data: {chunk.model_dump_json()}\n\n"
+
+ if not delta:
+ stream_buffers[index] = new_stream_buffer
+ is_firsts[index] = is_first
+ n_prev_tokens[index] = n_prev_token
+ continue
+
+ # Handle tool calls
+ if request.tool_choice != "none" and request.tools:
+ async for chunk in self._process_tool_call_stream(
+ index,
+ delta,
+ parser_dict,
+ content,
+ request,
+ finish_reason_type,
+ ):
+ yield chunk
+ else:
+ # Regular content
+ if delta or not (
+ request.stream_options
+ and request.stream_options.include_usage
+ ):
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=index,
+ delta=DeltaMessage(content=delta if delta else None),
+ finish_reason=(
+ None
+ if request.stream_options
+ and request.stream_options.include_usage
+ else finish_reason_type
+ ),
+ matched_stop=(
+ finish_reason["matched"]
+ if finish_reason and "matched" in finish_reason
+ else None
+ ),
+ logprobs=choice_logprobs,
+ )
+ chunk = ChatCompletionStreamResponse(
+ id=content["meta_info"]["id"],
+ created=int(time.time()),
+ choices=[choice_data],
+ model=request.model,
+ )
+ yield f"data: {chunk.model_dump_json()}\n\n"
+
+ stream_buffers[index] = new_stream_buffer
+ is_firsts[index] = is_first
+ n_prev_tokens[index] = n_prev_token
+
+ # Final chunk with usage
+ if request.stream_options and request.stream_options.include_usage:
+ usage = self._calculate_streaming_usage_base(
+ prompt_tokens, completion_tokens, cached_tokens, request.n
+ )
+ else:
+ usage = None
+
+ final_chunk = ChatCompletionStreamResponse(
+ id=content["meta_info"]["id"],
+ created=int(time.time()),
+ choices=[
+ ChatCompletionResponseStreamChoice(
+ index=index,
+ delta=DeltaMessage(),
+ finish_reason=finish_reason_type,
+ )
+ ],
+ model=request.model,
+ usage=usage,
+ )
+ yield f"data: {final_chunk.model_dump_json()}\n\n"
+
+ except Exception as e:
+ error = self.create_streaming_error_response(str(e))
+ yield f"data: {error}\n\n"
+
+ yield "data: [DONE]\n\n"
+
+ return StreamingResponse(
+ generate_stream_resp(),
+ media_type="text/event-stream",
+ background=self.tokenizer_manager.create_abort_task(adapted_request),
+ )
+
+ async def _handle_non_streaming_request(
+ self,
+ adapted_request: GenerateReqInput,
+ request: ChatCompletionRequest,
+ raw_request: Request,
+ ) -> Union[ChatCompletionResponse, ErrorResponse]:
+ """Handle non-streaming chat completion request"""
+ try:
+ ret = await self.tokenizer_manager.generate_request(
+ adapted_request, raw_request
+ ).__anext__()
+ except ValueError as e:
+ return self.create_error_response(str(e))
+
+ if not isinstance(ret, list):
+ ret = [ret]
+
+ response = self._build_chat_response(
+ request,
+ ret,
+ int(time.time()),
+ cache_report=self.tokenizer_manager.server_args.enable_cache_report,
+ tool_call_parser=self.tokenizer_manager.server_args.tool_call_parser,
+ reasoning_parser=self.tokenizer_manager.server_args.reasoning_parser,
+ )
+
+ return response
+
+ def _build_chat_response(
+ self,
+ request: ChatCompletionRequest,
+ ret: List[Dict[str, Any]],
+ created: int,
+ cache_report: bool = False,
+ tool_call_parser: Optional[str] = None,
+ reasoning_parser: Optional[str] = None,
+ ) -> ChatCompletionResponse:
+ """Build chat completion response from generation results"""
+ choices = []
+
+ for idx, ret_item in enumerate(ret):
+ # Process logprobs
+ choice_logprobs = None
+ if request.logprobs:
+ choice_logprobs = self._process_response_logprobs(ret_item)
+
+ finish_reason = ret_item["meta_info"]["finish_reason"]
+ text = ret_item["text"]
+
+ # Handle reasoning content
+ reasoning_text = None
+ enable_thinking = getattr(request, "chat_template_kwargs", {}).get(
+ "enable_thinking", True
+ )
+ if reasoning_parser and request.separate_reasoning and enable_thinking:
+ try:
+ parser = ReasoningParser(
+ model_type=reasoning_parser, stream_reasoning=False
+ )
+ reasoning_text, text = parser.parse_non_stream(text)
+ except Exception as e:
+ logger.error(f"Reasoning parsing error: {e}")
+ return self.create_error_response(
+ "Failed to parse reasoning content",
+ err_type="InternalServerError",
+ status_code=500,
+ )
+
+ # Handle tool calls
+ tool_calls = None
+ if request.tool_choice != "none" and request.tools:
+ tool_calls, text, finish_reason = self._process_tool_calls(
+ text, request.tools, tool_call_parser, finish_reason
+ )
+
+ choice_data = ChatCompletionResponseChoice(
+ index=idx,
+ message=ChatMessage(
+ role="assistant",
+ content=text if text else None,
+ tool_calls=tool_calls,
+ reasoning_content=reasoning_text if reasoning_text else None,
+ ),
+ logprobs=choice_logprobs,
+ finish_reason=finish_reason["type"] if finish_reason else None,
+ matched_stop=(
+ finish_reason["matched"]
+ if finish_reason and "matched" in finish_reason
+ else None
+ ),
+ )
+ choices.append(choice_data)
+
+ # Calculate usage
+ usage = aggregate_token_usage(ret, request.n, cache_report)
+
+ return ChatCompletionResponse(
+ id=ret[0]["meta_info"]["id"],
+ created=created,
+ model=request.model,
+ choices=choices,
+ usage=usage,
+ )
+
+ def _process_logprobs_tokens(
+ self, logprobs: LogProbs, use_token_index: bool = False
+ ) -> List[ChatCompletionTokenLogprob]:
+ """Common helper to process logprobs tokens for both streaming and non-streaming
+
+ Args:
+ logprobs: LogProbs data from model
+ use_token_index: True for non-streaming (use token_idx), False for streaming (use index 0)
+ """
+ token_logprobs = []
+
+ for token_idx, (token, logprob) in enumerate(
+ zip(logprobs.tokens, logprobs.token_logprobs)
+ ):
+ token_bytes = list(token.encode("utf-8"))
+ top_logprobs = []
+ if logprobs.top_logprobs:
+ # - Non-streaming (use_token_index=True): uses token_idx for full data
+ # - Streaming (use_token_index=False): uses index 0 for pre-sliced data
+ top_logprobs_idx = token_idx if use_token_index else 0
+ for top_token, top_logprob in logprobs.top_logprobs[
+ top_logprobs_idx
+ ].items():
+ top_token_bytes = list(top_token.encode("utf-8"))
+ top_logprobs.append(
+ TopLogprob(
+ token=top_token,
+ bytes=top_token_bytes,
+ logprob=top_logprob,
+ )
+ )
+ token_logprobs.append(
+ ChatCompletionTokenLogprob(
+ token=token,
+ bytes=token_bytes,
+ logprob=logprob,
+ top_logprobs=top_logprobs,
+ )
+ )
+
+ return token_logprobs
+
+ def _process_response_logprobs(self, ret_item: Dict[str, Any]) -> ChoiceLogprobs:
+ """Process logprobs for non-streaming response"""
+ logprobs = to_openai_style_logprobs(
+ output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
+ output_top_logprobs=ret_item["meta_info"].get("output_top_logprobs", None),
+ )
+
+ token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True)
+ return ChoiceLogprobs(content=token_logprobs)
+
+ def _process_tool_calls(
+ self,
+ text: str,
+ tools: List[Any],
+ tool_call_parser: Optional[str],
+ finish_reason: Dict[str, Any],
+ ) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
+ """Process tool calls in the response"""
+ parser = FunctionCallParser(tools, tool_call_parser)
+ if parser.has_tool_call(text):
+ if finish_reason["type"] == "stop":
+ finish_reason["type"] = "tool_calls"
+ finish_reason["matched"] = None
+ try:
+ text, call_info_list = parser.parse_non_stream(text)
+ tool_calls = [
+ ToolCall(
+ id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
+ function=FunctionResponse(
+ name=call_info.name, arguments=call_info.parameters
+ ),
+ )
+ for call_info in call_info_list
+ ]
+ return tool_calls, text, finish_reason
+ except Exception as e:
+ logger.error(f"Tool call parsing error: {e}")
+ # Return error but don't fail the whole request
+ return None, text, finish_reason
+
+ return None, text, finish_reason
+
+ def _process_streaming_logprobs(
+ self, content: Dict[str, Any], n_prev_token: int
+ ) -> ChoiceLogprobs:
+ """Process logprobs for streaming response"""
+ logprobs = to_openai_style_logprobs(
+ output_token_logprobs=content["meta_info"]["output_token_logprobs"][
+ n_prev_token:
+ ],
+ output_top_logprobs=content["meta_info"].get("output_top_logprobs", [])[
+ n_prev_token:
+ ],
+ )
+
+ token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=False)
+ return ChoiceLogprobs(content=token_logprobs)
+
+ def _process_reasoning_stream(
+ self,
+ index: int,
+ delta: str,
+ reasoning_parser_dict: Dict[int, ReasoningParser],
+ content: Dict[str, Any],
+ request: ChatCompletionRequest,
+ ) -> tuple[Optional[str], str]:
+ """Process reasoning content in streaming response"""
+ if index not in reasoning_parser_dict:
+ reasoning_parser_dict[index] = ReasoningParser(
+ self.tokenizer_manager.server_args.reasoning_parser,
+ request.stream_reasoning,
+ )
+ reasoning_parser = reasoning_parser_dict[index]
+ return reasoning_parser.parse_stream_chunk(delta)
+
+ async def _process_tool_call_stream(
+ self,
+ index: int,
+ delta: str,
+ parser_dict: Dict[int, FunctionCallParser],
+ content: Dict[str, Any],
+ request: ChatCompletionRequest,
+ finish_reason_type: Optional[str],
+ ):
+ """Process tool calls in streaming response"""
+ if index not in parser_dict:
+ parser_dict[index] = FunctionCallParser(
+ tools=request.tools,
+ tool_call_parser=self.tokenizer_manager.server_args.tool_call_parser,
+ )
+ parser = parser_dict[index]
+
+ normal_text, calls = parser.parse_stream_chunk(delta)
+
+ # Yield normal text
+ if normal_text:
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=index,
+ delta=DeltaMessage(content=normal_text),
+ finish_reason=finish_reason_type,
+ )
+ chunk = ChatCompletionStreamResponse(
+ id=content["meta_info"]["id"],
+ created=int(time.time()),
+ choices=[choice_data],
+ model=request.model,
+ )
+ yield f"data: {chunk.model_dump_json()}\n\n"
+
+ # Yield tool calls
+ for call_item in calls:
+ if finish_reason_type == "stop":
+ # Handle remaining arguments
+ latest_delta_len = 0
+ if isinstance(call_item.parameters, str):
+ latest_delta_len = len(call_item.parameters)
+
+ expected_call = json.dumps(
+ parser.detector.prev_tool_call_arr[index].get("arguments", {}),
+ ensure_ascii=False,
+ )
+ actual_call = parser.detector.streamed_args_for_tool[index]
+ if latest_delta_len > 0:
+ actual_call = actual_call[:-latest_delta_len]
+ remaining_call = expected_call.replace(actual_call, "", 1)
+ call_item.parameters = remaining_call
+ finish_reason_type = "tool_calls"
+
+ tool_call = ToolCall(
+ id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
+ index=call_item.tool_index,
+ function=FunctionResponse(
+ name=call_item.name,
+ arguments=call_item.parameters,
+ ),
+ )
+
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=index,
+ delta=DeltaMessage(tool_calls=[tool_call]),
+ finish_reason=(
+ None
+ if request.stream_options and request.stream_options.include_usage
+ else finish_reason_type
+ ),
+ )
+ chunk = ChatCompletionStreamResponse(
+ id=content["meta_info"]["id"],
+ created=int(time.time()),
+ choices=[choice_data],
+ model=request.model,
+ )
+ yield f"data: {chunk.model_dump_json()}\n\n"
diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py
new file mode 100644
index 00000000000..af501727562
--- /dev/null
+++ b/python/sglang/srt/entrypoints/openai/serving_completions.py
@@ -0,0 +1,467 @@
+import time
+from typing import Any, Dict, List, Optional, Union
+
+from fastapi import Request
+from fastapi.responses import StreamingResponse
+
+from sglang.srt.code_completion_parser import (
+ generate_completion_prompt_from_request,
+ is_completion_template_defined,
+)
+from sglang.srt.entrypoints.openai.protocol import (
+ CompletionRequest,
+ CompletionResponse,
+ CompletionResponseChoice,
+ CompletionResponseStreamChoice,
+ CompletionStreamResponse,
+ ErrorResponse,
+)
+from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
+from sglang.srt.entrypoints.openai.utils import (
+ aggregate_token_usage,
+ to_openai_style_logprobs,
+)
+from sglang.srt.managers.io_struct import GenerateReqInput
+
+
+class OpenAIServingCompletion(OpenAIServingBase):
+ """Handler for completion requests"""
+
+ def _request_id_prefix(self) -> str:
+ return "cmpl-"
+
+ def _validate_request(self, request: CompletionRequest) -> Optional[str]:
+ """Validate completion prompt format and content"""
+ if not (prompt := request.prompt):
+ return "Prompt cannot be None"
+
+ if isinstance(prompt, str):
+ if not prompt.strip():
+ return "Prompt cannot be empty or whitespace only"
+ elif isinstance(prompt, list):
+ if not prompt:
+ return "Prompt list cannot be empty"
+
+ # Check if it's a list of strings
+ if all(isinstance(item, str) for item in prompt):
+ for i, item in enumerate(prompt):
+ if not item.strip():
+ return f"Prompt at index {i} cannot be empty or whitespace only"
+
+ # Check if it's a list of token IDs (integers)
+ elif all(isinstance(item, int) for item in prompt):
+ if any(item < 0 for item in prompt):
+ return "Token IDs must be non-negative"
+
+ # Check if it's a list of lists (multiple token sequences)
+ elif all(isinstance(item, list) for item in prompt):
+ for i, item in enumerate(prompt):
+ if not item:
+ return f"Token sequence at index {i} cannot be empty"
+ if not all(isinstance(token, int) for token in item):
+ return f"Token sequence at index {i} must contain only integers"
+ if any(token < 0 for token in item):
+ return (
+ f"Token sequence at index {i} contains negative token IDs"
+ )
+ else:
+ return "Prompt must be string, list of strings, list of integers, or list of integer lists"
+ else:
+ return "Prompt must be string or list"
+
+ return None
+
+ def _convert_to_internal_request(
+ self,
+ all_requests: List[CompletionRequest],
+ request_ids: List[str],
+ ) -> tuple[GenerateReqInput, Union[CompletionRequest, List[CompletionRequest]]]:
+ """Convert OpenAI completion request to internal format"""
+ # Validate batch requests
+ if len(all_requests) > 1:
+ first_prompt_type = type(all_requests[0].prompt)
+ for request in all_requests:
+ assert (
+ type(request.prompt) is first_prompt_type
+ ), "All prompts must be of the same type in file input settings"
+ if request.n > 1:
+ raise ValueError(
+ "Parallel sampling is not supported for completions from files"
+ )
+
+ prompts = []
+ sampling_params_list = []
+ return_logprobs = []
+ logprob_start_lens = []
+ top_logprobs_nums = []
+ lora_paths = []
+
+ for request in all_requests:
+ # Process prompt
+ prompt = request.prompt
+ if is_completion_template_defined():
+ prompt = generate_completion_prompt_from_request(request)
+
+ prompts.append(prompt)
+
+ lora_paths.append(request.lora_path)
+
+ # Set logprob start length based on echo and logprobs
+ if request.echo and request.logprobs:
+ current_logprob_start_len = 0
+ else:
+ current_logprob_start_len = -1
+
+ # Build sampling parameters
+ sampling_params = self._build_sampling_params(request)
+ sampling_params_list.append(sampling_params)
+
+ return_logprobs.append(request.logprobs is not None)
+ logprob_start_lens.append(current_logprob_start_len)
+ top_logprobs_nums.append(
+ request.logprobs if request.logprobs is not None else 0
+ )
+
+ # Handle single vs multiple requests
+ if len(all_requests) == 1:
+ if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
+ prompt_kwargs = {"text": prompts[0]}
+ else:
+ prompt_kwargs = {"input_ids": prompts[0]}
+ sampling_params_list = sampling_params_list[0]
+ return_logprobs = return_logprobs[0]
+ logprob_start_lens = logprob_start_lens[0]
+ top_logprobs_nums = top_logprobs_nums[0]
+ lora_paths = lora_paths[0]
+ request_ids = request_ids[0]
+ else:
+ if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
+ prompt_kwargs = {"text": prompts}
+ else:
+ prompt_kwargs = {"input_ids": prompts}
+
+ adapted_request = GenerateReqInput(
+ **prompt_kwargs,
+ sampling_params=sampling_params_list,
+ return_logprob=return_logprobs,
+ top_logprobs_num=top_logprobs_nums,
+ logprob_start_len=logprob_start_lens,
+ return_text_in_logprobs=True,
+ stream=all_requests[0].stream,
+ rid=request_ids,
+ lora_path=lora_paths,
+ bootstrap_host=all_requests[0].bootstrap_host,
+ bootstrap_port=all_requests[0].bootstrap_port,
+ bootstrap_room=all_requests[0].bootstrap_room,
+ )
+
+ return adapted_request, (
+ all_requests if len(all_requests) > 1 else all_requests[0]
+ )
+
+ def _build_sampling_params(self, request: CompletionRequest) -> Dict[str, Any]:
+ """Build sampling parameters for the request"""
+ # Start with common parameters
+ sampling_params = {
+ "temperature": request.temperature,
+ "max_new_tokens": request.max_tokens,
+ "min_new_tokens": request.min_tokens,
+ "stop": request.stop,
+ "stop_token_ids": request.stop_token_ids,
+ "top_p": request.top_p,
+ "top_k": request.top_k,
+ "min_p": request.min_p,
+ "presence_penalty": request.presence_penalty,
+ "frequency_penalty": request.frequency_penalty,
+ "repetition_penalty": request.repetition_penalty,
+ "regex": request.regex,
+ "json_schema": request.json_schema,
+ "ebnf": request.ebnf,
+ "n": request.n,
+ "no_stop_trim": request.no_stop_trim,
+ "ignore_eos": request.ignore_eos,
+ "skip_special_tokens": request.skip_special_tokens,
+ "logit_bias": request.logit_bias,
+ }
+
+ # No additional completion-specific parameters needed currently
+ # (json_schema is already handled in base method)
+
+ return sampling_params
+
+ async def _handle_streaming_request(
+ self,
+ adapted_request: GenerateReqInput,
+ request: CompletionRequest,
+ raw_request: Request,
+ ) -> StreamingResponse:
+ """Handle streaming completion request"""
+ created = int(time.time())
+
+ async def generate_stream_resp():
+ stream_buffers = {}
+ n_prev_tokens = {}
+ prompt_tokens = {}
+ completion_tokens = {}
+ cached_tokens = {}
+
+ try:
+ async for content in self.tokenizer_manager.generate_request(
+ adapted_request, raw_request
+ ):
+ index = content.get("index", 0)
+
+ stream_buffer = stream_buffers.get(index, "")
+ n_prev_token = n_prev_tokens.get(index, 0)
+
+ text = content["text"]
+ prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
+ completion_tokens[index] = content["meta_info"]["completion_tokens"]
+ cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
+
+ # Handle echo for first chunk
+ if not stream_buffer: # The first chunk
+ if request.echo:
+ echo_text = self._get_echo_text(request, index)
+ text = echo_text + text
+
+ # Handle logprobs
+ logprobs = None
+ if request.logprobs is not None:
+ # The first chunk and echo is enabled.
+ if not stream_buffer and request.echo:
+ input_token_logprobs = content["meta_info"][
+ "input_token_logprobs"
+ ]
+ input_top_logprobs = content["meta_info"][
+ "input_top_logprobs"
+ ]
+ else:
+ input_token_logprobs = None
+ input_top_logprobs = None
+
+ logprobs = to_openai_style_logprobs(
+ input_token_logprobs=input_token_logprobs,
+ input_top_logprobs=input_top_logprobs,
+ output_token_logprobs=content["meta_info"][
+ "output_token_logprobs"
+ ][n_prev_token:],
+ output_top_logprobs=content["meta_info"][
+ "output_top_logprobs"
+ ][n_prev_token:],
+ )
+ n_prev_token = len(
+ content["meta_info"]["output_token_logprobs"]
+ )
+
+ # Generate delta
+ delta = text[len(stream_buffer) :]
+ stream_buffer = stream_buffer + delta
+ finish_reason = content["meta_info"]["finish_reason"]
+
+ choice_data = CompletionResponseStreamChoice(
+ index=index,
+ text=delta,
+ logprobs=logprobs,
+ finish_reason=finish_reason["type"] if finish_reason else None,
+ matched_stop=(
+ finish_reason["matched"]
+ if finish_reason and "matched" in finish_reason
+ else None
+ ),
+ )
+ chunk = CompletionStreamResponse(
+ id=content["meta_info"]["id"],
+ created=created,
+ object="text_completion",
+ choices=[choice_data],
+ model=request.model,
+ )
+
+ stream_buffers[index] = stream_buffer
+ n_prev_tokens[index] = n_prev_token
+
+ yield f"data: {chunk.model_dump_json()}\n\n"
+
+ # Handle final usage chunk
+ if request.stream_options and request.stream_options.include_usage:
+ usage = self._calculate_streaming_usage_base(
+ prompt_tokens, completion_tokens, cached_tokens, request.n
+ )
+ final_usage_chunk = CompletionStreamResponse(
+ id=content["meta_info"]["id"],
+ created=created,
+ choices=[],
+ model=request.model,
+ usage=usage,
+ )
+ final_usage_data = final_usage_chunk.model_dump_json(
+ exclude_none=True
+ )
+ yield f"data: {final_usage_data}\n\n"
+
+ except Exception as e:
+ error = self.create_streaming_error_response(str(e))
+ yield f"data: {error}\n\n"
+
+ yield "data: [DONE]\n\n"
+
+ return StreamingResponse(
+ generate_stream_resp(),
+ media_type="text/event-stream",
+ background=self.tokenizer_manager.create_abort_task(adapted_request),
+ )
+
+ async def _handle_non_streaming_request(
+ self,
+ adapted_request: GenerateReqInput,
+ request: CompletionRequest,
+ raw_request: Request,
+ ) -> Union[CompletionResponse, ErrorResponse]:
+ """Handle non-streaming completion request"""
+ try:
+ generator = self.tokenizer_manager.generate_request(
+ adapted_request, raw_request
+ )
+ ret = await generator.__anext__()
+ except ValueError as e:
+ return self.create_error_response(str(e))
+
+ if not isinstance(ret, list):
+ ret = [ret]
+
+ response = self._build_completion_response(
+ request,
+ ret,
+ int(time.time()),
+ cache_report=self.tokenizer_manager.server_args.enable_cache_report,
+ )
+
+ return response
+
+ def _build_completion_response(
+ self,
+ request: CompletionRequest,
+ ret: List[Dict[str, Any]],
+ created: int,
+ cache_report: bool = False,
+ ) -> CompletionResponse:
+ """Build completion response from generation results"""
+ choices = []
+ echo = False
+
+ # Prepare echo prompts if needed
+ echo_prompts = []
+ if (not isinstance(request, list)) and request.echo:
+ echo_prompts = self._prepare_echo_prompts(request)
+ echo = True
+
+ for idx, ret_item in enumerate(ret):
+ text = ret_item["text"]
+
+ # Handle echo
+ if isinstance(request, list) and request[idx].echo:
+ echo = True
+ text = request[idx].prompt + text
+ elif echo and not isinstance(request, list):
+ prompt_index = idx // request.n
+ text = echo_prompts[prompt_index] + text
+
+ # Handle logprobs
+ logprobs = None
+ if isinstance(request, list) and request[idx].logprobs is not None:
+ logprobs = True
+ elif (not isinstance(request, list)) and request.logprobs is not None:
+ logprobs = True
+
+ if logprobs:
+ if echo:
+ input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
+ input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
+ else:
+ input_token_logprobs = None
+ input_top_logprobs = None
+
+ logprobs = to_openai_style_logprobs(
+ input_token_logprobs=input_token_logprobs,
+ input_top_logprobs=input_top_logprobs,
+ output_token_logprobs=ret_item["meta_info"][
+ "output_token_logprobs"
+ ],
+ output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
+ )
+
+ finish_reason = ret_item["meta_info"]["finish_reason"]
+
+ choice_data = CompletionResponseChoice(
+ index=idx,
+ text=text,
+ logprobs=logprobs,
+ finish_reason=finish_reason["type"] if finish_reason else None,
+ matched_stop=(
+ finish_reason["matched"]
+ if finish_reason and "matched" in finish_reason
+ else None
+ ),
+ )
+ choices.append(choice_data)
+
+ # Calculate usage
+ usage = aggregate_token_usage(ret, request.n, cache_report)
+
+ return CompletionResponse(
+ id=ret[0]["meta_info"]["id"],
+ model=request.model,
+ created=created,
+ choices=choices,
+ usage=usage,
+ )
+
+ def _get_echo_text(self, request: CompletionRequest, index: int) -> str:
+ """Get echo text for streaming response"""
+ if isinstance(request.prompt, str):
+ # for the case of single str prompts
+ return request.prompt
+ elif isinstance(request.prompt, list):
+ if isinstance(request.prompt[0], str):
+ # for the case of multiple str prompts
+ return request.prompt[index // request.n]
+ elif isinstance(request.prompt[0], int):
+ # for the case of single token ids prompt
+ return self.tokenizer_manager.tokenizer.decode(
+ request.prompt, skip_special_tokens=True
+ )
+ elif isinstance(request.prompt[0], list) and isinstance(
+ request.prompt[0][0], int
+ ):
+ # for the case of multiple token ids prompts
+ return self.tokenizer_manager.tokenizer.decode(
+ request.prompt[index // request.n],
+ skip_special_tokens=True,
+ )
+ return ""
+
+ def _prepare_echo_prompts(self, request: CompletionRequest) -> List[str]:
+ """Prepare echo prompts for non-streaming response"""
+ # TODO: handle the case prompt is token ids
+ if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
+ # for the case of multiple str prompts
+ return request.prompt
+ elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
+ # for the case of multiple token ids prompts
+ return [
+ self.tokenizer_manager.tokenizer.decode(
+ prompt, skip_special_tokens=True
+ )
+ for prompt in request.prompt
+ ]
+ elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
+ # for the case of single token ids prompt
+ return [
+ self.tokenizer_manager.tokenizer.decode(
+ request.prompt, skip_special_tokens=True
+ )
+ ]
+ else:
+ # for the case of single str prompt
+ return [request.prompt]
diff --git a/python/sglang/srt/entrypoints/openai/serving_embedding.py b/python/sglang/srt/entrypoints/openai/serving_embedding.py
new file mode 100644
index 00000000000..33d1e591878
--- /dev/null
+++ b/python/sglang/srt/entrypoints/openai/serving_embedding.py
@@ -0,0 +1,227 @@
+from typing import Any, Dict, List, Optional, Union
+
+from fastapi import Request
+
+from sglang.srt.conversation import generate_embedding_convs
+from sglang.srt.entrypoints.openai.protocol import (
+ EmbeddingObject,
+ EmbeddingRequest,
+ EmbeddingResponse,
+ ErrorResponse,
+ MultimodalEmbeddingInput,
+ UsageInfo,
+)
+from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
+from sglang.srt.managers.io_struct import EmbeddingReqInput
+
+
+class OpenAIServingEmbedding(OpenAIServingBase):
+ """Handler for embedding requests"""
+
+ def _request_id_prefix(self) -> str:
+ return "embd-"
+
+ def _validate_request(self, request: EmbeddingRequest) -> Optional[str]:
+ """Validate that the input is not empty or whitespace only."""
+ if not (input := request.input):
+ return "Input cannot be empty"
+
+ # Handle single string
+ if isinstance(input, str):
+ if not input.strip():
+ return "Input cannot be empty or whitespace only"
+ return None
+
+ # Handle list inputs
+ if isinstance(input, list):
+ if len(input) == 0:
+ return "Input cannot be empty"
+
+ # Check first element to determine type
+ first_item = input[0]
+
+ if isinstance(first_item, str):
+ # List of strings
+ for i, item in enumerate(input):
+ if not isinstance(item, str):
+ return f"All items in input list must be strings"
+ if not item.strip():
+ return f"Input at index {i} cannot be empty or whitespace only"
+ elif isinstance(first_item, int):
+ # List of integers (token IDs)
+ for i, item in enumerate(input):
+ if not isinstance(item, int):
+ return f"All items in input list must be integers"
+ if item < 0:
+ return f"Token ID at index {i} must be non-negative"
+ elif isinstance(first_item, list):
+ # List of lists (multiple token sequences)
+ for i, item in enumerate(input):
+ if not isinstance(item, list):
+ return f"Input at index {i} must be a list"
+ if not item:
+ return f"Input at index {i} cannot be empty"
+ if not all(isinstance(token, int) for token in item):
+ return f"Input at index {i} must contain only integers"
+ if any(token < 0 for token in item):
+ return f"Input at index {i} contains negative token IDs"
+ # Note: MultimodalEmbeddingInput validation would be handled by Pydantic
+
+ return None
+
+ def _convert_to_internal_request(
+ self,
+ all_requests: List[EmbeddingRequest],
+ request_ids: List[str],
+ ) -> tuple[EmbeddingReqInput, Union[EmbeddingRequest, List[EmbeddingRequest]]]:
+ """Convert OpenAI embedding request to internal format"""
+ prompts = [request.input for request in all_requests]
+
+ # Handle single vs multiple requests
+ if len(all_requests) == 1:
+ prompt = prompts[0]
+ if isinstance(prompt, str):
+ # Single string input
+ prompt_kwargs = {"text": prompt}
+ elif isinstance(prompt, list):
+ if len(prompt) > 0 and isinstance(prompt[0], str):
+ # List of strings
+ prompt_kwargs = {"text": prompt}
+ elif len(prompt) > 0 and isinstance(
+ prompt[0], MultimodalEmbeddingInput
+ ):
+ # Handle multimodal embedding inputs
+ texts = []
+ images = []
+ for item in prompt:
+ # Use padding for text if None - this could be improved
+ texts.append(item.text if item.text is not None else "padding")
+ images.append(item.image if item.image is not None else None)
+
+ generate_prompts = []
+ # Check if we have a chat template for multimodal embeddings
+ # This would need to be passed in from the server configuration
+ chat_template_name = getattr(
+ self.tokenizer_manager, "chat_template_name", None
+ )
+ if chat_template_name is not None:
+ convs = generate_embedding_convs(
+ texts, images, chat_template_name
+ )
+ for conv in convs:
+ generate_prompts.append(conv.get_prompt())
+ else:
+ generate_prompts = texts
+
+ if len(generate_prompts) == 1:
+ prompt_kwargs = {
+ "text": generate_prompts[0],
+ "image_data": images[0],
+ }
+ else:
+ prompt_kwargs = {
+ "text": generate_prompts,
+ "image_data": images,
+ }
+ else:
+ # List of integers (token IDs) or empty list
+ prompt_kwargs = {"input_ids": prompt}
+ else:
+ # Other types (should not happen but handle gracefully)
+ prompt_kwargs = {"input_ids": prompt}
+ # Use the passed request_ids for single request
+ final_request_id = request_ids[0] if len(all_requests) == 1 else request_ids
+ else:
+ # Handle batch requests
+ if len(prompts) > 0:
+ # Validate that all prompts have the same type
+ first_prompt = prompts[0]
+ first_type = type(first_prompt)
+ for i, prompt in enumerate(prompts[1:], 1):
+ if type(prompt) != first_type:
+ raise AssertionError(
+ f"All prompts in batch must have the same type, but prompt at index {i} has different type"
+ )
+
+ if isinstance(first_prompt, str):
+ # Batch of strings
+ prompt_kwargs = {"text": prompts}
+ elif isinstance(first_prompt, list):
+ if len(first_prompt) > 0 and isinstance(first_prompt[0], str):
+ # Batch of lists of strings
+ prompt_kwargs = {"text": prompts}
+ elif len(first_prompt) > 0 and isinstance(
+ first_prompt[0], MultimodalEmbeddingInput
+ ):
+ # Handle multimodal batch requests
+ raise NotImplementedError(
+ "Multiple requests with multimodal inputs are not supported yet"
+ )
+ else:
+ # Batch of token ID lists
+ prompt_kwargs = {"input_ids": prompts}
+ else:
+ # Other types
+ prompt_kwargs = {"input_ids": prompts}
+ else:
+ prompt_kwargs = {"input_ids": prompts}
+ # Use the passed request_ids for batch requests
+ final_request_id = request_ids
+
+ adapted_request = EmbeddingReqInput(
+ rid=final_request_id,
+ **prompt_kwargs,
+ )
+
+ return adapted_request, (
+ all_requests[0] if len(all_requests) == 1 else all_requests
+ )
+
+ async def _handle_non_streaming_request(
+ self,
+ adapted_request: EmbeddingReqInput,
+ request: EmbeddingRequest,
+ raw_request: Request,
+ ) -> Union[EmbeddingResponse, ErrorResponse]:
+ """Handle the embedding request"""
+ try:
+ ret = await self.tokenizer_manager.generate_request(
+ adapted_request, raw_request
+ ).__anext__()
+ except ValueError as e:
+ return self.create_error_response(str(e))
+
+ if not isinstance(ret, list):
+ ret = [ret]
+
+ response = self._build_embedding_response(
+ ret, self.tokenizer_manager.model_path
+ )
+ return response
+
+ def _build_embedding_response(
+ self, ret: List[Dict[str, Any]], model_path: str
+ ) -> EmbeddingResponse:
+ """Build the embedding response"""
+ embedding_objects = []
+ prompt_tokens = 0
+
+ for idx, ret_item in enumerate(ret):
+ embedding_objects.append(
+ EmbeddingObject(
+ embedding=ret_item["embedding"],
+ index=idx,
+ )
+ )
+ # Handle missing prompt_tokens gracefully
+ meta_info = ret_item.get("meta_info", {})
+ prompt_tokens += meta_info.get("prompt_tokens", 0)
+
+ return EmbeddingResponse(
+ data=embedding_objects,
+ model=model_path,
+ usage=UsageInfo(
+ prompt_tokens=prompt_tokens,
+ total_tokens=prompt_tokens,
+ ),
+ )
diff --git a/python/sglang/srt/entrypoints/openai/utils.py b/python/sglang/srt/entrypoints/openai/utils.py
new file mode 100644
index 00000000000..53c67831cdb
--- /dev/null
+++ b/python/sglang/srt/entrypoints/openai/utils.py
@@ -0,0 +1,264 @@
+import logging
+from typing import Any, Dict, List, Optional
+
+import jinja2.nodes
+import transformers.utils.chat_template_utils as hf_chat_utils
+
+from sglang.srt.entrypoints.openai.protocol import LogProbs, UsageInfo
+
+logger = logging.getLogger(__name__)
+
+
+# ============================================================================
+# JINJA TEMPLATE CONTENT FORMAT DETECTION
+# ============================================================================
+#
+# This adapts vLLM's approach for detecting chat template content format:
+# https://github.com/vllm-project/vllm/blob/02f0c7b220422792f5e53de2a7d51d2d3ff2df28/vllm/entrypoints/chat_utils.py#L296-L313
+# - Analyzes Jinja template AST to detect content iteration patterns
+# - 'openai' format: templates with {%- for content in message['content'] -%} loops
+# - 'string' format: templates that expect simple string content
+# - Processes content accordingly to match template expectations
+
+
+def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
+ """Check if node is a variable access like {{ varname }}"""
+ if isinstance(node, jinja2.nodes.Name):
+ return node.ctx == "load" and node.name == varname
+ return False
+
+
+def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
+ """Check if node is an attribute access like {{ varname['key'] }} or {{ varname.key }}"""
+ if isinstance(node, jinja2.nodes.Getitem):
+ return (
+ _is_var_access(node.node, varname)
+ and isinstance(node.arg, jinja2.nodes.Const)
+ and node.arg.value == key
+ )
+
+ if isinstance(node, jinja2.nodes.Getattr):
+ return _is_var_access(node.node, varname) and node.attr == key
+
+ return False
+
+
+def _is_var_or_elems_access(
+ node: jinja2.nodes.Node,
+ varname: str,
+ key: str = None,
+) -> bool:
+ """Check if node accesses varname or varname[key] with filters/tests"""
+ if isinstance(node, jinja2.nodes.Filter):
+ return node.node is not None and _is_var_or_elems_access(
+ node.node, varname, key
+ )
+ if isinstance(node, jinja2.nodes.Test):
+ return _is_var_or_elems_access(node.node, varname, key)
+
+ if isinstance(node, jinja2.nodes.Getitem) and isinstance(
+ node.arg, jinja2.nodes.Slice
+ ):
+ return _is_var_or_elems_access(node.node, varname, key)
+
+ return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
+
+
+def _try_extract_ast(chat_template: str):
+ """Try to parse the Jinja template into an AST"""
+ try:
+ jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
+ return jinja_compiled.environment.parse(chat_template)
+ except Exception as e:
+ logger.debug(f"Error when compiling Jinja template: {e}")
+ return None
+
+
+def detect_template_content_format(chat_template: str) -> str:
+ """
+ Detect whether a chat template expects 'string' or 'openai' content format.
+
+ - 'string': content is a simple string (like DeepSeek templates)
+ - 'openai': content is a list of structured dicts (like Llama4 templates)
+
+ Detection logic:
+ - If template has loops like {%- for content in message['content'] -%} → 'openai'
+ - Otherwise → 'string'
+ """
+ jinja_ast = _try_extract_ast(chat_template)
+ if jinja_ast is None:
+ return "string"
+
+ try:
+ # Look for patterns like: {%- for content in message['content'] -%}
+ for loop_ast in jinja_ast.find_all(jinja2.nodes.For):
+ loop_iter = loop_ast.iter
+
+ # Check if iterating over message['content'] or similar
+ if _is_var_or_elems_access(loop_iter, "message", "content"):
+ return "openai" # Found content iteration → openai format
+
+ return "string" # No content loops found → string format
+ except Exception as e:
+ logger.debug(f"Error when parsing AST of Jinja template: {e}")
+ return "string"
+
+
+def process_content_for_template_format(
+ msg_dict: dict,
+ content_format: str,
+ image_data: list,
+ audio_data: list,
+ modalities: list,
+) -> dict:
+ """
+ Process message content based on detected template format.
+
+ Args:
+ msg_dict: Message dictionary with content
+ content_format: 'string' or 'openai' (detected via AST analysis)
+ image_data: List to append extracted image URLs
+ audio_data: List to append extracted audio URLs
+ modalities: List to append modalities
+
+ Returns:
+ Processed message dictionary
+ """
+ if not isinstance(msg_dict.get("content"), list):
+ # Already a string or None, no processing needed
+ return {k: v for k, v in msg_dict.items() if v is not None}
+
+ if content_format == "openai":
+ # OpenAI format: preserve structured content list, normalize types
+ processed_content_parts = []
+ for chunk in msg_dict["content"]:
+ if isinstance(chunk, dict):
+ chunk_type = chunk.get("type")
+
+ if chunk_type == "image_url":
+ image_data.append(chunk["image_url"]["url"])
+ if chunk.get("modalities"):
+ modalities.append(chunk.get("modalities"))
+ # Normalize to simple 'image' type for template compatibility
+ processed_content_parts.append({"type": "image"})
+ elif chunk_type == "audio_url":
+ audio_data.append(chunk["audio_url"]["url"])
+ # Normalize to simple 'audio' type
+ processed_content_parts.append({"type": "audio"})
+ else:
+ # Keep other content as-is (text, etc.)
+ processed_content_parts.append(chunk)
+
+ new_msg = {
+ k: v for k, v in msg_dict.items() if v is not None and k != "content"
+ }
+ new_msg["content"] = processed_content_parts
+ return new_msg
+
+ else: # content_format == "string"
+ # String format: flatten to text only (for templates like DeepSeek)
+ text_parts = []
+ for chunk in msg_dict["content"]:
+ if isinstance(chunk, dict) and chunk.get("type") == "text":
+ text_parts.append(chunk["text"])
+ # Note: For string format, we ignore images/audio since the template
+ # doesn't expect structured content - multimodal placeholders would
+ # need to be inserted differently
+
+ new_msg = msg_dict.copy()
+ new_msg["content"] = " ".join(text_parts) if text_parts else ""
+ new_msg = {k: v for k, v in new_msg.items() if v is not None}
+ return new_msg
+
+
+def calculate_token_usage(
+ prompt_tokens: int,
+ completion_tokens: int,
+ cached_tokens: Optional[Dict[str, int]] = None,
+) -> UsageInfo:
+ """Calculate token usage information"""
+ return UsageInfo(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ prompt_tokens_details=cached_tokens,
+ )
+
+
+def aggregate_token_usage(
+ responses: List[Dict[str, Any]],
+ n_choices: int = 1,
+ enable_cache_report: bool = False,
+) -> UsageInfo:
+ """Aggregate token usage from multiple responses
+
+ Args:
+ responses: List of response dictionaries with meta_info
+ n_choices: Number of choices per request (for prompt token counting)
+ enable_cache_report: Whether to include cached token details
+
+ Returns:
+ Aggregated UsageInfo
+ """
+ # Sum completion tokens from all responses
+ completion_tokens = sum(
+ response["meta_info"]["completion_tokens"] for response in responses
+ )
+
+ # For prompt tokens, only count every n_choices-th response to avoid double counting
+ prompt_tokens = sum(
+ responses[i]["meta_info"]["prompt_tokens"]
+ for i in range(0, len(responses), n_choices)
+ )
+
+ # Handle cached tokens if cache reporting is enabled
+ cached_tokens_details = None
+ if enable_cache_report:
+ cached_tokens_sum = sum(
+ response["meta_info"].get("cached_tokens", 0) for response in responses
+ )
+ if cached_tokens_sum > 0:
+ cached_tokens_details = {"cached_tokens": cached_tokens_sum}
+
+ return calculate_token_usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ cached_tokens=cached_tokens_details,
+ )
+
+
+def to_openai_style_logprobs(
+ input_token_logprobs=None,
+ output_token_logprobs=None,
+ input_top_logprobs=None,
+ output_top_logprobs=None,
+):
+ ret_logprobs = LogProbs()
+
+ def append_token_logprobs(token_logprobs):
+ for logprob, _, token_text in token_logprobs:
+ ret_logprobs.tokens.append(token_text)
+ ret_logprobs.token_logprobs.append(logprob)
+
+ # Not supported yet
+ ret_logprobs.text_offset.append(-1)
+
+ def append_top_logprobs(top_logprobs):
+ for tokens in top_logprobs:
+ if tokens is not None:
+ ret_logprobs.top_logprobs.append(
+ {token[2]: token[0] for token in tokens}
+ )
+ else:
+ ret_logprobs.top_logprobs.append(None)
+
+ if input_token_logprobs is not None:
+ append_token_logprobs(input_token_logprobs)
+ if output_token_logprobs is not None:
+ append_token_logprobs(output_token_logprobs)
+ if input_top_logprobs is not None:
+ append_top_logprobs(input_top_logprobs)
+ if output_top_logprobs is not None:
+ append_top_logprobs(output_top_logprobs)
+
+ return ret_logprobs
diff --git a/test/pytest.ini b/test/pytest.ini
new file mode 100644
index 00000000000..2f4c80e3075
--- /dev/null
+++ b/test/pytest.ini
@@ -0,0 +1,2 @@
+[pytest]
+asyncio_mode = auto
diff --git a/test/srt/openai/test_protocol.py b/test/srt/openai/test_protocol.py
new file mode 100644
index 00000000000..a14b3d7179f
--- /dev/null
+++ b/test/srt/openai/test_protocol.py
@@ -0,0 +1,683 @@
+# Copyright 2023-2024 SGLang Team
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for OpenAI API protocol models"""
+
+import json
+import time
+from typing import Dict, List, Optional
+
+import pytest
+from pydantic import ValidationError
+
+from sglang.srt.entrypoints.openai.protocol import (
+ BatchRequest,
+ BatchResponse,
+ ChatCompletionMessageContentImagePart,
+ ChatCompletionMessageContentTextPart,
+ ChatCompletionRequest,
+ ChatCompletionResponse,
+ ChatCompletionResponseChoice,
+ ChatCompletionResponseStreamChoice,
+ ChatCompletionStreamResponse,
+ ChatCompletionTokenLogprob,
+ ChatMessage,
+ ChoiceLogprobs,
+ CompletionRequest,
+ CompletionResponse,
+ CompletionResponseChoice,
+ DeltaMessage,
+ EmbeddingObject,
+ EmbeddingRequest,
+ EmbeddingResponse,
+ ErrorResponse,
+ FileDeleteResponse,
+ FileRequest,
+ FileResponse,
+ Function,
+ FunctionResponse,
+ JsonSchemaResponseFormat,
+ LogProbs,
+ ModelCard,
+ ModelList,
+ MultimodalEmbeddingInput,
+ ResponseFormat,
+ ScoringRequest,
+ ScoringResponse,
+ StreamOptions,
+ StructuralTagResponseFormat,
+ Tool,
+ ToolCall,
+ ToolChoice,
+ TopLogprob,
+ UsageInfo,
+)
+
+
+class TestModelCard:
+ """Test ModelCard protocol model"""
+
+ def test_basic_model_card_creation(self):
+ """Test basic model card creation with required fields"""
+ card = ModelCard(id="test-model")
+ assert card.id == "test-model"
+ assert card.object == "model"
+ assert card.owned_by == "sglang"
+ assert isinstance(card.created, int)
+ assert card.root is None
+ assert card.max_model_len is None
+
+ def test_model_card_with_optional_fields(self):
+ """Test model card with optional fields"""
+ card = ModelCard(
+ id="test-model",
+ root="/path/to/model",
+ max_model_len=2048,
+ created=1234567890,
+ )
+ assert card.id == "test-model"
+ assert card.root == "/path/to/model"
+ assert card.max_model_len == 2048
+ assert card.created == 1234567890
+
+ def test_model_card_serialization(self):
+ """Test model card JSON serialization"""
+ card = ModelCard(id="test-model", max_model_len=4096)
+ data = card.model_dump()
+ assert data["id"] == "test-model"
+ assert data["object"] == "model"
+ assert data["max_model_len"] == 4096
+
+
+class TestModelList:
+ """Test ModelList protocol model"""
+
+ def test_empty_model_list(self):
+ """Test empty model list creation"""
+ model_list = ModelList()
+ assert model_list.object == "list"
+ assert len(model_list.data) == 0
+
+ def test_model_list_with_cards(self):
+ """Test model list with model cards"""
+ cards = [
+ ModelCard(id="model-1"),
+ ModelCard(id="model-2", max_model_len=2048),
+ ]
+ model_list = ModelList(data=cards)
+ assert len(model_list.data) == 2
+ assert model_list.data[0].id == "model-1"
+ assert model_list.data[1].id == "model-2"
+
+
+class TestErrorResponse:
+ """Test ErrorResponse protocol model"""
+
+ def test_basic_error_response(self):
+ """Test basic error response creation"""
+ error = ErrorResponse(
+ message="Invalid request", type="BadRequestError", code=400
+ )
+ assert error.object == "error"
+ assert error.message == "Invalid request"
+ assert error.type == "BadRequestError"
+ assert error.code == 400
+ assert error.param is None
+
+ def test_error_response_with_param(self):
+ """Test error response with parameter"""
+ error = ErrorResponse(
+ message="Invalid temperature",
+ type="ValidationError",
+ code=422,
+ param="temperature",
+ )
+ assert error.param == "temperature"
+
+
+class TestUsageInfo:
+ """Test UsageInfo protocol model"""
+
+ def test_basic_usage_info(self):
+ """Test basic usage info creation"""
+ usage = UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30)
+ assert usage.prompt_tokens == 10
+ assert usage.completion_tokens == 20
+ assert usage.total_tokens == 30
+ assert usage.prompt_tokens_details is None
+
+ def test_usage_info_with_cache_details(self):
+ """Test usage info with cache details"""
+ usage = UsageInfo(
+ prompt_tokens=10,
+ completion_tokens=20,
+ total_tokens=30,
+ prompt_tokens_details={"cached_tokens": 5},
+ )
+ assert usage.prompt_tokens_details == {"cached_tokens": 5}
+
+
+class TestCompletionRequest:
+ """Test CompletionRequest protocol model"""
+
+ def test_basic_completion_request(self):
+ """Test basic completion request"""
+ request = CompletionRequest(model="test-model", prompt="Hello world")
+ assert request.model == "test-model"
+ assert request.prompt == "Hello world"
+ assert request.max_tokens == 16 # default
+ assert request.temperature == 1.0 # default
+ assert request.n == 1 # default
+ assert not request.stream # default
+ assert not request.echo # default
+
+ def test_completion_request_with_options(self):
+ """Test completion request with various options"""
+ request = CompletionRequest(
+ model="test-model",
+ prompt=["Hello", "world"],
+ max_tokens=100,
+ temperature=0.7,
+ top_p=0.9,
+ n=2,
+ stream=True,
+ echo=True,
+ stop=[".", "!"],
+ logprobs=5,
+ )
+ assert request.prompt == ["Hello", "world"]
+ assert request.max_tokens == 100
+ assert request.temperature == 0.7
+ assert request.top_p == 0.9
+ assert request.n == 2
+ assert request.stream
+ assert request.echo
+ assert request.stop == [".", "!"]
+ assert request.logprobs == 5
+
+ def test_completion_request_sglang_extensions(self):
+ """Test completion request with SGLang-specific extensions"""
+ request = CompletionRequest(
+ model="test-model",
+ prompt="Hello",
+ top_k=50,
+ min_p=0.1,
+ repetition_penalty=1.1,
+ regex=r"\d+",
+ json_schema='{"type": "object"}',
+ lora_path="/path/to/lora",
+ )
+ assert request.top_k == 50
+ assert request.min_p == 0.1
+ assert request.repetition_penalty == 1.1
+ assert request.regex == r"\d+"
+ assert request.json_schema == '{"type": "object"}'
+ assert request.lora_path == "/path/to/lora"
+
+ def test_completion_request_validation_errors(self):
+ """Test completion request validation errors"""
+ with pytest.raises(ValidationError):
+ CompletionRequest() # missing required fields
+
+ with pytest.raises(ValidationError):
+ CompletionRequest(model="test-model") # missing prompt
+
+
+class TestCompletionResponse:
+ """Test CompletionResponse protocol model"""
+
+ def test_basic_completion_response(self):
+ """Test basic completion response"""
+ choice = CompletionResponseChoice(
+ index=0, text="Hello world!", finish_reason="stop"
+ )
+ usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
+ response = CompletionResponse(
+ id="test-id", model="test-model", choices=[choice], usage=usage
+ )
+ assert response.id == "test-id"
+ assert response.object == "text_completion"
+ assert response.model == "test-model"
+ assert len(response.choices) == 1
+ assert response.choices[0].text == "Hello world!"
+ assert response.usage.total_tokens == 5
+
+
+class TestChatCompletionRequest:
+ """Test ChatCompletionRequest protocol model"""
+
+ def test_basic_chat_completion_request(self):
+ """Test basic chat completion request"""
+ messages = [{"role": "user", "content": "Hello"}]
+ request = ChatCompletionRequest(model="test-model", messages=messages)
+ assert request.model == "test-model"
+ assert len(request.messages) == 1
+ assert request.messages[0].role == "user"
+ assert request.messages[0].content == "Hello"
+ assert request.temperature == 0.7 # default
+ assert not request.stream # default
+ assert request.tool_choice == "none" # default when no tools
+
+ def test_chat_completion_with_multimodal_content(self):
+ """Test chat completion with multimodal content"""
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "What's in this image?"},
+ {
+ "type": "image_url",
+ "image_url": {"url": "..."},
+ },
+ ],
+ }
+ ]
+ request = ChatCompletionRequest(model="test-model", messages=messages)
+ assert len(request.messages[0].content) == 2
+ assert request.messages[0].content[0].type == "text"
+ assert request.messages[0].content[1].type == "image_url"
+
+ def test_chat_completion_with_tools(self):
+ """Test chat completion with tools"""
+ messages = [{"role": "user", "content": "What's the weather?"}]
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "Get weather information",
+ "parameters": {
+ "type": "object",
+ "properties": {"location": {"type": "string"}},
+ },
+ },
+ }
+ ]
+ request = ChatCompletionRequest(
+ model="test-model", messages=messages, tools=tools
+ )
+ assert len(request.tools) == 1
+ assert request.tools[0].function.name == "get_weather"
+ assert request.tool_choice == "auto" # default when tools present
+
+ def test_chat_completion_tool_choice_validation(self):
+ """Test tool choice validation logic"""
+ messages = [{"role": "user", "content": "Hello"}]
+
+ # No tools, tool_choice should default to "none"
+ request1 = ChatCompletionRequest(model="test-model", messages=messages)
+ assert request1.tool_choice == "none"
+
+ # With tools, tool_choice should default to "auto"
+ tools = [
+ {
+ "type": "function",
+ "function": {"name": "test_func", "description": "Test function"},
+ }
+ ]
+ request2 = ChatCompletionRequest(
+ model="test-model", messages=messages, tools=tools
+ )
+ assert request2.tool_choice == "auto"
+
+ def test_chat_completion_sglang_extensions(self):
+ """Test chat completion with SGLang extensions"""
+ messages = [{"role": "user", "content": "Hello"}]
+ request = ChatCompletionRequest(
+ model="test-model",
+ messages=messages,
+ top_k=40,
+ min_p=0.05,
+ separate_reasoning=False,
+ stream_reasoning=False,
+ chat_template_kwargs={"custom_param": "value"},
+ )
+ assert request.top_k == 40
+ assert request.min_p == 0.05
+ assert not request.separate_reasoning
+ assert not request.stream_reasoning
+ assert request.chat_template_kwargs == {"custom_param": "value"}
+
+
+class TestChatCompletionResponse:
+ """Test ChatCompletionResponse protocol model"""
+
+ def test_basic_chat_completion_response(self):
+ """Test basic chat completion response"""
+ message = ChatMessage(role="assistant", content="Hello there!")
+ choice = ChatCompletionResponseChoice(
+ index=0, message=message, finish_reason="stop"
+ )
+ usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5)
+ response = ChatCompletionResponse(
+ id="test-id", model="test-model", choices=[choice], usage=usage
+ )
+ assert response.id == "test-id"
+ assert response.object == "chat.completion"
+ assert response.model == "test-model"
+ assert len(response.choices) == 1
+ assert response.choices[0].message.content == "Hello there!"
+
+ def test_chat_completion_response_with_tool_calls(self):
+ """Test chat completion response with tool calls"""
+ tool_call = ToolCall(
+ id="call_123",
+ function=FunctionResponse(
+ name="get_weather", arguments='{"location": "San Francisco"}'
+ ),
+ )
+ message = ChatMessage(role="assistant", content=None, tool_calls=[tool_call])
+ choice = ChatCompletionResponseChoice(
+ index=0, message=message, finish_reason="tool_calls"
+ )
+ usage = UsageInfo(prompt_tokens=10, completion_tokens=5, total_tokens=15)
+ response = ChatCompletionResponse(
+ id="test-id", model="test-model", choices=[choice], usage=usage
+ )
+ assert response.choices[0].message.tool_calls[0].function.name == "get_weather"
+ assert response.choices[0].finish_reason == "tool_calls"
+
+
+class TestEmbeddingRequest:
+ """Test EmbeddingRequest protocol model"""
+
+ def test_basic_embedding_request(self):
+ """Test basic embedding request"""
+ request = EmbeddingRequest(model="test-model", input="Hello world")
+ assert request.model == "test-model"
+ assert request.input == "Hello world"
+ assert request.encoding_format == "float" # default
+ assert request.dimensions is None # default
+
+ def test_embedding_request_with_list_input(self):
+ """Test embedding request with list input"""
+ request = EmbeddingRequest(
+ model="test-model", input=["Hello", "world"], dimensions=512
+ )
+ assert request.input == ["Hello", "world"]
+ assert request.dimensions == 512
+
+ def test_multimodal_embedding_request(self):
+ """Test multimodal embedding request"""
+ multimodal_input = [
+ MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
+ MultimodalEmbeddingInput(text="World", image=None),
+ ]
+ request = EmbeddingRequest(model="test-model", input=multimodal_input)
+ assert len(request.input) == 2
+ assert request.input[0].text == "Hello"
+ assert request.input[0].image == "base64_image_data"
+ assert request.input[1].text == "World"
+ assert request.input[1].image is None
+
+
+class TestEmbeddingResponse:
+ """Test EmbeddingResponse protocol model"""
+
+ def test_basic_embedding_response(self):
+ """Test basic embedding response"""
+ embedding_obj = EmbeddingObject(embedding=[0.1, 0.2, 0.3], index=0)
+ usage = UsageInfo(prompt_tokens=3, total_tokens=3)
+ response = EmbeddingResponse(
+ data=[embedding_obj], model="test-model", usage=usage
+ )
+ assert response.object == "list"
+ assert len(response.data) == 1
+ assert response.data[0].embedding == [0.1, 0.2, 0.3]
+ assert response.data[0].index == 0
+ assert response.usage.prompt_tokens == 3
+
+
+class TestScoringRequest:
+ """Test ScoringRequest protocol model"""
+
+ def test_basic_scoring_request(self):
+ """Test basic scoring request"""
+ request = ScoringRequest(
+ model="test-model", query="Hello", items=["World", "Earth"]
+ )
+ assert request.model == "test-model"
+ assert request.query == "Hello"
+ assert request.items == ["World", "Earth"]
+ assert not request.apply_softmax # default
+ assert not request.item_first # default
+
+ def test_scoring_request_with_token_ids(self):
+ """Test scoring request with token IDs"""
+ request = ScoringRequest(
+ model="test-model",
+ query=[1, 2, 3],
+ items=[[4, 5], [6, 7]],
+ label_token_ids=[8, 9],
+ apply_softmax=True,
+ item_first=True,
+ )
+ assert request.query == [1, 2, 3]
+ assert request.items == [[4, 5], [6, 7]]
+ assert request.label_token_ids == [8, 9]
+ assert request.apply_softmax
+ assert request.item_first
+
+
+class TestScoringResponse:
+ """Test ScoringResponse protocol model"""
+
+ def test_basic_scoring_response(self):
+ """Test basic scoring response"""
+ response = ScoringResponse(scores=[[0.1, 0.9], [0.3, 0.7]], model="test-model")
+ assert response.object == "scoring"
+ assert response.scores == [[0.1, 0.9], [0.3, 0.7]]
+ assert response.model == "test-model"
+ assert response.usage is None # default
+
+
+class TestFileOperations:
+ """Test file operation protocol models"""
+
+ def test_file_request(self):
+ """Test file request model"""
+ file_data = b"test file content"
+ request = FileRequest(file=file_data, purpose="batch")
+ assert request.file == file_data
+ assert request.purpose == "batch"
+
+ def test_file_response(self):
+ """Test file response model"""
+ response = FileResponse(
+ id="file-123",
+ bytes=1024,
+ created_at=1234567890,
+ filename="test.jsonl",
+ purpose="batch",
+ )
+ assert response.id == "file-123"
+ assert response.object == "file"
+ assert response.bytes == 1024
+ assert response.filename == "test.jsonl"
+
+ def test_file_delete_response(self):
+ """Test file delete response model"""
+ response = FileDeleteResponse(id="file-123", deleted=True)
+ assert response.id == "file-123"
+ assert response.object == "file"
+ assert response.deleted
+
+
+class TestBatchOperations:
+ """Test batch operation protocol models"""
+
+ def test_batch_request(self):
+ """Test batch request model"""
+ request = BatchRequest(
+ input_file_id="file-123",
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ metadata={"custom": "value"},
+ )
+ assert request.input_file_id == "file-123"
+ assert request.endpoint == "/v1/chat/completions"
+ assert request.completion_window == "24h"
+ assert request.metadata == {"custom": "value"}
+
+ def test_batch_response(self):
+ """Test batch response model"""
+ response = BatchResponse(
+ id="batch-123",
+ endpoint="/v1/chat/completions",
+ input_file_id="file-123",
+ completion_window="24h",
+ created_at=1234567890,
+ )
+ assert response.id == "batch-123"
+ assert response.object == "batch"
+ assert response.status == "validating" # default
+ assert response.endpoint == "/v1/chat/completions"
+
+
+class TestResponseFormats:
+ """Test response format protocol models"""
+
+ def test_basic_response_format(self):
+ """Test basic response format"""
+ format_obj = ResponseFormat(type="json_object")
+ assert format_obj.type == "json_object"
+ assert format_obj.json_schema is None
+
+ def test_json_schema_response_format(self):
+ """Test JSON schema response format"""
+ schema = {"type": "object", "properties": {"name": {"type": "string"}}}
+ json_schema = JsonSchemaResponseFormat(
+ name="person_schema", description="Person schema", schema=schema
+ )
+ format_obj = ResponseFormat(type="json_schema", json_schema=json_schema)
+ assert format_obj.type == "json_schema"
+ assert format_obj.json_schema.name == "person_schema"
+ assert format_obj.json_schema.schema_ == schema
+
+ def test_structural_tag_response_format(self):
+ """Test structural tag response format"""
+ structures = [
+ {
+ "begin": "",
+ "schema_": {"type": "string"},
+ "end": "",
+ }
+ ]
+ format_obj = StructuralTagResponseFormat(
+ type="structural_tag", structures=structures, triggers=["think"]
+ )
+ assert format_obj.type == "structural_tag"
+ assert len(format_obj.structures) == 1
+ assert format_obj.triggers == ["think"]
+
+
+class TestLogProbs:
+ """Test LogProbs protocol models"""
+
+ def test_basic_logprobs(self):
+ """Test basic LogProbs model"""
+ logprobs = LogProbs(
+ text_offset=[0, 5, 11],
+ token_logprobs=[-0.1, -0.2, -0.3],
+ tokens=["Hello", " ", "world"],
+ top_logprobs=[{"Hello": -0.1}, {" ": -0.2}, {"world": -0.3}],
+ )
+ assert len(logprobs.tokens) == 3
+ assert logprobs.tokens == ["Hello", " ", "world"]
+ assert logprobs.token_logprobs == [-0.1, -0.2, -0.3]
+
+ def test_choice_logprobs(self):
+ """Test ChoiceLogprobs model"""
+ token_logprob = ChatCompletionTokenLogprob(
+ token="Hello",
+ bytes=[72, 101, 108, 108, 111],
+ logprob=-0.1,
+ top_logprobs=[
+ TopLogprob(token="Hello", bytes=[72, 101, 108, 108, 111], logprob=-0.1)
+ ],
+ )
+ choice_logprobs = ChoiceLogprobs(content=[token_logprob])
+ assert len(choice_logprobs.content) == 1
+ assert choice_logprobs.content[0].token == "Hello"
+
+
+class TestStreamingModels:
+ """Test streaming response models"""
+
+ def test_stream_options(self):
+ """Test StreamOptions model"""
+ options = StreamOptions(include_usage=True)
+ assert options.include_usage
+
+ def test_chat_completion_stream_response(self):
+ """Test ChatCompletionStreamResponse model"""
+ delta = DeltaMessage(role="assistant", content="Hello")
+ choice = ChatCompletionResponseStreamChoice(index=0, delta=delta)
+ response = ChatCompletionStreamResponse(
+ id="test-id", model="test-model", choices=[choice]
+ )
+ assert response.object == "chat.completion.chunk"
+ assert response.choices[0].delta.content == "Hello"
+
+
+class TestValidationEdgeCases:
+ """Test edge cases and validation scenarios"""
+
+ def test_empty_messages_validation(self):
+ """Test validation with empty messages"""
+ with pytest.raises(ValidationError):
+ ChatCompletionRequest(model="test-model", messages=[])
+
+ def test_invalid_tool_choice_type(self):
+ """Test invalid tool choice type"""
+ messages = [{"role": "user", "content": "Hello"}]
+ with pytest.raises(ValidationError):
+ ChatCompletionRequest(
+ model="test-model", messages=messages, tool_choice=123
+ )
+
+ def test_negative_token_limits(self):
+ """Test negative token limits"""
+ with pytest.raises(ValidationError):
+ CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
+
+ def test_invalid_temperature_range(self):
+ """Test invalid temperature values"""
+ # Note: The current protocol doesn't enforce temperature range,
+ # but this test documents expected behavior
+ request = CompletionRequest(model="test-model", prompt="Hello", temperature=5.0)
+ assert request.temperature == 5.0 # Currently allowed
+
+ def test_model_serialization_roundtrip(self):
+ """Test that models can be serialized and deserialized"""
+ original_request = ChatCompletionRequest(
+ model="test-model",
+ messages=[{"role": "user", "content": "Hello"}],
+ temperature=0.7,
+ max_tokens=100,
+ )
+
+ # Serialize to dict
+ data = original_request.model_dump()
+
+ # Deserialize back
+ restored_request = ChatCompletionRequest(**data)
+
+ assert restored_request.model == original_request.model
+ assert restored_request.temperature == original_request.temperature
+ assert restored_request.max_tokens == original_request.max_tokens
+ assert len(restored_request.messages) == len(original_request.messages)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/test/srt/openai/test_serving_chat.py b/test/srt/openai/test_serving_chat.py
new file mode 100644
index 00000000000..b2015866b63
--- /dev/null
+++ b/test/srt/openai/test_serving_chat.py
@@ -0,0 +1,634 @@
+"""
+Unit tests for the OpenAIServingChat class from serving_chat.py.
+
+These tests ensure that the refactored implementation maintains compatibility
+with the original adapter.py functionality.
+"""
+
+import uuid
+from unittest.mock import Mock, patch
+
+import pytest
+from fastapi import Request
+
+from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse
+from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
+from sglang.srt.managers.io_struct import GenerateReqInput
+
+
+# Mock TokenizerManager since it may not be directly importable in tests
+class MockTokenizerManager:
+ def __init__(self):
+ self.model_config = Mock()
+ self.model_config.is_multimodal = False
+ self.server_args = Mock()
+ self.server_args.enable_cache_report = False
+ self.server_args.tool_call_parser = "hermes"
+ self.server_args.reasoning_parser = None
+ self.chat_template_name = "llama-3"
+
+ # Mock tokenizer
+ self.tokenizer = Mock()
+ self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
+ self.tokenizer.decode = Mock(return_value="Test response")
+ self.tokenizer.chat_template = None
+ self.tokenizer.bos_token_id = 1
+
+ # Mock generate_request method
+ async def mock_generate():
+ yield {
+ "text": "Test response",
+ "meta_info": {
+ "id": f"chatcmpl-{uuid.uuid4()}",
+ "prompt_tokens": 10,
+ "completion_tokens": 5,
+ "cached_tokens": 0,
+ "finish_reason": {"type": "stop", "matched": None},
+ "output_token_logprobs": [(0.1, 1, "Test"), (0.2, 2, "response")],
+ "output_top_logprobs": None,
+ },
+ "index": 0,
+ }
+
+ self.generate_request = Mock(return_value=mock_generate())
+ self.create_abort_task = Mock(return_value=None)
+
+
+@pytest.fixture
+def mock_tokenizer_manager():
+ """Create a mock tokenizer manager for testing."""
+ return MockTokenizerManager()
+
+
+@pytest.fixture
+def serving_chat(mock_tokenizer_manager):
+ """Create a OpenAIServingChat instance for testing."""
+ return OpenAIServingChat(mock_tokenizer_manager)
+
+
+@pytest.fixture
+def mock_request():
+ """Create a mock FastAPI request."""
+ request = Mock(spec=Request)
+ request.headers = {}
+ return request
+
+
+@pytest.fixture
+def basic_chat_request():
+ """Create a basic chat completion request."""
+ return ChatCompletionRequest(
+ model="test-model",
+ messages=[{"role": "user", "content": "Hello, how are you?"}],
+ temperature=0.7,
+ max_tokens=100,
+ stream=False,
+ )
+
+
+@pytest.fixture
+def streaming_chat_request():
+ """Create a streaming chat completion request."""
+ return ChatCompletionRequest(
+ model="test-model",
+ messages=[{"role": "user", "content": "Hello, how are you?"}],
+ temperature=0.7,
+ max_tokens=100,
+ stream=True,
+ )
+
+
+class TestOpenAIServingChatConversion:
+ """Test request conversion methods."""
+
+ def test_convert_to_internal_request_single(
+ self, serving_chat, basic_chat_request, mock_tokenizer_manager
+ ):
+ """Test converting single request to internal format."""
+ with patch(
+ "sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
+ ) as mock_conv:
+ mock_conv_instance = Mock()
+ mock_conv_instance.get_prompt.return_value = "Test prompt"
+ mock_conv_instance.image_data = None
+ mock_conv_instance.audio_data = None
+ mock_conv_instance.modalities = []
+ mock_conv_instance.stop_str = [""]
+ mock_conv.return_value = mock_conv_instance
+
+ # Mock the _process_messages method to return expected values
+ with patch.object(serving_chat, "_process_messages") as mock_process:
+ mock_process.return_value = (
+ "Test prompt",
+ [1, 2, 3],
+ None,
+ None,
+ [],
+ [""],
+ None, # tool_call_constraint
+ )
+
+ adapted_request, processed_request = (
+ serving_chat._convert_to_internal_request(
+ [basic_chat_request], ["test-id"]
+ )
+ )
+
+ assert isinstance(adapted_request, GenerateReqInput)
+ assert adapted_request.stream == basic_chat_request.stream
+ assert processed_request == basic_chat_request
+
+
+class TestToolCalls:
+ """Test tool call functionality from adapter.py"""
+
+ def test_tool_call_request_conversion(self, serving_chat):
+ """Test request with tool calls"""
+ request = ChatCompletionRequest(
+ model="test-model",
+ messages=[{"role": "user", "content": "What's the weather?"}],
+ tools=[
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "Get weather information",
+ "parameters": {
+ "type": "object",
+ "properties": {"location": {"type": "string"}},
+ },
+ },
+ }
+ ],
+ tool_choice="auto",
+ )
+
+ with patch.object(serving_chat, "_process_messages") as mock_process:
+ mock_process.return_value = (
+ "Test prompt",
+ [1, 2, 3],
+ None,
+ None,
+ [],
+ [""],
+ None, # tool_call_constraint
+ )
+
+ adapted_request, _ = serving_chat._convert_to_internal_request(
+ [request], ["test-id"]
+ )
+
+ assert adapted_request.rid == "test-id"
+ # Tool call constraint should be processed
+ assert request.tools is not None
+
+ def test_tool_choice_none(self, serving_chat):
+ """Test tool_choice=none disables tool calls"""
+ request = ChatCompletionRequest(
+ model="test-model",
+ messages=[{"role": "user", "content": "Hello"}],
+ tools=[{"type": "function", "function": {"name": "test_func"}}],
+ tool_choice="none",
+ )
+
+ with patch.object(serving_chat, "_process_messages") as mock_process:
+ mock_process.return_value = (
+ "Test prompt",
+ [1, 2, 3],
+ None,
+ None,
+ [],
+ [""],
+ None, # tool_call_constraint
+ )
+
+ adapted_request, _ = serving_chat._convert_to_internal_request(
+ [request], ["test-id"]
+ )
+
+ # Tools should not be processed when tool_choice is "none"
+ assert adapted_request.rid == "test-id"
+
+ def test_tool_call_response_processing(self, serving_chat):
+ """Test processing tool calls in response"""
+ mock_ret_item = {
+ "text": '{"name": "get_weather", "parameters": {"location": "Paris"}}',
+ "meta_info": {
+ "output_token_logprobs": [],
+ "output_top_logprobs": None,
+ },
+ }
+
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "parameters": {
+ "type": "object",
+ "properties": {"location": {"type": "string"}},
+ },
+ },
+ }
+ ]
+
+ finish_reason = {"type": "stop", "matched": None}
+
+ # Mock FunctionCallParser
+ with patch(
+ "sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
+ ) as mock_parser_class:
+ mock_parser = Mock()
+ mock_parser.has_tool_call.return_value = True
+
+ # Create proper mock tool call object
+ mock_tool_call = Mock()
+ mock_tool_call.name = "get_weather"
+ mock_tool_call.parameters = '{"location": "Paris"}'
+
+ mock_parser.parse_non_stream.return_value = ("", [mock_tool_call])
+ mock_parser_class.return_value = mock_parser
+
+ tool_calls, text, updated_finish_reason = serving_chat._process_tool_calls(
+ mock_ret_item["text"], tools, "hermes", finish_reason
+ )
+
+ assert tool_calls is not None
+ assert len(tool_calls) == 1
+ assert updated_finish_reason["type"] == "tool_calls"
+
+
+class TestMultimodalContent:
+ """Test multimodal content handling from adapter.py"""
+
+ def test_multimodal_request_with_images(self, serving_chat):
+ """Test request with image content"""
+ request = ChatCompletionRequest(
+ model="test-model",
+ messages=[
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "What's in this image?"},
+ {
+ "type": "image_url",
+ "image_url": {"url": "data:image/jpeg;base64,..."},
+ },
+ ],
+ }
+ ],
+ )
+
+ # Set multimodal mode
+ serving_chat.tokenizer_manager.model_config.is_multimodal = True
+
+ with patch.object(serving_chat, "_apply_jinja_template") as mock_apply:
+ mock_apply.return_value = (
+ "prompt",
+ [1, 2, 3],
+ ["image_data"],
+ None,
+ [],
+ [],
+ )
+
+ with patch.object(
+ serving_chat, "_apply_conversation_template"
+ ) as mock_conv:
+ mock_conv.return_value = ("prompt", ["image_data"], None, [], [])
+
+ (
+ prompt,
+ prompt_ids,
+ image_data,
+ audio_data,
+ modalities,
+ stop,
+ tool_call_constraint,
+ ) = serving_chat._process_messages(request, True)
+
+ assert image_data == ["image_data"]
+ assert prompt == "prompt"
+
+ def test_multimodal_request_with_audio(self, serving_chat):
+ """Test request with audio content"""
+ request = ChatCompletionRequest(
+ model="test-model",
+ messages=[
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Transcribe this audio"},
+ {
+ "type": "audio_url",
+ "audio_url": {"url": "data:audio/wav;base64,UklGR..."},
+ },
+ ],
+ }
+ ],
+ )
+
+ serving_chat.tokenizer_manager.model_config.is_multimodal = True
+
+ with patch.object(serving_chat, "_apply_jinja_template") as mock_apply:
+ mock_apply.return_value = (
+ "prompt",
+ [1, 2, 3],
+ None,
+ ["audio_data"],
+ ["audio"],
+ [],
+ )
+
+ with patch.object(
+ serving_chat, "_apply_conversation_template"
+ ) as mock_conv:
+ mock_conv.return_value = ("prompt", None, ["audio_data"], ["audio"], [])
+
+ (
+ prompt,
+ prompt_ids,
+ image_data,
+ audio_data,
+ modalities,
+ stop,
+ tool_call_constraint,
+ ) = serving_chat._process_messages(request, True)
+
+ assert audio_data == ["audio_data"]
+ assert modalities == ["audio"]
+
+
+class TestTemplateHandling:
+ """Test chat template handling from adapter.py"""
+
+ def test_jinja_template_processing(self, serving_chat):
+ """Test Jinja template processing"""
+ request = ChatCompletionRequest(
+ model="test-model", messages=[{"role": "user", "content": "Hello"}]
+ )
+
+ # Mock the template attribute directly
+ serving_chat.tokenizer_manager.chat_template_name = None
+ serving_chat.tokenizer_manager.tokenizer.chat_template = ""
+
+ with patch.object(serving_chat, "_apply_jinja_template") as mock_apply:
+ mock_apply.return_value = (
+ "processed_prompt",
+ [1, 2, 3],
+ None,
+ None,
+ [],
+ [""],
+ )
+
+ # Mock hasattr to simulate the None check
+ with patch("builtins.hasattr") as mock_hasattr:
+ mock_hasattr.return_value = True
+
+ (
+ prompt,
+ prompt_ids,
+ image_data,
+ audio_data,
+ modalities,
+ stop,
+ tool_call_constraint,
+ ) = serving_chat._process_messages(request, False)
+
+ assert prompt == "processed_prompt"
+ assert prompt_ids == [1, 2, 3]
+
+ def test_conversation_template_processing(self, serving_chat):
+ """Test conversation template processing"""
+ request = ChatCompletionRequest(
+ model="test-model", messages=[{"role": "user", "content": "Hello"}]
+ )
+
+ serving_chat.tokenizer_manager.chat_template_name = "llama-3"
+
+ with patch.object(serving_chat, "_apply_conversation_template") as mock_apply:
+ mock_apply.return_value = ("conv_prompt", None, None, [], [""])
+
+ (
+ prompt,
+ prompt_ids,
+ image_data,
+ audio_data,
+ modalities,
+ stop,
+ tool_call_constraint,
+ ) = serving_chat._process_messages(request, False)
+
+ assert prompt == "conv_prompt"
+ assert stop == [""]
+
+ def test_continue_final_message(self, serving_chat):
+ """Test continue_final_message functionality"""
+ request = ChatCompletionRequest(
+ model="test-model",
+ messages=[
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "Hi there"},
+ ],
+ continue_final_message=True,
+ )
+
+ with patch.object(serving_chat, "_apply_conversation_template") as mock_apply:
+ mock_apply.return_value = ("Hi there", None, None, [], [""])
+
+ (
+ prompt,
+ prompt_ids,
+ image_data,
+ audio_data,
+ modalities,
+ stop,
+ tool_call_constraint,
+ ) = serving_chat._process_messages(request, False)
+
+ # Should handle continue_final_message properly
+ assert prompt == "Hi there"
+
+
+class TestReasoningContent:
+ """Test reasoning content separation from adapter.py"""
+
+ def test_reasoning_content_request(self, serving_chat):
+ """Test request with reasoning content separation"""
+ request = ChatCompletionRequest(
+ model="test-model",
+ messages=[{"role": "user", "content": "Solve this math problem"}],
+ separate_reasoning=True,
+ stream_reasoning=False,
+ )
+
+ with patch.object(serving_chat, "_process_messages") as mock_process:
+ mock_process.return_value = (
+ "Test prompt",
+ [1, 2, 3],
+ None,
+ None,
+ [],
+ [""],
+ None, # tool_call_constraint
+ )
+
+ adapted_request, _ = serving_chat._convert_to_internal_request(
+ [request], ["test-id"]
+ )
+
+ assert adapted_request.rid == "test-id"
+ assert request.separate_reasoning == True
+
+ def test_reasoning_content_response(self, serving_chat):
+ """Test reasoning content in response"""
+ mock_ret_item = {
+ "text": "This is reasoningAnswer: 42",
+ "meta_info": {
+ "output_token_logprobs": [],
+ "output_top_logprobs": None,
+ },
+ }
+
+ # Mock ReasoningParser
+ with patch(
+ "sglang.srt.entrypoints.openai.serving_chat.ReasoningParser"
+ ) as mock_parser_class:
+ mock_parser = Mock()
+ mock_parser.parse_non_stream.return_value = (
+ "This is reasoning",
+ "Answer: 42",
+ )
+ mock_parser_class.return_value = mock_parser
+
+ choice_logprobs = None
+ reasoning_text = None
+ text = mock_ret_item["text"]
+
+ # Simulate reasoning processing
+ enable_thinking = True
+ if enable_thinking:
+ parser = mock_parser_class(model_type="test", stream_reasoning=False)
+ reasoning_text, text = parser.parse_non_stream(text)
+
+ assert reasoning_text == "This is reasoning"
+ assert text == "Answer: 42"
+
+
+class TestSamplingParams:
+ """Test sampling parameter handling from adapter.py"""
+
+ def test_all_sampling_parameters(self, serving_chat):
+ """Test all sampling parameters are properly handled"""
+ request = ChatCompletionRequest(
+ model="test-model",
+ messages=[{"role": "user", "content": "Hello"}],
+ temperature=0.8,
+ max_tokens=150,
+ max_completion_tokens=200,
+ min_tokens=5,
+ top_p=0.9,
+ top_k=50,
+ min_p=0.1,
+ presence_penalty=0.1,
+ frequency_penalty=0.2,
+ repetition_penalty=1.1,
+ stop=["<|endoftext|>"],
+ stop_token_ids=[13, 14],
+ regex=r"\d+",
+ ebnf=" ::= ",
+ n=2,
+ no_stop_trim=True,
+ ignore_eos=True,
+ skip_special_tokens=False,
+ logit_bias={"1": 0.5, "2": -0.3},
+ )
+
+ with patch.object(serving_chat, "_process_messages") as mock_process:
+ mock_process.return_value = (
+ "Test prompt",
+ [1, 2, 3],
+ None,
+ None,
+ [],
+ [""],
+ None, # tool_call_constraint
+ )
+
+ sampling_params = serving_chat._build_sampling_params(
+ request, [""], None
+ )
+
+ # Verify all parameters
+ assert sampling_params["temperature"] == 0.8
+ assert sampling_params["max_new_tokens"] == 150
+ assert sampling_params["min_new_tokens"] == 5
+ assert sampling_params["top_p"] == 0.9
+ assert sampling_params["top_k"] == 50
+ assert sampling_params["min_p"] == 0.1
+ assert sampling_params["presence_penalty"] == 0.1
+ assert sampling_params["frequency_penalty"] == 0.2
+ assert sampling_params["repetition_penalty"] == 1.1
+ assert sampling_params["stop"] == [""]
+ assert sampling_params["logit_bias"] == {"1": 0.5, "2": -0.3}
+
+ def test_response_format_json_schema(self, serving_chat):
+ """Test response format with JSON schema"""
+ request = ChatCompletionRequest(
+ model="test-model",
+ messages=[{"role": "user", "content": "Generate JSON"}],
+ response_format={
+ "type": "json_schema",
+ "json_schema": {
+ "name": "response",
+ "schema": {
+ "type": "object",
+ "properties": {"answer": {"type": "string"}},
+ },
+ },
+ },
+ )
+
+ with patch.object(serving_chat, "_process_messages") as mock_process:
+ mock_process.return_value = (
+ "Test prompt",
+ [1, 2, 3],
+ None,
+ None,
+ [],
+ [""],
+ None, # tool_call_constraint
+ )
+
+ sampling_params = serving_chat._build_sampling_params(
+ request, [""], None
+ )
+
+ assert "json_schema" in sampling_params
+ assert '"type": "object"' in sampling_params["json_schema"]
+
+ def test_response_format_json_object(self, serving_chat):
+ """Test response format with JSON object"""
+ request = ChatCompletionRequest(
+ model="test-model",
+ messages=[{"role": "user", "content": "Generate JSON"}],
+ response_format={"type": "json_object"},
+ )
+
+ with patch.object(serving_chat, "_process_messages") as mock_process:
+ mock_process.return_value = (
+ "Test prompt",
+ [1, 2, 3],
+ None,
+ None,
+ [],
+ [""],
+ None, # tool_call_constraint
+ )
+
+ sampling_params = serving_chat._build_sampling_params(
+ request, [""], None
+ )
+
+ assert sampling_params["json_schema"] == '{"type": "object"}'
diff --git a/test/srt/openai/test_serving_completions.py b/test/srt/openai/test_serving_completions.py
new file mode 100644
index 00000000000..3e8fc42c8c6
--- /dev/null
+++ b/test/srt/openai/test_serving_completions.py
@@ -0,0 +1,176 @@
+"""
+Tests for the refactored completions serving handler
+"""
+
+from unittest.mock import AsyncMock, Mock, patch
+
+import pytest
+
+from sglang.srt.entrypoints.openai.protocol import (
+ CompletionRequest,
+ CompletionResponse,
+ CompletionResponseChoice,
+ CompletionStreamResponse,
+ ErrorResponse,
+)
+from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
+from sglang.srt.managers.io_struct import GenerateReqInput
+from sglang.srt.managers.tokenizer_manager import TokenizerManager
+
+
+@pytest.fixture
+def mock_tokenizer_manager():
+ """Create a mock tokenizer manager"""
+ manager = Mock(spec=TokenizerManager)
+
+ # Mock tokenizer
+ manager.tokenizer = Mock()
+ manager.tokenizer.encode = Mock(return_value=[1, 2, 3, 4])
+ manager.tokenizer.decode = Mock(return_value="decoded text")
+ manager.tokenizer.bos_token_id = 1
+
+ # Mock model config
+ manager.model_config = Mock()
+ manager.model_config.is_multimodal = False
+
+ # Mock server args
+ manager.server_args = Mock()
+ manager.server_args.enable_cache_report = False
+
+ # Mock generation
+ manager.generate_request = AsyncMock()
+ manager.create_abort_task = Mock(return_value=None)
+
+ return manager
+
+
+@pytest.fixture
+def serving_completion(mock_tokenizer_manager):
+ """Create a OpenAIServingCompletion instance"""
+ return OpenAIServingCompletion(mock_tokenizer_manager)
+
+
+class TestPromptHandling:
+ """Test different prompt types and formats from adapter.py"""
+
+ def test_single_string_prompt(self, serving_completion):
+ """Test handling single string prompt"""
+ request = CompletionRequest(
+ model="test-model", prompt="Hello world", max_tokens=100
+ )
+
+ adapted_request, _ = serving_completion._convert_to_internal_request(
+ [request], ["test-id"]
+ )
+
+ assert adapted_request.text == "Hello world"
+
+ def test_single_token_ids_prompt(self, serving_completion):
+ """Test handling single token IDs prompt"""
+ request = CompletionRequest(
+ model="test-model", prompt=[1, 2, 3, 4], max_tokens=100
+ )
+
+ adapted_request, _ = serving_completion._convert_to_internal_request(
+ [request], ["test-id"]
+ )
+
+ assert adapted_request.input_ids == [1, 2, 3, 4]
+
+ def test_completion_template_handling(self, serving_completion):
+ """Test completion template processing"""
+ request = CompletionRequest(
+ model="test-model",
+ prompt="def hello():",
+ suffix="return 'world'",
+ max_tokens=100,
+ )
+
+ with patch(
+ "sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined",
+ return_value=True,
+ ):
+ with patch(
+ "sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request",
+ return_value="processed_prompt",
+ ):
+ adapted_request, _ = serving_completion._convert_to_internal_request(
+ [request], ["test-id"]
+ )
+
+ assert adapted_request.text == "processed_prompt"
+
+
+class TestEchoHandling:
+ """Test echo functionality from adapter.py"""
+
+ def test_echo_with_string_prompt_streaming(self, serving_completion):
+ """Test echo handling with string prompt in streaming"""
+ request = CompletionRequest(
+ model="test-model", prompt="Hello", max_tokens=100, echo=True
+ )
+
+ # Test _get_echo_text method
+ echo_text = serving_completion._get_echo_text(request, 0)
+ assert echo_text == "Hello"
+
+ def test_echo_with_list_of_strings_streaming(self, serving_completion):
+ """Test echo handling with list of strings in streaming"""
+ request = CompletionRequest(
+ model="test-model",
+ prompt=["Hello", "World"],
+ max_tokens=100,
+ echo=True,
+ n=1,
+ )
+
+ echo_text = serving_completion._get_echo_text(request, 0)
+ assert echo_text == "Hello"
+
+ echo_text = serving_completion._get_echo_text(request, 1)
+ assert echo_text == "World"
+
+ def test_echo_with_token_ids_streaming(self, serving_completion):
+ """Test echo handling with token IDs in streaming"""
+ request = CompletionRequest(
+ model="test-model", prompt=[1, 2, 3], max_tokens=100, echo=True
+ )
+
+ serving_completion.tokenizer_manager.tokenizer.decode.return_value = (
+ "decoded_prompt"
+ )
+ echo_text = serving_completion._get_echo_text(request, 0)
+ assert echo_text == "decoded_prompt"
+
+ def test_echo_with_multiple_token_ids_streaming(self, serving_completion):
+ """Test echo handling with multiple token ID prompts in streaming"""
+ request = CompletionRequest(
+ model="test-model", prompt=[[1, 2], [3, 4]], max_tokens=100, echo=True, n=1
+ )
+
+ serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded"
+ echo_text = serving_completion._get_echo_text(request, 0)
+ assert echo_text == "decoded"
+
+ def test_prepare_echo_prompts_non_streaming(self, serving_completion):
+ """Test prepare echo prompts for non-streaming response"""
+ # Test with single string
+ request = CompletionRequest(model="test-model", prompt="Hello", echo=True)
+
+ echo_prompts = serving_completion._prepare_echo_prompts(request)
+ assert echo_prompts == ["Hello"]
+
+ # Test with list of strings
+ request = CompletionRequest(
+ model="test-model", prompt=["Hello", "World"], echo=True
+ )
+
+ echo_prompts = serving_completion._prepare_echo_prompts(request)
+ assert echo_prompts == ["Hello", "World"]
+
+ # Test with token IDs
+ request = CompletionRequest(model="test-model", prompt=[1, 2, 3], echo=True)
+
+ serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded"
+ echo_prompts = serving_completion._prepare_echo_prompts(request)
+ assert echo_prompts == ["decoded"]
diff --git a/test/srt/openai/test_serving_embedding.py b/test/srt/openai/test_serving_embedding.py
new file mode 100644
index 00000000000..58438bba884
--- /dev/null
+++ b/test/srt/openai/test_serving_embedding.py
@@ -0,0 +1,316 @@
+"""
+Unit tests for the OpenAIServingEmbedding class from serving_embedding.py.
+
+These tests ensure that the embedding serving implementation maintains compatibility
+with the original adapter.py functionality and follows OpenAI API specifications.
+"""
+
+import asyncio
+import json
+import time
+import uuid
+from typing import Any, Dict, List
+from unittest.mock import AsyncMock, Mock, patch
+
+import pytest
+from fastapi import Request
+from fastapi.responses import ORJSONResponse
+from pydantic_core import ValidationError
+
+from sglang.srt.entrypoints.openai.protocol import (
+ EmbeddingObject,
+ EmbeddingRequest,
+ EmbeddingResponse,
+ ErrorResponse,
+ MultimodalEmbeddingInput,
+ UsageInfo,
+)
+from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
+from sglang.srt.managers.io_struct import EmbeddingReqInput
+
+
+# Mock TokenizerManager for embedding tests
+class MockTokenizerManager:
+ def __init__(self):
+ self.model_config = Mock()
+ self.model_config.is_multimodal = False
+ self.server_args = Mock()
+ self.server_args.enable_cache_report = False
+ self.model_path = "test-model"
+
+ # Mock tokenizer
+ self.tokenizer = Mock()
+ self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
+ self.tokenizer.decode = Mock(return_value="Test embedding input")
+ self.tokenizer.chat_template = None
+ self.tokenizer.bos_token_id = 1
+
+ # Mock generate_request method for embeddings
+ async def mock_generate_embedding():
+ yield {
+ "embedding": [0.1, 0.2, 0.3, 0.4, 0.5] * 20, # 100-dim embedding
+ "meta_info": {
+ "id": f"embd-{uuid.uuid4()}",
+ "prompt_tokens": 5,
+ },
+ }
+
+ self.generate_request = Mock(return_value=mock_generate_embedding())
+
+
+@pytest.fixture
+def mock_tokenizer_manager():
+ """Create a mock tokenizer manager for testing."""
+ return MockTokenizerManager()
+
+
+@pytest.fixture
+def serving_embedding(mock_tokenizer_manager):
+ """Create an OpenAIServingEmbedding instance for testing."""
+ return OpenAIServingEmbedding(mock_tokenizer_manager)
+
+
+@pytest.fixture
+def mock_request():
+ """Create a mock FastAPI request."""
+ request = Mock(spec=Request)
+ request.headers = {}
+ return request
+
+
+@pytest.fixture
+def basic_embedding_request():
+ """Create a basic embedding request."""
+ return EmbeddingRequest(
+ model="test-model",
+ input="Hello, how are you?",
+ encoding_format="float",
+ )
+
+
+@pytest.fixture
+def list_embedding_request():
+ """Create an embedding request with list input."""
+ return EmbeddingRequest(
+ model="test-model",
+ input=["Hello, how are you?", "I am fine, thank you!"],
+ encoding_format="float",
+ )
+
+
+@pytest.fixture
+def multimodal_embedding_request():
+ """Create a multimodal embedding request."""
+ return EmbeddingRequest(
+ model="test-model",
+ input=[
+ MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
+ MultimodalEmbeddingInput(text="World", image=None),
+ ],
+ encoding_format="float",
+ )
+
+
+@pytest.fixture
+def token_ids_embedding_request():
+ """Create an embedding request with token IDs."""
+ return EmbeddingRequest(
+ model="test-model",
+ input=[1, 2, 3, 4, 5],
+ encoding_format="float",
+ )
+
+
+class TestOpenAIServingEmbeddingConversion:
+ """Test request conversion methods."""
+
+ def test_convert_single_string_request(
+ self, serving_embedding, basic_embedding_request
+ ):
+ """Test converting single string request to internal format."""
+ adapted_request, processed_request = (
+ serving_embedding._convert_to_internal_request(
+ [basic_embedding_request], ["test-id"]
+ )
+ )
+
+ assert isinstance(adapted_request, EmbeddingReqInput)
+ assert adapted_request.text == "Hello, how are you?"
+ assert adapted_request.rid == "test-id"
+ assert processed_request == basic_embedding_request
+
+ def test_convert_list_string_request(
+ self, serving_embedding, list_embedding_request
+ ):
+ """Test converting list of strings request to internal format."""
+ adapted_request, processed_request = (
+ serving_embedding._convert_to_internal_request(
+ [list_embedding_request], ["test-id"]
+ )
+ )
+
+ assert isinstance(adapted_request, EmbeddingReqInput)
+ assert adapted_request.text == ["Hello, how are you?", "I am fine, thank you!"]
+ assert adapted_request.rid == "test-id"
+ assert processed_request == list_embedding_request
+
+ def test_convert_token_ids_request(
+ self, serving_embedding, token_ids_embedding_request
+ ):
+ """Test converting token IDs request to internal format."""
+ adapted_request, processed_request = (
+ serving_embedding._convert_to_internal_request(
+ [token_ids_embedding_request], ["test-id"]
+ )
+ )
+
+ assert isinstance(adapted_request, EmbeddingReqInput)
+ assert adapted_request.input_ids == [1, 2, 3, 4, 5]
+ assert adapted_request.rid == "test-id"
+ assert processed_request == token_ids_embedding_request
+
+ def test_convert_multimodal_request(
+ self, serving_embedding, multimodal_embedding_request
+ ):
+ """Test converting multimodal request to internal format."""
+ adapted_request, processed_request = (
+ serving_embedding._convert_to_internal_request(
+ [multimodal_embedding_request], ["test-id"]
+ )
+ )
+
+ assert isinstance(adapted_request, EmbeddingReqInput)
+ # Should extract text and images separately
+ assert len(adapted_request.text) == 2
+ assert "Hello" in adapted_request.text
+ assert "World" in adapted_request.text
+ assert adapted_request.image_data[0] == "base64_image_data"
+ assert adapted_request.image_data[1] is None
+ assert adapted_request.rid == "test-id"
+
+
+class TestEmbeddingResponseBuilding:
+ """Test response building methods."""
+
+ def test_build_single_embedding_response(self, serving_embedding):
+ """Test building response for single embedding."""
+ ret_data = [
+ {
+ "embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
+ "meta_info": {"prompt_tokens": 5},
+ }
+ ]
+
+ response = serving_embedding._build_embedding_response(ret_data, "test-model")
+
+ assert isinstance(response, EmbeddingResponse)
+ assert response.model == "test-model"
+ assert len(response.data) == 1
+ assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
+ assert response.data[0].index == 0
+ assert response.data[0].object == "embedding"
+ assert response.usage.prompt_tokens == 5
+ assert response.usage.total_tokens == 5
+ assert response.usage.completion_tokens == 0
+
+ def test_build_multiple_embedding_response(self, serving_embedding):
+ """Test building response for multiple embeddings."""
+ ret_data = [
+ {
+ "embedding": [0.1, 0.2, 0.3],
+ "meta_info": {"prompt_tokens": 3},
+ },
+ {
+ "embedding": [0.4, 0.5, 0.6],
+ "meta_info": {"prompt_tokens": 4},
+ },
+ ]
+
+ response = serving_embedding._build_embedding_response(ret_data, "test-model")
+
+ assert isinstance(response, EmbeddingResponse)
+ assert len(response.data) == 2
+ assert response.data[0].embedding == [0.1, 0.2, 0.3]
+ assert response.data[0].index == 0
+ assert response.data[1].embedding == [0.4, 0.5, 0.6]
+ assert response.data[1].index == 1
+ assert response.usage.prompt_tokens == 7 # 3 + 4
+ assert response.usage.total_tokens == 7
+
+
+@pytest.mark.asyncio
+class TestOpenAIServingEmbeddingAsyncMethods:
+ """Test async methods of OpenAIServingEmbedding."""
+
+ async def test_handle_request_success(
+ self, serving_embedding, basic_embedding_request, mock_request
+ ):
+ """Test successful embedding request handling."""
+
+ # Mock the generate_request to return expected data
+ async def mock_generate():
+ yield {
+ "embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
+ "meta_info": {"prompt_tokens": 5},
+ }
+
+ serving_embedding.tokenizer_manager.generate_request = Mock(
+ return_value=mock_generate()
+ )
+
+ response = await serving_embedding.handle_request(
+ basic_embedding_request, mock_request
+ )
+
+ assert isinstance(response, EmbeddingResponse)
+ assert len(response.data) == 1
+ assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
+
+ async def test_handle_request_validation_error(
+ self, serving_embedding, mock_request
+ ):
+ """Test handling request with validation error."""
+ invalid_request = EmbeddingRequest(model="test-model", input="")
+
+ response = await serving_embedding.handle_request(invalid_request, mock_request)
+
+ assert isinstance(response, ORJSONResponse)
+ assert response.status_code == 400
+
+ async def test_handle_request_generation_error(
+ self, serving_embedding, basic_embedding_request, mock_request
+ ):
+ """Test handling request with generation error."""
+
+ # Mock generate_request to raise an error
+ async def mock_generate_error():
+ raise ValueError("Generation failed")
+ yield # This won't be reached but needed for async generator
+
+ serving_embedding.tokenizer_manager.generate_request = Mock(
+ return_value=mock_generate_error()
+ )
+
+ response = await serving_embedding.handle_request(
+ basic_embedding_request, mock_request
+ )
+
+ assert isinstance(response, ORJSONResponse)
+ assert response.status_code == 400
+
+ async def test_handle_request_internal_error(
+ self, serving_embedding, basic_embedding_request, mock_request
+ ):
+ """Test handling request with internal server error."""
+ # Mock _convert_to_internal_request to raise an exception
+ with patch.object(
+ serving_embedding,
+ "_convert_to_internal_request",
+ side_effect=Exception("Internal error"),
+ ):
+ response = await serving_embedding.handle_request(
+ basic_embedding_request, mock_request
+ )
+
+ assert isinstance(response, ORJSONResponse)
+ assert response.status_code == 500