Skip to content

Commit 4df5fc2

Browse files
authored
Feat/refactor embedding server (#7322)
1 parent a06912a commit 4df5fc2

File tree

4 files changed

+76
-120
lines changed

4 files changed

+76
-120
lines changed

python/sglang/srt/entrypoints/openai/api_server.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,10 @@
4040
register_disaggregation_server,
4141
)
4242
from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses
43+
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
4344
from sglang.srt.managers.tokenizer_manager import TokenizerManager
4445
from sglang.srt.metrics.func_timer import enable_func_timer
45-
from sglang.srt.openai_api.protocol import ModelCard, ModelList
46+
from sglang.srt.openai_api.protocol import EmbeddingRequest, ModelCard, ModelList
4647
from sglang.srt.server_args import ServerArgs
4748
from sglang.srt.utils import (
4849
add_prometheus_middleware,
@@ -64,6 +65,7 @@ class AppState:
6465
server_args: Optional[ServerArgs] = None
6566
tokenizer_manager: Optional[TokenizerManager] = None
6667
scheduler_info: Optional[Dict] = None
68+
embedding_server: Optional[OpenAIServingEmbedding] = None
6769

6870

6971
@asynccontextmanager
@@ -78,6 +80,9 @@ async def lifespan(app: FastAPI):
7880
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
7981
app.state.tokenizer_manager = tokenizer_manager
8082
app.state.scheduler_info = scheduler_info
83+
app.state.serving_embedding = OpenAIServingEmbedding(
84+
tokenizer_manager=tokenizer_manager
85+
)
8186

8287
if server_args.enable_metrics:
8388
add_prometheus_middleware(app)
@@ -169,7 +174,16 @@ async def openai_v1_chat_completions(raw_request: Request):
169174

170175
@app.post("/v1/embeddings")
171176
async def openai_v1_embeddings(raw_request: Request):
172-
pass
177+
try:
178+
request_json = await raw_request.json()
179+
request = EmbeddingRequest(**request_json)
180+
except Exception as e:
181+
return app.state.serving_embedding.create_error_response(
182+
f"Invalid request body, error: {str(e)}"
183+
)
184+
185+
ret = await app.state.serving_embedding.handle_request(request, raw_request)
186+
return ret
173187

174188

175189
@app.post("/v1/score")

python/sglang/srt/entrypoints/openai/serving_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ async def handle_request(
3737

3838
# Convert to internal format
3939
adapted_request, processed_request = self._convert_to_internal_request(
40-
[request], [self._generate_request_id_base(request)]
40+
request, self._generate_request_id_base(request)
4141
)
4242

4343
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
@@ -73,8 +73,8 @@ def _generate_request_id_base(self, request: OpenAIServingRequest) -> str:
7373
@abstractmethod
7474
def _convert_to_internal_request(
7575
self,
76-
all_requests: List[OpenAIServingRequest],
77-
request_ids: List[str],
76+
request: OpenAIServingRequest,
77+
request_id: str,
7878
) -> tuple[
7979
GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]]
8080
]:

python/sglang/srt/entrypoints/openai/serving_embedding.py

Lines changed: 47 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -71,111 +71,61 @@ def _validate_request(self, request: EmbeddingRequest) -> Optional[str]:
7171

