Skip to content

Commit 747dd45

Browse files
authored
feat: throttle requests at scheduler based on --max_queued_requests (#7565)
1 parent b582159 commit 747dd45

File tree

10 files changed

+218
-6
lines changed

10 files changed

+218
-6
lines changed

python/sglang/srt/entrypoints/http_server.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
import requests
3939
import uvicorn
4040
import uvloop
41-
from fastapi import Depends, FastAPI, Request, UploadFile
41+
from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile
4242
from fastapi.exceptions import RequestValidationError
4343
from fastapi.middleware.cors import CORSMiddleware
4444
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
@@ -174,6 +174,18 @@ async def lifespan(fast_api_app: FastAPI):
174174
)
175175

176176

177+
@app.exception_handler(HTTPException)
178+
async def validation_exception_handler(request: Request, exc: HTTPException):
179+
"""Enrich HTTP exception with status code and other details"""
180+
error = ErrorResponse(
181+
object="error",
182+
message=exc.detail,
183+
type=str(exc.status_code),
184+
code=exc.status_code,
185+
)
186+
return ORJSONResponse(content=error.model_dump(), status_code=exc.status_code)
187+
188+
177189
# Custom exception handlers to change validation error status codes
178190
@app.exception_handler(RequestValidationError)
179191
async def validation_exception_handler(request: Request, exc: RequestValidationError):

python/sglang/srt/entrypoints/openai/serving_base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from abc import ABC, abstractmethod
55
from typing import Any, Optional, Union
66

7-
from fastapi import Request
7+
from fastapi import HTTPException, Request
88
from fastapi.responses import ORJSONResponse, StreamingResponse
99

