Skip to content

Commit beddf1e

Browse files
whybeyoungch-wan
authored andcommitted
[Feature] Simple Improve Health Check Mechanism for Production-Grade Stability (#8115)
Signed-off-by: ybyang <[email protected]>
1 parent fcb2246 commit beddf1e

File tree

6 files changed

+82
-11
lines changed

6 files changed

+82
-11
lines changed

python/sglang/srt/entrypoints/engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
6666
from sglang.srt.utils import (
6767
MultiprocessingSerializer,
68+
ServerStatus,
6869
assert_pkg_version,
6970
configure_logger,
7071
get_zmq_socket,
@@ -73,6 +74,7 @@
7374
launch_dummy_health_check_server,
7475
maybe_set_triton_cache_manager,
7576
prepare_model_and_tokenizer,
77+
report_health,
7678
set_prometheus_multiproc_dir,
7779
set_ulimit,
7880
)
@@ -661,6 +663,7 @@ def _set_envs_and_config(server_args: ServerArgs):
661663
def sigchld_handler(signum, frame):
662664
pid, exitcode = os.waitpid(0, os.WNOHANG)
663665
if exitcode != 0:
666+
report_health(ServerStatus.Crashed, server_args.host, server_args.port)
664667
logger.warning(
665668
f"Child process unexpectedly failed with {exitcode=}. {pid=}"
666669
)
@@ -674,6 +677,7 @@ def sigquit_handler(signum, frame):
674677
logger.error(
675678
"Received sigquit from a child process. It usually means the child failed."
676679
)
680+
report_health(ServerStatus.Crashed, server_args.host, server_args.port)
677681
kill_process_tree(os.getpid())
678682

679683
signal.signal(signal.SIGQUIT, sigquit_handler)

python/sglang/srt/entrypoints/http_server.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
ParseFunctionCallReq,
7878
ProfileReqInput,
7979
ReleaseMemoryOccupationReqInput,
80+
ReportHealthInput,
8081
ResumeMemoryOccupationReqInput,
8182
SeparateReasoningReqInput,
8283
SetInternalStateReq,
@@ -93,6 +94,7 @@
9394
from sglang.srt.reasoning_parser import ReasoningParser
9495
from sglang.srt.server_args import ServerArgs
9596
from sglang.srt.utils import (
97+
ServerStatus,
9698
add_api_key_middleware,
9799
add_prometheus_middleware,
98100
delete_directory,
@@ -220,8 +222,31 @@ async def validate_json_request(raw_request: Request):
220222

221223
@app.get("/health")
222224
async def health() -> Response:
223-
"""Check the health of the http server."""
224-
return Response(status_code=200)
225+
"""Check the status of the http server."""
226+
code = HTTPStatus.SERVICE_UNAVAILABLE.value
227+
if _global_state.tokenizer_manager.server_status == ServerStatus.Up:
228+
code = HTTPStatus.OK.value
229+
return Response(
230+
status_code=code,
231+
content=json.dumps(
232+
{"status": _global_state.tokenizer_manager.server_status.value}
233+
),
234+
)
235+
236+
237+
@app.post("/health")
238+
async def health_update(obj: ReportHealthInput, request: Request) -> Response:
239+
"""Update the Status of the http server."""
240+
try:
241+
server_status = ServerStatus(obj.status)
242+
_global_state.tokenizer_manager.server_status = server_status
243+
if server_status != ServerStatus.Up:
244+
return Response(
245+
status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, content=obj.msg
246+
)
247+
except Exception as e:
248+
logger.error(e)
249+
return Response(status_code=HTTPStatus.SERVICE_UNAVAILABLE.value)
225250

226251

227252
@app.get("/health_generate")
@@ -256,7 +281,7 @@ async def gen():
256281
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
257282
task.cancel()
258283
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
259-
_global_state.tokenizer_manager.health_check_failed = False
284+
_global_state.tokenizer_manager.server_status = ServerStatus.Up
260285
return Response(status_code=200)
261286

262287
task.cancel()
@@ -270,7 +295,7 @@ async def gen():
270295
f"last_heartbeat time: {last_receive_time}"
271296
)
272297
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
273-
_global_state.tokenizer_manager.health_check_failed = True
298+
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
274299
return Response(status_code=503)
275300

276301

@@ -1022,9 +1047,13 @@ def _execute_server_warmup(
10221047
headers=headers,
10231048
timeout=600,
10241049
)
1025-
assert res.status_code == 200, f"{res}"
1050+
if res.status_code == 200:
1051+
_global_state.tokenizer_manager.server_status = ServerStatus.Up
1052+
else:
1053+
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
1054+
logger.info(f"{res}")
10261055
else:
1027-
logger.info(f"Start of prefill warmup ...")
1056+
logger.info(f"Start of prefill/decode warmup ...")
10281057
json_data = {
10291058
"sampling_params": {
10301059
"temperature": 0.0,
@@ -1046,15 +1075,25 @@ def _execute_server_warmup(
10461075
headers=headers,
10471076
timeout=1800, # because of deep gemm precache is very long if not precache.
10481077
)
1049-
logger.info(
1050-
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
1051-
)
1078+
if res.status_code == 200:
1079+
logger.info(
1080+
f"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}"
1081+
)
1082+
_global_state.tokenizer_manager.server_status = ServerStatus.Up
1083+
else:
1084+
logger.info(
1085+
"Prefill disaggregation mode warm Up Failed, status code: {}".format(
1086+
res.status_code
1087+
)
1088+
)
1089+
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
10521090

10531091
except Exception:
10541092
last_traceback = get_exception_traceback()
10551093
if pipe_finish_writer is not None:
10561094
pipe_finish_writer.send(last_traceback)
10571095
logger.error(f"Initialization failed. warmup error: {last_traceback}")
1096+
_global_state.tokenizer_manager.server_status = ServerStatus.Crashed
10581097
kill_process_tree(os.getpid())
10591098
return False
10601099

python/sglang/srt/managers/io_struct.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,3 +1083,9 @@ class LoRAUpdateResult:
10831083

10841084

10851085
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
1086+
1087+
1088+
@dataclass
1089+
class ReportHealthInput:
1090+
status: str
1091+
msg: Optional[str] = ""

python/sglang/srt/managers/scheduler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@
143143
from sglang.srt.utils import (
144144
DeepEPMode,
145145
DynamicGradMode,
146+
ServerStatus,
146147
broadcast_pyobj,
147148
configure_gc_logger,
148149
configure_logger,
@@ -154,6 +155,7 @@
154155
kill_itself_when_parent_died,
155156
point_to_point_pyobj,
156157
pyspy_dump_schedulers,
158+
report_health,
157159
require_mlp_sync,
158160
require_mlp_tp_gather,
159161
set_gpu_proc_affinity,
@@ -2964,4 +2966,5 @@ def run_scheduler_process(
29642966
except Exception:
29652967
traceback = get_exception_traceback()
29662968
logger.error(f"Scheduler hit an exception: {traceback}")
2969+
report_health(ServerStatus.Crashed, server_args.host, ServerArgs.port)
29672970
parent_process.send_signal(signal.SIGQUIT)

python/sglang/srt/managers/tokenizer_manager.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
from sglang.srt.sampling.sampling_params import SamplingParams
117117
from sglang.srt.server_args import PortArgs, ServerArgs
118118
from sglang.srt.utils import (
119+
ServerStatus,
119120
dataclass_to_string_truncated,
120121
get_bool_env_var,
121122
get_zmq_socket,
@@ -173,6 +174,9 @@ def __init__(
173174
server_args: ServerArgs,
174175
port_args: PortArgs,
175176
):
177+
# Server Status
178+
self.server_status = ServerStatus.Starting
179+
176180
# Parse args
177181
self.server_args = server_args
178182
self.enable_metrics = server_args.enable_metrics
@@ -251,7 +255,6 @@ def __init__(
251255
# Store states
252256
self.no_create_loop = False
253257
self.rid_to_state: Dict[str, ReqState] = {}
254-
self.health_check_failed = False
255258
self.gracefully_exit = False
256259
self.last_receive_tstamp = 0
257260
self.dump_requests_folder = "" # By default do not dump
@@ -1332,7 +1335,7 @@ async def sigterm_watchdog(self):
13321335
while True:
13331336
remain_num_req = len(self.rid_to_state)
13341337

1335-
if self.health_check_failed:
1338+
if not self.server_status.is_healthy():
13361339
# if health check failed, we should exit immediately
13371340
logger.error(
13381341
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",

python/sglang/srt/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,22 @@
9393
HIP_FP8_E4M3_FNUZ_MAX = 224.0
9494

9595

96+
class ServerStatus(Enum):
97+
Up = "Up"
98+
Starting = "Starting"
99+
UnHealthy = "UnHealthy"
100+
Crashed = "Crashed"
101+
102+
def is_healthy(self) -> bool:
103+
return self == ServerStatus.Up
104+
105+
106+
def report_health(status: ServerStatus, host: str, http_port: int, msg: str = ""):
107+
requests.post(
108+
f"http://{host}:{http_port}/health", json={"status": status.value, "msg": msg}
109+
)
110+
111+
96112
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
97113
def is_hip() -> bool:
98114
return torch.version.hip is not None

0 commit comments

Comments
 (0)