|
1 | 1 | import json
|
2 | 2 | import random
|
| 3 | +import time |
3 | 4 | import unittest
|
| 5 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
4 | 6 |
|
5 | 7 | import requests
|
6 | 8 |
|
@@ -153,6 +155,82 @@ def test_update_weights_unexist_model(self):
|
153 | 155 | self.assertEqual(origin_response[:32], updated_response[:32])
|
154 | 156 |
|
155 | 157 |
|
| 158 | +class TestServerUpdateWeightsFromDiskAbortAllRequests(CustomTestCase): |
| 159 | + @classmethod |
| 160 | + def setUpClass(cls): |
| 161 | + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST |
| 162 | + cls.base_url = DEFAULT_URL_FOR_TEST |
| 163 | + cls.process = popen_launch_server( |
| 164 | + cls.model, |
| 165 | + cls.base_url, |
| 166 | + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| 167 | + other_args=["--max-running-requests", 8], |
| 168 | + ) |
| 169 | + |
| 170 | + @classmethod |
| 171 | + def tearDownClass(cls): |
| 172 | + kill_process_tree(cls.process.pid) |
| 173 | + |
| 174 | + def run_decode(self, max_new_tokens=32): |
| 175 | + response = requests.post( |
| 176 | + self.base_url + "/generate", |
| 177 | + json={ |
| 178 | + "text": "The capital of France is", |
| 179 | + "sampling_params": { |
| 180 | + "temperature": 0, |
| 181 | + "max_new_tokens": max_new_tokens, |
| 182 | + "ignore_eos": True, |
| 183 | + }, |
| 184 | + }, |
| 185 | + ) |
| 186 | + return response.json() |
| 187 | + |
| 188 | + def get_model_info(self): |
| 189 | + response = requests.get(self.base_url + "/get_model_info") |
| 190 | + model_path = response.json()["model_path"] |
| 191 | + print(json.dumps(response.json())) |
| 192 | + return model_path |
| 193 | + |
| 194 | + def run_update_weights(self, model_path, abort_all_requests=False): |
| 195 | + response = requests.post( |
| 196 | + self.base_url + "/update_weights_from_disk", |
| 197 | + json={ |
| 198 | + "model_path": model_path, |
| 199 | + "abort_all_requests": abort_all_requests, |
| 200 | + }, |
| 201 | + ) |
| 202 | + ret = response.json() |
| 203 | + print(json.dumps(ret)) |
| 204 | + return ret |
| 205 | + |
| 206 | + def test_update_weights_abort_all_requests(self): |
| 207 | + origin_model_path = self.get_model_info() |
| 208 | + print(f"[Server Mode] origin_model_path: {origin_model_path}") |
| 209 | + |
| 210 | + num_requests = 32 |
| 211 | + with ThreadPoolExecutor(num_requests) as executor: |
| 212 | + futures = [ |
| 213 | + executor.submit(self.run_decode, 16000) for _ in range(num_requests) |
| 214 | + ] |
| 215 | + |
| 216 | + # ensure the decode has been started |
| 217 | + time.sleep(2) |
| 218 | + |
| 219 | + new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "") |
| 220 | + ret = self.run_update_weights(new_model_path, abort_all_requests=True) |
| 221 | + self.assertTrue(ret["success"]) |
| 222 | + |
| 223 | + for future in as_completed(futures): |
| 224 | + self.assertEqual( |
| 225 | + future.result()["meta_info"]["finish_reason"]["type"], "abort" |
| 226 | + ) |
| 227 | + |
| 228 | + updated_model_path = self.get_model_info() |
| 229 | + print(f"[Server Mode] updated_model_path: {updated_model_path}") |
| 230 | + self.assertEqual(updated_model_path, new_model_path) |
| 231 | + self.assertNotEqual(updated_model_path, origin_model_path) |
| 232 | + |
| 233 | + |
156 | 234 | ###############################################################################
|
157 | 235 | # Parameterized Tests for update_weights_from_disk
|
158 | 236 | # Test coverage is determined based on the value of is_in_ci:
|
|
0 commit comments