7272
def _convert_to_internal_request(
7373
self,
74-
all_requests: List[EmbeddingRequest],
75-
request_ids: List[str],
74+
request: EmbeddingRequest,
75+
request_id: str,
7676
) -> tuple[EmbeddingReqInput, Union[EmbeddingRequest, List[EmbeddingRequest]]]:
7777
"""Convert OpenAI embedding request to internal format"""
78-
prompts = [request.input for request in all_requests]
79-
80-
# Handle single vs multiple requests
81-
if len(all_requests) == 1:
82-
prompt = prompts[0]
83-
if isinstance(prompt, str):
84-
# Single string input
78+
prompt = request.input
79+
if isinstance(prompt, str):
80+
# Single string input
81+
prompt_kwargs = {"text": prompt}
82+
elif isinstance(prompt, list):
83+
if len(prompt) > 0 and isinstance(prompt[0], str):
84+
# List of strings
8585
prompt_kwargs = {"text": prompt}
86-
elif isinstance(prompt, list):
87-
if len(prompt) > 0 and isinstance(prompt[0], str):
88-
# List of strings
89-
prompt_kwargs = {"text": prompt}
90-
elif len(prompt) > 0 and isinstance(
91-
prompt[0], MultimodalEmbeddingInput
92-
):
93-
# Handle multimodal embedding inputs
94-
texts = []
95-
images = []
96-
for item in prompt:
97-
# Use padding for text if None - this could be improved
98-
texts.append(item.text if item.text is not None else "padding")
99-
images.append(item.image if item.image is not None else None)
100-
101-
generate_prompts = []
102-
# Check if we have a chat template for multimodal embeddings
103-
# This would need to be passed in from the server configuration
104-
chat_template_name = getattr(
105-
self.tokenizer_manager, "chat_template_name", None
106-
)
107-
if chat_template_name is not None:
108-
convs = generate_embedding_convs(
109-
texts, images, chat_template_name
110-
)
111-
for conv in convs:
112-
generate_prompts.append(conv.get_prompt())
113-
else:
114-
generate_prompts = texts
115-
116-
if len(generate_prompts) == 1:
117-
prompt_kwargs = {
118-
"text": generate_prompts[0],
119-
"image_data": images[0],
120-
}
121-
else:
122-
prompt_kwargs = {
123-
"text": generate_prompts,
124-
"image_data": images,
125-
}
86+
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
87+
# Handle multimodal embedding inputs
88+
texts = []
89+
images = []
90+
for item in prompt:
91+
# Use padding for text if None - this could be improved
92+
texts.append(item.text if item.text is not None else "padding")
93+
images.append(item.image if item.image is not None else None)
94+
95+
generate_prompts = []
96+
# Check if we have a chat template for multimodal embeddings
97+
# This would need to be passed in from the server configuration
98+
chat_template_name = getattr(
99+
self.tokenizer_manager, "chat_template_name", None
100+
)
101+
if chat_template_name is not None:
102+
convs = generate_embedding_convs(texts, images, chat_template_name)
103+
for conv in convs:
104+
generate_prompts.append(conv.get_prompt())
105+
else:
106+
generate_prompts = texts
107+
108+
if len(generate_prompts) == 1:
109+
prompt_kwargs = {
110+
"text": generate_prompts[0],
111+
"image_data": images[0],
112+
}
126113
else:
127-
# List of integers (token IDs) or empty list
128-
prompt_kwargs = {"input_ids": prompt}
114+
prompt_kwargs = {
115+
"text": generate_prompts,
116+
"image_data": images,
117+
}
129118
else:
130-
# Other types (should not happen but handle gracefully)
119+
# List of integers (token IDs) or empty list
131120
prompt_kwargs = {"input_ids": prompt}
132-
# Use the passed request_ids for single request
133-
final_request_id = request_ids[0] if len(all_requests) == 1 else request_ids
134121
else:
135-
# Handle batch requests
136-
if len(prompts) > 0:
137-
# Validate that all prompts have the same type
138-
first_prompt = prompts[0]
139-
first_type = type(first_prompt)
140-
for i, prompt in enumerate(prompts[1:], 1):
141-
if type(prompt) != first_type:
142-
raise AssertionError(
143-
f"All prompts in batch must have the same type, but prompt at index {i} has different type"
144-
)
145-
146-
if isinstance(first_prompt, str):
147-
# Batch of strings
148-
prompt_kwargs = {"text": prompts}
149-
elif isinstance(first_prompt, list):
150-
if len(first_prompt) > 0 and isinstance(first_prompt[0], str):
151-
# Batch of lists of strings
152-
prompt_kwargs = {"text": prompts}
153-
elif len(first_prompt) > 0 and isinstance(
154-
first_prompt[0], MultimodalEmbeddingInput
155-
):
156-
# Handle multimodal batch requests
157-
raise NotImplementedError(
158-
"Multiple requests with multimodal inputs are not supported yet"
159-
)
160-
else:
161-
# Batch of token ID lists
162-
prompt_kwargs = {"input_ids": prompts}
163-
else:
164-
# Other types
165-
prompt_kwargs = {"input_ids": prompts}
166-
else:
167-
prompt_kwargs = {"input_ids": prompts}
168-
# Use the passed request_ids for batch requests
169-
final_request_id = request_ids
170-
122+
# Other types (should not happen but handle gracefully)
123+
prompt_kwargs = {"input_ids": prompt}
171124
adapted_request = EmbeddingReqInput(
172-
rid=final_request_id,
173125
**prompt_kwargs,
174126
)
175127

176-
return adapted_request, (
177-
all_requests[0] if len(all_requests) == 1 else all_requests
178-
)
128+
return adapted_request, request
179129

