Skip to content
Merged
42 changes: 41 additions & 1 deletion python/sglang/srt/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
import time
from typing import Dict, List, Optional, Union

from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic import (
BaseModel,
Field,
field_validator,
model_serializer,
model_validator,
)
from typing_extensions import Literal


Expand Down Expand Up @@ -167,6 +173,7 @@ class CompletionRequest(BaseModel):
temperature: float = 1.0
top_p: float = 1.0
user: Optional[str] = None
return_hidden_states: bool = False

# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1
Expand Down Expand Up @@ -202,6 +209,14 @@ class CompletionResponseChoice(BaseModel):
logprobs: Optional[LogProbs] = None
finish_reason: Literal["stop", "length", "content_filter", "abort"]
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None

@model_serializer(mode="wrap")
def _serialize(self, handler):
data = handler(self)
if self.hidden_states is None:
data.pop("hidden_states", None)
return data


class CompletionResponse(BaseModel):
Expand All @@ -219,6 +234,14 @@ class CompletionResponseStreamChoice(BaseModel):
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None

@model_serializer(mode="wrap")
def _serialize(self, handler):
data = handler(self)
if self.hidden_states is None:
data.pop("hidden_states", None)
return data


class CompletionStreamResponse(BaseModel):
Expand Down Expand Up @@ -376,6 +399,7 @@ class ChatCompletionRequest(BaseModel):
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
default="auto", examples=["none"]
) # noqa
return_hidden_states: bool = False

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -437,6 +461,14 @@ class ChatCompletionResponseChoice(BaseModel):
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
]
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None

@model_serializer(mode="wrap")
def _serialize(self, handler):
data = handler(self)
if self.hidden_states is None:
data.pop("hidden_states", None)
return data


class ChatCompletionResponse(BaseModel):
Expand All @@ -453,6 +485,14 @@ class DeltaMessage(BaseModel):
content: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
hidden_states: Optional[object] = None

@model_serializer(mode="wrap")
def _serialize(self, handler):
data = handler(self)
if self.hidden_states is None:
data.pop("hidden_states", None)
return data


class ChatCompletionResponseStreamChoice(BaseModel):
Expand Down
33 changes: 33 additions & 0 deletions python/sglang/srt/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from sglang.srt.entrypoints.openai.utils import (
detect_template_content_format,
process_content_for_template_format,
process_hidden_states_from_ret,
to_openai_style_logprobs,
)
from sglang.srt.function_call.function_call_parser import FunctionCallParser
Expand Down Expand Up @@ -99,6 +100,7 @@ def _convert_to_internal_request(
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
)

