Skip to content

Commit 0354142

Browse files
tianyuzhou95zhuzilin
authored andcommitted
Support updating weights at once by stopping all requests (sgl-project#6698)
Signed-off-by: Tianyu Zhou <[email protected]> Co-authored-by: Zilin Zhu <[email protected]>
1 parent b108a75 commit 0354142

File tree

7 files changed

+190
-13
lines changed

7 files changed

+190
-13
lines changed

python/sglang/srt/entrypoints/http_server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,9 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
662662
async def abort_request(obj: AbortReq, request: Request):
663663
"""Abort a request."""
664664
try:
665-
_global_state.tokenizer_manager.abort_request(rid=obj.rid)
665+
_global_state.tokenizer_manager.abort_request(
666+
rid=obj.rid, abort_all=obj.abort_all
667+
)
666668
return Response(status_code=200)
667669
except Exception as e:
668670
return _create_error_response(e)

python/sglang/srt/entrypoints/openai/protocol.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ class CompletionResponseStreamChoice(BaseModel):
236236
index: int
237237
text: str
238238
logprobs: Optional[LogProbs] = None
239-
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
239+
finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None
240240
matched_stop: Union[None, int, str] = None
241241
hidden_states: Optional[object] = None
242242

@@ -510,7 +510,9 @@ class ChatCompletionResponseStreamChoice(BaseModel):
510510
delta: DeltaMessage
511511
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
512512
finish_reason: Optional[
513-
Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
513+
Literal[
514+
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
515+
]
514516
] = None
515517
matched_stop: Union[None, int, str] = None
516518

python/sglang/srt/managers/io_struct.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,8 @@ class UpdateWeightFromDiskReqInput:
740740
model_path: str
741741
# The format to load the weights
742742
load_format: Optional[str] = None
743+
# Whether to abort all requests before updating weights
744+
abort_all_requests: bool = False
743745

744746

745747
@dataclass
@@ -759,6 +761,8 @@ class UpdateWeightsFromDistributedReqInput:
759761
group_name: str = "weight_update_group"
760762
# Whether to flush the cache after updating weights
761763
flush_cache: bool = True
764+
# Whether to abort all requests before updating weights
765+
abort_all_requests: bool = False
762766

763767

764768
@dataclass
@@ -780,6 +784,8 @@ class UpdateWeightsFromTensorReqInput:
780784
load_format: Optional[str] = None
781785
# Whether to flush the cache after updating weights
782786
flush_cache: bool = True
787+
# Whether to abort all requests before updating weights
788+
abort_all_requests: bool = False
783789

784790

785791
@dataclass
@@ -858,7 +864,9 @@ class SlowDownReqOutput:
858864
@dataclass
859865
class AbortReq:
860866
# The request id
861-
rid: str
867+
rid: str = ""
868+
# Whether to abort all requests
869+
abort_all: bool = False
862870

863871

864872
@dataclass

python/sglang/srt/managers/scheduler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2211,7 +2211,7 @@ def abort_request(self, recv_req: AbortReq):
22112211
# Delete requests in the waiting queue
22122212
to_del = []
22132213
for i, req in enumerate(self.waiting_queue):
2214-
if req.rid.startswith(recv_req.rid):
2214+
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
22152215
to_del.append(i)
22162216

22172217
# Sort in reverse order to avoid index issues when deleting
@@ -2228,7 +2228,7 @@ def abort_request(self, recv_req: AbortReq):
22282228
# Abort method 2: call `set_finish_with_abort`
22292229
# The request will still run one prefill forward pass.
22302230
# In this case, we change the input_ids to be only one token to make this prefill cheap.
2231-
if req.rid.startswith(recv_req.rid):
2231+
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
22322232
logger.debug(f"Abort grammar queue request. {req.rid=}")
22332233
if req.grammar:
22342234
req.grammar.cancel()
@@ -2241,7 +2241,9 @@ def abort_request(self, recv_req: AbortReq):
22412241
reqs = self.running_batch.reqs + self.cur_batch.reqs
22422242

22432243
for req in reqs:
2244-
if req.rid.startswith(recv_req.rid) and not req.finished():
2244+
if not req.finished() and (
2245+
recv_req.abort_all or req.rid.startswith(recv_req.rid)
2246+
):
22452247
# Abort method 3: set `to_abort=True`
22462248
# The request will still run one decode forward pass.
22472249
# Then we reuse all existing code to clean up the KV cache allocation.

python/sglang/srt/managers/tokenizer_manager.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -846,10 +846,10 @@ async def _handle_batch_request(
846846
async def flush_cache(self) -> FlushCacheReqOutput:
847847
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
848848

849-
def abort_request(self, rid: str):
850-
if rid not in self.rid_to_state:
849+
def abort_request(self, rid: str = "", abort_all: bool = False):
850+
if not abort_all and rid not in self.rid_to_state:
851851
return
852-
req = AbortReq(rid)
852+
req = AbortReq(rid, abort_all)
853853
self.send_to_scheduler.send_pyobj(req)
854854

855855
if self.enable_metrics:
@@ -914,6 +914,9 @@ async def update_weights_from_disk(
914914
obj.load_format = self.server_args.load_format
915915
logger.info("Start update_weights. Load format=%s", obj.load_format)
916916

917+
if obj.abort_all_requests:
918+
self.abort_request(abort_all=True)
919+
917920
if True: # Keep this redundant check to simplify some internal code sync
918921
# Hold the lock if it is not async. This means that weight sync
919922
# cannot run while requests are in progress.
@@ -969,6 +972,9 @@ async def update_weights_from_distributed(
969972
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
970973
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
971974

975+
if obj.abort_all_requests:
976+
self.abort_request(abort_all=True)
977+
972978
# This means that weight sync
973979
# cannot run while requests are in progress.
974980
async with self.model_update_lock.writer_lock:
@@ -985,6 +991,9 @@ async def update_weights_from_tensor(
985991
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
986992
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
987993

994+
if obj.abort_all_requests:
995+
self.abort_request(abort_all=True)
996+
988997
# This means that weight sync
989998
# cannot run while requests are in progress.
990999
async with self.model_update_lock.writer_lock:
@@ -1619,7 +1628,23 @@ def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
16191628
self.crash_dump_request_list.popleft()
16201629

16211630
def _handle_abort_req(self, recv_obj):
1622-
self.rid_to_state.pop(recv_obj.rid, None)
1631+
state = self.rid_to_state[recv_obj.rid]
1632+
state.finished = True
1633+
state.out_list.append(
1634+
{
1635+
"text": "",
1636+
"meta_info": {
1637+
"id": recv_obj.rid,
1638+
"finish_reason": {
1639+
"type": "abort",
1640+
"message": "Abort before prefill",
1641+
},
1642+
"prompt_tokens": 0,
1643+
"completion_tokens": 0,
1644+
},
1645+
}
1646+
)
1647+
state.event.set()
16231648

16241649
def _handle_open_session_req_output(self, recv_obj):
16251650
self.session_futures[recv_obj.session_id].set_result(

test/srt/test_abort.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
1+
import json
12
import multiprocessing
23
import time
34
import unittest
4-
from concurrent.futures import ThreadPoolExecutor
5+
from concurrent.futures import ThreadPoolExecutor, as_completed
56

67
import requests
78

8-
from sglang.test.test_utils import CustomTestCase, run_and_check_memory_leak
9+
from sglang.srt.utils import kill_process_tree
10+
from sglang.test.test_utils import (
11+
DEFAULT_MODEL_NAME_FOR_TEST,
12+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
13+
DEFAULT_URL_FOR_TEST,
14+
CustomTestCase,
15+
popen_launch_server,
16+
run_and_check_memory_leak,
17+
)
918

1019

1120
class TestAbort(CustomTestCase):
@@ -50,5 +59,56 @@ def test_memory_leak(self):
5059
)
5160

5261

62+
class TestAbortAll(CustomTestCase):
63+
@classmethod
64+
def setUpClass(cls):
65+
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
66+
cls.base_url = DEFAULT_URL_FOR_TEST
67+
cls.process = popen_launch_server(
68+
cls.model,
69+
cls.base_url,
70+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
71+
other_args=["--max-running-requests", 8],
72+
)
73+
74+
@classmethod
75+
def tearDownClass(cls):
76+
kill_process_tree(cls.process.pid)
77+
78+
def _run_decode(self):
79+
response = requests.post(
80+
self.base_url + "/generate",
81+
json={
82+
"text": "The capital of France is",
83+
"sampling_params": {
84+
"temperature": 0,
85+
"max_new_tokens": 16000,
86+
"ignore_eos": True,
87+
},
88+
},
89+
)
90+
return response.json()
91+
92+
def test_abort_all(self):
93+
num_requests = 32
94+
with ThreadPoolExecutor(num_requests) as executor:
95+
futures = [executor.submit(self._run_decode) for _ in range(num_requests)]
96+
97+
# ensure the decode has been started
98+
time.sleep(2)
99+
100+
requests.post(
101+
self.base_url + "/abort_request",
102+
json={
103+
"abort_all": True,
104+
},
105+
)
106+
107+
for future in as_completed(futures):
108+
self.assertEqual(
109+
future.result()["meta_info"]["finish_reason"]["type"], "abort"
110+
)
111+
112+
53113
if __name__ == "__main__":
54114
unittest.main()

test/srt/test_update_weights_from_disk.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import json
22
import random
3+
import time
34
import unittest
5+
from concurrent.futures import ThreadPoolExecutor, as_completed
46

57
import requests
68

@@ -153,6 +155,82 @@ def test_update_weights_unexist_model(self):
153155
self.assertEqual(origin_response[:32], updated_response[:32])
154156

155157

158+
class TestServerUpdateWeightsFromDiskAbortAllRequests(CustomTestCase):
159+
@classmethod
160+
def setUpClass(cls):
161+
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
162+
cls.base_url = DEFAULT_URL_FOR_TEST
163+
cls.process = popen_launch_server(
164+
cls.model,
165+
cls.base_url,
166+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
167+
other_args=["--max-running-requests", 8],
168+
)
169+
170+
@classmethod
171+
def tearDownClass(cls):
172+
kill_process_tree(cls.process.pid)
173+
174+
def run_decode(self, max_new_tokens=32):
175+
response = requests.post(
176+
self.base_url + "/generate",
177+
json={
178+
"text": "The capital of France is",
179+
"sampling_params": {
180+
"temperature": 0,
181+
"max_new_tokens": max_new_tokens,
182+
"ignore_eos": True,
183+
},
184+
},
185+
)
186+
return response.json()
187+
188+
def get_model_info(self):
189+
response = requests.get(self.base_url + "/get_model_info")
190+
model_path = response.json()["model_path"]
191+
print(json.dumps(response.json()))
192+
return model_path
193+
194+
def run_update_weights(self, model_path, abort_all_requests=False):
195+
response = requests.post(
196+
self.base_url + "/update_weights_from_disk",
197+
json={
198+
"model_path": model_path,
199+
"abort_all_requests": abort_all_requests,
200+
},
201+
)
202+
ret = response.json()
203+
print(json.dumps(ret))
204+
return ret
205+
206+
def test_update_weights_abort_all_requests(self):
207+
origin_model_path = self.get_model_info()
208+
print(f"[Server Mode] origin_model_path: {origin_model_path}")
209+
210+
num_requests = 32
211+
with ThreadPoolExecutor(num_requests) as executor:
212+
futures = [
213+
executor.submit(self.run_decode, 16000) for _ in range(num_requests)
214+
]
215+
216+
# ensure the decode has been started
217+
time.sleep(2)
218+
219+
new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
220+
ret = self.run_update_weights(new_model_path, abort_all_requests=True)
221+
self.assertTrue(ret["success"])
222+
223+
for future in as_completed(futures):
224+
self.assertEqual(
225+
future.result()["meta_info"]["finish_reason"]["type"], "abort"
226+
)
227+
228+
updated_model_path = self.get_model_info()
229+
print(f"[Server Mode] updated_model_path: {updated_model_path}")
230+
self.assertEqual(updated_model_path, new_model_path)
231+
self.assertNotEqual(updated_model_path, origin_model_path)
232+
233+
156234
###############################################################################
157235
# Parameterized Tests for update_weights_from_disk
158236
# Test coverage is determined based on the value of is_in_ci:

0 commit comments

Comments
 (0)