Skip to content

Commit f87a6ab

Browse files
authored
Resolves the 404 Not Found error when running compile_deep_gemm.py in multi-node setups (#5720)
1 parent eebfdb9 commit f87a6ab

File tree

1 file changed

+40
-3
lines changed

1 file changed

+40
-3
lines changed

python/sglang/compile_deep_gemm.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,36 @@ def launch_server_process_and_send_one_request(
8888
headers = {
8989
"Content-Type": "application/json; charset=utf-8",
9090
}
91-
response = requests.get(f"{base_url}/v1/models", headers=headers)
91+
if server_args.node_rank == 0:
92+
response = requests.get(f"{base_url}/v1/models", headers=headers)
93+
else:
94+
# This http api is created by launch_dummy_health_check_server for none-rank0 node.
95+
response = requests.get(f"{base_url}/health", headers=headers)
9296
if response.status_code == 200:
97+
# Rank-0 node send a request to sync with other node and then return.
98+
if server_args.node_rank == 0:
99+
response = requests.post(
100+
f"{base_url}/generate",
101+
json={
102+
"input_ids": [0, 1, 2, 3],
103+
"sampling_params": {
104+
"max_new_tokens": 8,
105+
"temperature": 0,
106+
},
107+
},
108+
timeout=600,
109+
)
110+
if response.status_code != 200:
111+
error = response.json()
112+
raise RuntimeError(f"Sync request failed: {error}")
113+
# Other nodes should wait for the exit signal from Rank-0 node.
114+
else:
115+
start_time_waiting = time.time()
116+
while proc.is_alive():
117+
if time.time() - start_time_waiting < timeout:
118+
time.sleep(10)
119+
else:
120+
raise TimeoutError("Waiting for main node timeout!")
93121
return proc
94122
except requests.RequestException:
95123
pass
@@ -122,10 +150,19 @@ def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
122150

123151
proc = launch_server_process_and_send_one_request(server_args, compile_args)
124152

125-
kill_process_tree(proc.pid)
126-
127153
print("\nDeepGEMM Kernels compilation finished successfully.")
128154

155+
# Sleep for safety
156+
time.sleep(10)
157+
if proc.is_alive():
158+
# This is the rank0 node.
159+
kill_process_tree(proc.pid)
160+
else:
161+
try:
162+
kill_process_tree(proc.pid)
163+
except Exception:
164+
pass
165+
129166

130167
if __name__ == "__main__":
131168
parser = argparse.ArgumentParser()

0 commit comments

Comments
 (0)