return adapted_request, request
Expand Down Expand Up @@ -402,6 +404,7 @@ async def _generate_chat_stream(
prompt_tokens = {}
completion_tokens = {}
cached_tokens = {}
hidden_states = {}

try:
async for content in self.tokenizer_manager.generate_request(
Expand All @@ -412,6 +415,7 @@ async def _generate_chat_stream(
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"]
cached_tokens[index] = content["meta_info"].get("cached_tokens", 0)
hidden_states[index] = content["meta_info"].get("hidden_states", None)

# Handle logprobs
choice_logprobs = None
Expand Down Expand Up @@ -544,6 +548,31 @@ async def _generate_chat_stream(
)
yield f"data: {finish_reason_chunk.model_dump_json()}\n\n"

# Send hidden states if requested
if request.return_hidden_states and hidden_states:
for index, choice_hidden_states in hidden_states.items():
if choice_hidden_states:
last_token_hidden_states = (
choice_hidden_states[-1]
if len(choice_hidden_states) > 1
else []
)
hidden_states_chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
created=int(time.time()),
choices=[
ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(
hidden_states=last_token_hidden_states
),
finish_reason=finish_reason_type,
)
],
model=request.model,
)
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"

# Additional usage chunk
if request.stream_options and request.stream_options.include_usage:
usage = UsageProcessor.calculate_streaming_usage(
Expand Down Expand Up @@ -608,6 +637,9 @@ def _build_chat_response(
if request.logprobs:
choice_logprobs = self._process_response_logprobs(ret_item)

# Handle hidden states
hidden_states = process_hidden_states_from_ret(ret_item, request)

finish_reason = ret_item["meta_info"]["finish_reason"]
text = ret_item["text"]

Expand Down Expand Up @@ -654,6 +686,7 @@ def _build_chat_response(
if finish_reason and "matched" in finish_reason
else None
),
hidden_states=hidden_states,
)
choices.append(choice_data)

Expand Down
35 changes: 34 additions & 1 deletion python/sglang/srt/entrypoints/openai/serving_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor
from sglang.srt.entrypoints.openai.utils import to_openai_style_logprobs
from sglang.srt.entrypoints.openai.utils import (
process_hidden_states_from_ret,
to_openai_style_logprobs,
)
from sglang.srt.managers.io_struct import GenerateReqInput

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,6 +79,7 @@ def _convert_to_internal_request(
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
bootstrap_room=request.bootstrap_room,
return_hidden_states=request.return_hidden_states,
)

return adapted_request, request
Expand Down Expand Up @@ -188,6 +192,7 @@ async def _generate_completion_stream(
delta = text[len(stream_buffer) :]
stream_buffers[index] = stream_buffer + delta
finish_reason = content["meta_info"]["finish_reason"]
hidden_states = content["meta_info"].get("hidden_states", None)

choice_data = CompletionResponseStreamChoice(
index=index,
Expand All @@ -210,6 +215,30 @@ async def _generate_completion_stream(

yield f"data: {chunk.model_dump_json()}\n\n"

if request.return_hidden_states and hidden_states:
for index, choice_hidden_states in hidden_states.items():
if choice_hidden_states:
last_token_hidden_states = (
choice_hidden_states[-1]
if len(choice_hidden_states) > 1
else []
)
hidden_states_chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
created=created,
object="text_completion",
choices=[
CompletionResponseStreamChoice(
index=index,
text="",
hidden_states=last_token_hidden_states,
finish_reason=None,
)
],
model=request.model,
)
yield f"data: {hidden_states_chunk.model_dump_json()}\n\n"

# Handle final usage chunk
if request.stream_options and request.stream_options.include_usage:
usage = UsageProcessor.calculate_streaming_usage(
Expand Down Expand Up @@ -304,6 +333,9 @@ def _build_completion_response(
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
)

# Handle hidden states
hidden_states = process_hidden_states_from_ret(ret_item, request)

finish_reason = ret_item["meta_info"]["finish_reason"]

choice_data = CompletionResponseChoice(
Expand All @@ -316,6 +348,7 @@ def _build_completion_response(
if finish_reason and "matched" in finish_reason
else None
),
hidden_states=hidden_states,
)
choices.append(choice_data)

Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/entrypoints/openai/usage_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, Dict, List, Mapping, Optional, final

from python.sglang.srt.entrypoints.openai.protocol import UsageInfo
from sglang.srt.entrypoints.openai.protocol import UsageInfo


@final
Expand Down
33 changes: 32 additions & 1 deletion python/sglang/srt/entrypoints/openai/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import logging
from typing import Any, Dict, List, Optional, Union

import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils

from sglang.srt.entrypoints.openai.protocol import LogProbs
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
LogProbs,
UsageInfo,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -205,3 +211,28 @@ def append_top_logprobs(top_logprobs):
append_top_logprobs(output_top_logprobs)

return ret_logprobs


def process_hidden_states_from_ret(
ret_item: Dict[str, Any],
request: Union[
ChatCompletionRequest,
CompletionRequest,
],
) -> Optional[List]:
"""Process hidden states from a ret item in non-streaming response.

Args:
ret_item: Response item containing meta_info
request: The original request object

Returns:
Processed hidden states for the last token, or None
"""
if not request.return_hidden_states:
return None

hidden_states = ret_item["meta_info"].get("hidden_states", None)
if hidden_states is not None:
hidden_states = hidden_states[-1] if len(hidden_states) > 1 else []
return hidden_states
45 changes: 45 additions & 0 deletions test/srt/openai/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,51 @@ def test_chat_completion_stream_response(self):
self.assertEqual(response.choices[0].delta.content, "Hello")


class TestModelSerialization(unittest.TestCase):
"""Test model serialization with hidden states"""

def test_hidden_states_excluded_when_none(self):
"""Test that None hidden_states are excluded with exclude_none=True"""
choice = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content="Hello"),
finish_reason="stop",
hidden_states=None,
)

response = ChatCompletionResponse(
id="test-id",
model="test-model",
choices=[choice],
usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6),
)

# Test exclude_none serialization (should exclude None hidden_states)
data = response.model_dump(exclude_none=True)
self.assertNotIn("hidden_states", data["choices"][0])

def test_hidden_states_included_when_not_none(self):
"""Test that non-None hidden_states are included"""
choice = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content="Hello"),
finish_reason="stop",
hidden_states=[0.1, 0.2, 0.3],
)

response = ChatCompletionResponse(
id="test-id",
model="test-model",
choices=[choice],
usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6),
)

# Test exclude_none serialization (should include non-None hidden_states)
data = response.model_dump(exclude_none=True)
self.assertIn("hidden_states", data["choices"][0])
self.assertEqual(data["choices"][0]["hidden_states"], [0.1, 0.2, 0.3])


class TestValidationEdgeCases(unittest.TestCase):
"""Test edge cases and validation scenarios"""

Expand Down