Skip to content

Commit 5094845

Browse files
committed
add unit test for aborting all requests during update weights
Signed-off-by: Tianyu Zhou <[email protected]>
1 parent 0116262 commit 5094845

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

test/srt/test_update_weights_from_disk.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import json
22
import random
3+
import time
34
import unittest
5+
from concurrent.futures import ThreadPoolExecutor, as_completed
46

57
import requests
68

@@ -153,6 +155,82 @@ def test_update_weights_unexist_model(self):
153155
self.assertEqual(origin_response[:32], updated_response[:32])
154156

155157

158+
class TestServerUpdateWeightsFromDiskAbortAllRequests(CustomTestCase):
159+
@classmethod
160+
def setUpClass(cls):
161+
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
162+
cls.base_url = DEFAULT_URL_FOR_TEST
163+
cls.process = popen_launch_server(
164+
cls.model,
165+
cls.base_url,
166+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
167+
other_args=["--max-running-requests", 8],
168+
)
169+
170+
@classmethod
171+
def tearDownClass(cls):
172+
kill_process_tree(cls.process.pid)
173+
174+
def run_decode(self, max_new_tokens=32):
175+
response = requests.post(
176+
self.base_url + "/generate",
177+
json={
178+
"text": "The capital of France is",
179+
"sampling_params": {
180+
"temperature": 0,
181+
"max_new_tokens": max_new_tokens,
182+
"ignore_eos": True,
183+
},
184+
},
185+
)
186+
return response.json()
187+
188+
def get_model_info(self):
189+
response = requests.get(self.base_url + "/get_model_info")
190+
model_path = response.json()["model_path"]
191+
print(json.dumps(response.json()))
192+
return model_path
193+
194+
def run_update_weights(self, model_path, abort_all_requests=False):
195+
response = requests.post(
196+
self.base_url + "/update_weights_from_disk",
197+
json={
198+
"model_path": model_path,
199+
"abort_all_requests": abort_all_requests,
200+
},
201+
)
202+
ret = response.json()
203+
print(json.dumps(ret))
204+
return ret
205+
206+
def test_update_weights_abort_all_requests(self):
207+
origin_model_path = self.get_model_info()
208+
print(f"[Server Mode] origin_model_path: {origin_model_path}")
209+
210+
num_requests = 32
211+
with ThreadPoolExecutor(num_requests) as executor:
212+
futures = [
213+
executor.submit(self.run_decode, 16000) for _ in range(num_requests)
214+
]
215+
216+
# ensure the decode has been started
217+
time.sleep(2)
218+
219+
new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
220+
ret = self.run_update_weights(new_model_path, abort_all_requests=True)
221+
self.assertTrue(ret["success"])
222+
223+
for future in as_completed(futures):
224+
self.assertEqual(
225+
future.result()["meta_info"]["finish_reason"]["type"], "abort"
226+
)
227+
228+
updated_model_path = self.get_model_info()
229+
print(f"[Server Mode] updated_model_path: {updated_model_path}")
230+
self.assertEqual(updated_model_path, new_model_path)
231+
self.assertNotEqual(updated_model_path, origin_model_path)
232+
233+
156234
###############################################################################
157235
# Parameterized Tests for update_weights_from_disk
158236
# Test coverage is determined based on the value of is_in_ci:

0 commit comments

Comments
 (0)