180130
async def _handle_non_streaming_request(
181131
self,
@@ -194,14 +144,10 @@ async def _handle_non_streaming_request(
194144
if not isinstance(ret, list):
195145
ret = [ret]
196146

197-
response = self._build_embedding_response(
198-
ret, self.tokenizer_manager.model_path
199-
)
147+
response = self._build_embedding_response(ret)
200148
return response
201149

202-
def _build_embedding_response(
203-
self, ret: List[Dict[str, Any]], model_path: str
204-
) -> EmbeddingResponse:
150+
def _build_embedding_response(self, ret: List[Dict[str, Any]]) -> EmbeddingResponse:
205151
"""Build the embedding response"""
206152
embedding_objects = []
207153
prompt_tokens = 0
@@ -219,7 +165,7 @@ def _build_embedding_response(
219165

220166
return EmbeddingResponse(
221167
data=embedding_objects,
222-
model=model_path,
168+
model=self.tokenizer_manager.model_path,
223169
usage=UsageInfo(
224170
prompt_tokens=prompt_tokens,
225171
total_tokens=prompt_tokens,

test/srt/openai/test_serving_embedding.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -95,48 +95,48 @@ def test_convert_single_string_request(self):
9595
"""Test converting single string request to internal format."""
9696
adapted_request, processed_request = (
9797
self.serving_embedding._convert_to_internal_request(
98-
[self.basic_req], ["test-id"]
98+
self.basic_req, "test-id"
9999
)
100100
)
101101

102102
self.assertIsInstance(adapted_request, EmbeddingReqInput)
103103
self.assertEqual(adapted_request.text, "Hello, how are you?")
104-
self.assertEqual(adapted_request.rid, "test-id")
104+
self.assertEqual(adapted_request.rid, None)
105105
self.assertEqual(processed_request, self.basic_req)
106106

107107
def test_convert_list_string_request(self):
108108
"""Test converting list of strings request to internal format."""
109109
adapted_request, processed_request = (
110110
self.serving_embedding._convert_to_internal_request(
111-
[self.list_req], ["test-id"]
111+
self.list_req, "test-id"
112112
)
113113
)
114114

115115
self.assertIsInstance(adapted_request, EmbeddingReqInput)
116116
self.assertEqual(
117117
adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
118118
)
119-
self.assertEqual(adapted_request.rid, "test-id")
119+
self.assertEqual(adapted_request.rid, None)
120120
self.assertEqual(processed_request, self.list_req)
121121

122122
def test_convert_token_ids_request(self):
123123
"""Test converting token IDs request to internal format."""
124124
adapted_request, processed_request = (
125125
self.serving_embedding._convert_to_internal_request(
126-
[self.token_ids_req], ["test-id"]
126+
self.token_ids_req, "test-id"
127127
)
128128
)
129129

130130
self.assertIsInstance(adapted_request, EmbeddingReqInput)
131131
self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
132-
self.assertEqual(adapted_request.rid, "test-id")
132+
self.assertEqual(adapted_request.rid, None)
133133
self.assertEqual(processed_request, self.token_ids_req)
134134

135135
def test_convert_multimodal_request(self):
136136
"""Test converting multimodal request to internal format."""
137137
adapted_request, processed_request = (
138138
self.serving_embedding._convert_to_internal_request(
139-
[self.multimodal_req], ["test-id"]
139+
self.multimodal_req, "test-id"
140140
)
141141
)
142142

@@ -147,7 +147,7 @@ def test_convert_multimodal_request(self):
147147
self.assertIn("World", adapted_request.text)
148148
self.assertEqual(adapted_request.image_data[0], "base64_image_data")
149149
self.assertIsNone(adapted_request.image_data[1])
150-
self.assertEqual(adapted_request.rid, "test-id")
150+
self.assertEqual(adapted_request.rid, None)
151151

152152
def test_build_single_embedding_response(self):
153153
"""Test building response for single embedding."""
@@ -158,9 +158,7 @@ def test_build_single_embedding_response(self):
158158
}
159159
]
160160

161-
response = self.serving_embedding._build_embedding_response(
162-
ret_data, "test-model"
163-
)
161+
response = self.serving_embedding._build_embedding_response(ret_data)
164162

165163
self.assertIsInstance(response, EmbeddingResponse)
166164
self.assertEqual(response.model, "test-model")
@@ -185,9 +183,7 @@ def test_build_multiple_embedding_response(self):
185183
},
186184
]
187185

188-
response = self.serving_embedding._build_embedding_response(
189-
ret_data, "test-model"
190-
)
186+
response = self.serving_embedding._build_embedding_response(ret_data)
191187

192188
self.assertIsInstance(response, EmbeddingResponse)
193189
self.assertEqual(len(response.data), 2)

0 commit comments

Comments
 (0)