1010
from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
@@ -45,7 +45,10 @@ async def handle_request(
4545
return await self._handle_non_streaming_request(
4646
adapted_request, processed_request, raw_request
4747
)
48-
48+
except HTTPException as e:
49+
return self.create_error_response(
50+
message=e.detail, err_type=str(e.status_code), status_code=e.status_code
51+
)
4952
except Exception as e:
5053
logger.exception(f"Error in request: {e}")
5154
return self.create_error_response(

python/sglang/srt/managers/io_struct.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,8 @@ class AbortReq:
911911
rid: str = ""
912912
# Whether to abort all requests
913913
abort_all: bool = False
914+
# The finished reason data
915+
finished_reason: Optional[Dict[str, Any]] = None
914916

915917

916918
@dataclass

python/sglang/srt/managers/scheduler.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from collections import defaultdict, deque
2525
from concurrent import futures
2626
from dataclasses import dataclass
27+
from http import HTTPStatus
2728
from pathlib import Path
2829
from types import SimpleNamespace
2930
from typing import Dict, List, Optional, Tuple, Union
@@ -370,6 +371,7 @@ def __init__(
370371
self.max_total_num_tokens,
371372
self.max_prefill_tokens,
372373
self.max_running_requests,
374+
self.max_queued_requests,
373375
self.max_req_len,
374376
self.max_req_input_len,
375377
self.random_seed,
@@ -1086,6 +1088,19 @@ def process_input_requests(self, recv_reqs: List):
10861088
self.return_health_check_ct += 1
10871089
continue
10881090

1091+
# If it is a work request, accept or reject the request based on the request queue size.
1092+
if is_work_request(recv_req):
1093+
if len(self.waiting_queue) + 1 > self.max_queued_requests:
1094+
abort_req = AbortReq(
1095+
recv_req.rid,
1096+
finished_reason={
1097+
"type": "abort",
1098+
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
1099+
"message": "The request queue is full.",
1100+
},
1101+
)
1102+
self.send_to_tokenizer.send_pyobj(abort_req)
1103+
continue
10891104
output = self._request_dispatcher(recv_req)
10901105
if output is not None:
10911106
if isinstance(output, RpcReqOutput):
@@ -2902,6 +2917,10 @@ def is_health_check_generate_req(recv_req):
29022917
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
29032918

29042919

2920+
def is_work_request(recv_req):
2921+
return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
2922+
2923+
29052924
def _export_static_state(model):
29062925
return dict(
29072926
buffers=[

python/sglang/srt/managers/tokenizer_manager.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,19 @@ async def _wait_one_response(
766766
):
767767
raise ValueError(finish_reason["message"])
768768

769+
if (
770+
finish_reason.get("type") == "abort"
771+
and finish_reason.get("status_code")
772+
== HTTPStatus.SERVICE_UNAVAILABLE
773+
):
774+
# This is an abort request initiated by scheduler.
775+
# Delete the key to prevent resending abort request to the scheduler and
776+
# to ensure aborted request state is cleaned up.
777+
del self.rid_to_state[state.obj.rid]
778+
raise fastapi.HTTPException(
779+
status_code=finish_reason["status_code"],
780+
detail=finish_reason["message"],
781+
)
769782
yield out
770783
break
771784

@@ -1705,8 +1718,15 @@ def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
17051718
def _handle_abort_req(self, recv_obj):
17061719
state = self.rid_to_state[recv_obj.rid]
17071720
state.finished = True
1708-
state.out_list.append(
1709-
{
1721+
if recv_obj.finished_reason:
1722+
out = {
1723+
"meta_info": {
1724+
"id": recv_obj.rid,
1725+
"finish_reason": recv_obj.finished_reason,
1726+
},
1727+
}
1728+
else:
1729+
out = {
17101730
"text": "",
17111731
"meta_info": {
17121732
"id": recv_obj.rid,
@@ -1718,7 +1738,7 @@ def _handle_abort_req(self, recv_obj):
17181738
"completion_tokens": 0,
17191739
},
17201740
}
1721-
)
1741+
state.out_list.append(out)
17221742
state.event.set()
17231743

17241744
def _handle_open_session_req_output(self, recv_obj):
@@ -1910,8 +1930,10 @@ def handle_recv(self, recv_obj: T):
19101930
#
19111931
# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
19121932
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
1933+
# | http | yes | validation | background task | fast api | del in _handle_abort_req |
19131934
# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
19141935
# | http | yes | running | background task | fast api | del in _handle_batch_output |
1936+
# | http | no | validation | http exception | http exception | del in _handle_abort_req |
19151937
# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
19161938
# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
19171939
#

python/sglang/srt/managers/tp_worker.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ def __init__(
130130
self.model_runner.req_to_token_pool.size,
131131
)
132132
assert self.max_running_requests > 0, "max_running_request is zero"
133+
self.max_queued_requests = server_args.max_queued_requests
134+
assert (
135+
self.max_running_requests > 0
136+
), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
133137
self.max_req_len = min(
134138
self.model_config.context_len - 1,
135139
self.max_total_num_tokens - 1,
@@ -165,6 +169,7 @@ def get_worker_info(self):
165169
self.max_total_num_tokens,
166170
self.max_prefill_tokens,
167171
self.max_running_requests,
172+
self.max_queued_requests,
168173
self.max_req_len,
169174
self.max_req_input_len,
170175
self.random_seed,

python/sglang/srt/server_args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import logging
2020
import os
2121
import random
22+
import sys
2223
import tempfile
2324
from typing import List, Literal, Optional, Union
2425

@@ -74,6 +75,7 @@ class ServerArgs:
7475
# Memory and scheduling
7576
mem_fraction_static: Optional[float] = None
7677
max_running_requests: Optional[int] = None
78+
max_queued_requests: Optional[int] = sys.maxsize
7779
max_total_tokens: Optional[int] = None
7880
chunked_prefill_size: Optional[int] = None
7981
max_prefill_tokens: int = 16384
@@ -805,6 +807,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
805807
default=ServerArgs.max_running_requests,
806808
help="The maximum number of running requests.",
807809
)
810+
parser.add_argument(
811+
"--max-queued-requests",
812+
type=int,
813+
default=ServerArgs.max_queued_requests,
814+
help="The maximum number of queued requests. This option is ignored when using disaggregation-mode.",
815+
)
808816
parser.add_argument(
809817
"--max-total-tokens",
810818
type=int,

python/sglang/test/test_utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from types import SimpleNamespace
2020
from typing import Awaitable, Callable, List, Optional, Tuple
2121

22+
import aiohttp
2223
import numpy as np
2324
import requests
2425
import torch
@@ -1303,6 +1304,58 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
13031304
raise
13041305

13051306

1307+
def send_generate_requests(base_url: str, num_requests: int) -> List[str]:
1308+
"""Sends generate request serially and returns status codes. Max concurrency is 1."""
1309+
1310+
def generate():
1311+
prompt = """
1312+
System: You are a helpful assistant.
1313+
User: What is the capital of France?
1314+
Assistant: The capital of France is
1315+
"""
1316+
response = requests.post(
1317+
f"{base_url}/generate",
1318+
json={
1319+
"text": prompt,
1320+
"sampling_params": {
1321+
"temperature": 0,
1322+
"max_new_tokens": 50,
1323+
},
1324+
},
1325+
)
1326+
return response.status_code
1327+
1328+
return [generate() for _ in range(num_requests)]
1329+
1330+
1331+
async def send_concurrent_generate_requests(
1332+
base_url: str, num_requests: int
1333+
) -> List[str]:
1334+
"""Sends generate request concurrently and returns status codes. Max concurrency is num_requests."""
1335+
1336+
async def async_generate():
1337+
async with aiohttp.ClientSession() as session:
1338+
prompt = """
1339+
System: You are a helpful assistant.
1340+
User: What is the capital of France?
1341+
Assistant: The capital of France is
1342+
"""
1343+
async with session.post(
1344+
f"{base_url}/generate",
1345+
json={
1346+
"text": prompt,
1347+
"sampling_params": {
1348+
"temperature": 0,
1349+
"max_new_tokens": 50,
1350+
},
1351+
},
1352+
) as response:
1353+
return response.status
1354+
1355+
tasks = [asyncio.create_task(async_generate()) for _ in range(num_requests)]
1356+
return await asyncio.gather(*tasks)
1357+
1358+
13061359
class CustomTestCase(unittest.TestCase):
13071360
def _callTestMethod(self, method):
13081361
max_retry = int(

test/srt/run_suite.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class TestFile:
8686
TestFile("test_radix_attention.py", 105),
8787
TestFile("test_regex_constrained.py", 64),
8888
TestFile("test_retract_decode.py", 54),
89+
TestFile("test_request_queue_validation.py", 30),
8990
TestFile("test_server_args.py", 1),
9091
TestFile("test_skip_tokenizer_init.py", 117),
9192
TestFile("test_srt_engine.py", 261),
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import asyncio
2+
import os
3+
import re
4+
import unittest
5+
from concurrent.futures import ThreadPoolExecutor
6+
7+
from sglang.srt.utils import kill_process_tree
8+
from sglang.test.test_utils import (
9+
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
10+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
11+
DEFAULT_URL_FOR_TEST,
12+
STDERR_FILENAME,
13+
STDOUT_FILENAME,
14+
CustomTestCase,
15+
popen_launch_server,
16+
send_concurrent_generate_requests,
17+
send_generate_requests,
18+
)
19+
20+
21+
class TestMaxQueuedRequests(CustomTestCase):
22+
@classmethod
23+
def setUpClass(cls):
24+
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
25+
cls.base_url = DEFAULT_URL_FOR_TEST
26+
27+
cls.stdout = open(STDOUT_FILENAME, "w")
28+
cls.stderr = open(STDERR_FILENAME, "w")
29+
30+
cls.base_url = DEFAULT_URL_FOR_TEST
31+
cls.process = popen_launch_server(
32+
cls.model,
33+
cls.base_url,
34+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
35+
other_args=(
36+
"--max-running-requests", # Enforce max request concurrency is 1
37+
"1",
38+
"--max-queued-requests", # Enforce max queued request number is 1
39+
"1",
40+
),
41+
return_stdout_stderr=(cls.stdout, cls.stderr),
42+
)
43+
44+
@classmethod
45+
def tearDownClass(cls):
46+
kill_process_tree(cls.process.pid)
47+
cls.stdout.close()
48+
cls.stderr.close()
49+
os.remove(STDOUT_FILENAME)
50+
os.remove(STDERR_FILENAME)
51+
52+
def test_max_queued_requests_validation_with_serial_requests(self):
53+
"""Verify request is not throttled when the max concurrency is 1."""
54+
status_codes = send_generate_requests(
55+
self.base_url,
56+
num_requests=10,
57+
)
58+
59+
for status_code in status_codes:
60+
assert status_code == 200 # request shouldn't be throttled
61+
62+
def test_max_queued_requests_validation_with_concurrent_requests(self):
63+
"""Verify request throttling with concurrent requests."""
64+
status_codes = asyncio.run(
65+
send_concurrent_generate_requests(self.base_url, num_requests=10)
66+
)
67+
68+
assert 200 in status_codes
69+
assert 503 in status_codes
70+
assert all(status_code in [200, 503] for status_code in status_codes)
71+
72+
def test_max_running_requests_and_max_queued_request_validation(self):
73+
"""Verify running request and queued request numbers based on server logs."""
74+
rr_pattern = re.compile(r"#running-req:\s*(\d+)")
75+
qr_pattern = re.compile(r"#queue-req:\s*(\d+)")
76+
77+
with open(STDERR_FILENAME) as lines:
78+
for line in lines:
79+
rr_match, qr_match = rr_pattern.search(line), qr_pattern.search(line)
80+
if rr_match:
81+
assert int(rr_match.group(1)) <= 1
82+
if qr_match:
83+
assert int(qr_match.group(1)) <= 1
84+
85+
86+
if __name__ == "__main__":
87+
unittest.main()

0 commit comments

Comments
 (0)