diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 017097e1144..9acd50d02b0 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -16,7 +16,13 @@ import time from typing import Dict, List, Optional, Union -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import ( + BaseModel, + Field, + field_validator, + model_serializer, + model_validator, +) from typing_extensions import Literal @@ -167,6 +173,7 @@ class CompletionRequest(BaseModel): temperature: float = 1.0 top_p: float = 1.0 user: Optional[str] = None + return_hidden_states: bool = False # Extra parameters for SRT backend only and will be ignored by OpenAI models. top_k: int = -1 @@ -202,6 +209,14 @@ class CompletionResponseChoice(BaseModel): logprobs: Optional[LogProbs] = None finish_reason: Literal["stop", "length", "content_filter", "abort"] matched_stop: Union[None, int, str] = None + hidden_states: Optional[object] = None + + @model_serializer(mode="wrap") + def _serialize(self, handler): + data = handler(self) + if self.hidden_states is None: + data.pop("hidden_states", None) + return data class CompletionResponse(BaseModel): @@ -219,6 +234,14 @@ class CompletionResponseStreamChoice(BaseModel): logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None matched_stop: Union[None, int, str] = None + hidden_states: Optional[object] = None + + @model_serializer(mode="wrap") + def _serialize(self, handler): + data = handler(self) + if self.hidden_states is None: + data.pop("hidden_states", None) + return data class CompletionStreamResponse(BaseModel): @@ -376,6 +399,7 @@ class ChatCompletionRequest(BaseModel): tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field( default="auto", examples=["none"] ) # noqa + return_hidden_states: bool = False @model_validator(mode="before") @classmethod @@ -437,6 +461,14 @@ class ChatCompletionResponseChoice(BaseModel): "stop", "length", "tool_calls", "content_filter", "function_call", "abort" ] matched_stop: Union[None, int, str] = None + hidden_states: Optional[object] = None + + @model_serializer(mode="wrap") + def _serialize(self, handler): + data = handler(self) + if self.hidden_states is None: + data.pop("hidden_states", None) + return data class ChatCompletionResponse(BaseModel): @@ -453,6 +485,14 @@ class DeltaMessage(BaseModel): content: Optional[str] = None reasoning_content: Optional[str] = None tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) + hidden_states: Optional[object] = None + + @model_serializer(mode="wrap") + def _serialize(self, handler): + data = handler(self) + if self.hidden_states is None: + data.pop("hidden_states", None) + return data class ChatCompletionResponseStreamChoice(BaseModel): diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 98e622819e3..cb128ff41d5 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -30,6 +30,7 @@ from sglang.srt.entrypoints.openai.utils import ( detect_template_content_format, process_content_for_template_format, + process_hidden_states_from_ret, to_openai_style_logprobs, ) from sglang.srt.function_call.function_call_parser import FunctionCallParser @@ -99,6 +100,7 @@ def _convert_to_internal_request( bootstrap_host=request.bootstrap_host, bootstrap_port=request.bootstrap_port, bootstrap_room=request.bootstrap_room, + return_hidden_states=request.return_hidden_states, ) return adapted_request, request @@ -402,6 +404,7 @@ async def _generate_chat_stream( prompt_tokens = {} completion_tokens = {} cached_tokens = {} + hidden_states = {} try: async for content in self.tokenizer_manager.generate_request( @@ -412,6 +415,7 @@ async def _generate_chat_stream( prompt_tokens[index] = content["meta_info"]["prompt_tokens"] completion_tokens[index] = content["meta_info"]["completion_tokens"] cached_tokens[index] = content["meta_info"].get("cached_tokens", 0) + hidden_states[index] = content["meta_info"].get("hidden_states", None) # Handle logprobs choice_logprobs = None @@ -544,6 +548,31 @@ async def _generate_chat_stream( ) yield f"data: {finish_reason_chunk.model_dump_json()}\n\n" + # Send hidden states if requested + if request.return_hidden_states and hidden_states: + for index, choice_hidden_states in hidden_states.items(): + if choice_hidden_states: + last_token_hidden_states = ( + choice_hidden_states[-1] + if len(choice_hidden_states) > 1 + else [] + ) + hidden_states_chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=int(time.time()), + choices=[ + ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage( + hidden_states=last_token_hidden_states + ), + finish_reason=finish_reason_type, + ) + ], + model=request.model, + ) + yield f"data: {hidden_states_chunk.model_dump_json()}\n\n" + # Additional usage chunk if request.stream_options and request.stream_options.include_usage: usage = UsageProcessor.calculate_streaming_usage( @@ -608,6 +637,9 @@ def _build_chat_response( if request.logprobs: choice_logprobs = self._process_response_logprobs(ret_item) + # Handle hidden states + hidden_states = process_hidden_states_from_ret(ret_item, request) + finish_reason = ret_item["meta_info"]["finish_reason"] text = ret_item["text"] @@ -654,6 +686,7 @@ def _build_chat_response( if finish_reason and "matched" in finish_reason else None ), + hidden_states=hidden_states, ) choices.append(choice_data) diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index eea6dbccc1b..af715b32dfb 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -19,7 +19,10 @@ ) from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor -from sglang.srt.entrypoints.openai.utils import to_openai_style_logprobs +from sglang.srt.entrypoints.openai.utils import ( + process_hidden_states_from_ret, + to_openai_style_logprobs, +) from sglang.srt.managers.io_struct import GenerateReqInput logger = logging.getLogger(__name__) @@ -76,6 +79,7 @@ def _convert_to_internal_request( bootstrap_host=request.bootstrap_host, bootstrap_port=request.bootstrap_port, bootstrap_room=request.bootstrap_room, + return_hidden_states=request.return_hidden_states, ) return adapted_request, request @@ -188,6 +192,7 @@ async def _generate_completion_stream( delta = text[len(stream_buffer) :] stream_buffers[index] = stream_buffer + delta finish_reason = content["meta_info"]["finish_reason"] + hidden_states = content["meta_info"].get("hidden_states", None) choice_data = CompletionResponseStreamChoice( index=index, @@ -210,6 +215,30 @@ async def _generate_completion_stream( yield f"data: {chunk.model_dump_json()}\n\n" + if request.return_hidden_states and hidden_states: + for index, choice_hidden_states in hidden_states.items(): + if choice_hidden_states: + last_token_hidden_states = ( + choice_hidden_states[-1] + if len(choice_hidden_states) > 1 + else [] + ) + hidden_states_chunk = CompletionStreamResponse( + id=content["meta_info"]["id"], + created=created, + object="text_completion", + choices=[ + CompletionResponseStreamChoice( + index=index, + text="", + hidden_states=last_token_hidden_states, + finish_reason=None, + ) + ], + model=request.model, + ) + yield f"data: {hidden_states_chunk.model_dump_json()}\n\n" + # Handle final usage chunk if request.stream_options and request.stream_options.include_usage: usage = UsageProcessor.calculate_streaming_usage( @@ -304,6 +333,9 @@ def _build_completion_response( output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"], ) + # Handle hidden states + hidden_states = process_hidden_states_from_ret(ret_item, request) + finish_reason = ret_item["meta_info"]["finish_reason"] choice_data = CompletionResponseChoice( @@ -316,6 +348,7 @@ def _build_completion_response( if finish_reason and "matched" in finish_reason else None ), + hidden_states=hidden_states, ) choices.append(choice_data) diff --git a/python/sglang/srt/entrypoints/openai/utils.py b/python/sglang/srt/entrypoints/openai/utils.py index 06e5e4dee10..e80125cf64b 100644 --- a/python/sglang/srt/entrypoints/openai/utils.py +++ b/python/sglang/srt/entrypoints/openai/utils.py @@ -1,9 +1,14 @@ import logging +from typing import Any, Dict, List, Optional, Union import jinja2.nodes import transformers.utils.chat_template_utils as hf_chat_utils -from sglang.srt.entrypoints.openai.protocol import LogProbs +from sglang.srt.entrypoints.openai.protocol import ( + ChatCompletionRequest, + CompletionRequest, + LogProbs, +) logger = logging.getLogger(__name__) @@ -205,3 +210,28 @@ def append_top_logprobs(top_logprobs): append_top_logprobs(output_top_logprobs) return ret_logprobs + + +def process_hidden_states_from_ret( + ret_item: Dict[str, Any], + request: Union[ + ChatCompletionRequest, + CompletionRequest, + ], +) -> Optional[List]: + """Process hidden states from a ret item in non-streaming response. + + Args: + ret_item: Response item containing meta_info + request: The original request object + + Returns: + Processed hidden states for the last token, or None + """ + if not request.return_hidden_states: + return None + + hidden_states = ret_item["meta_info"].get("hidden_states", None) + if hidden_states is not None: + hidden_states = hidden_states[-1] if len(hidden_states) > 1 else [] + return hidden_states diff --git a/test/srt/openai/test_protocol.py b/test/srt/openai/test_protocol.py index d096ee97cd1..06260024ae3 100644 --- a/test/srt/openai/test_protocol.py +++ b/test/srt/openai/test_protocol.py @@ -632,6 +632,51 @@ def test_chat_completion_stream_response(self): self.assertEqual(response.choices[0].delta.content, "Hello") +class TestModelSerialization(unittest.TestCase): + """Test model serialization with hidden states""" + + def test_hidden_states_excluded_when_none(self): + """Test that None hidden_states are excluded with exclude_none=True""" + choice = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content="Hello"), + finish_reason="stop", + hidden_states=None, + ) + + response = ChatCompletionResponse( + id="test-id", + model="test-model", + choices=[choice], + usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6), + ) + + # Test exclude_none serialization (should exclude None hidden_states) + data = response.model_dump(exclude_none=True) + self.assertNotIn("hidden_states", data["choices"][0]) + + def test_hidden_states_included_when_not_none(self): + """Test that non-None hidden_states are included""" + choice = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content="Hello"), + finish_reason="stop", + hidden_states=[0.1, 0.2, 0.3], + ) + + response = ChatCompletionResponse( + id="test-id", + model="test-model", + choices=[choice], + usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6), + ) + + # Test exclude_none serialization (should include non-None hidden_states) + data = response.model_dump(exclude_none=True) + self.assertIn("hidden_states", data["choices"][0]) + self.assertEqual(data["choices"][0]["hidden_states"], [0.1, 0.2, 0.3]) + + class TestValidationEdgeCases(unittest.TestCase): """Test edge cases and validation scenarios"""