Skip to content
43 changes: 40 additions & 3 deletions python/sglang/compile_deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,36 @@ def launch_server_process_and_send_one_request(
headers = {
"Content-Type": "application/json; charset=utf-8",
}
response = requests.get(f"{base_url}/v1/models", headers=headers)
if server_args.node_rank == 0:
response = requests.get(f"{base_url}/v1/models", headers=headers)
else:
# This http api is created by launch_dummy_health_check_server for none-rank0 node.
response = requests.get(f"{base_url}/health", headers=headers)
if response.status_code == 200:
# Rank-0 node send a request to sync with other node and then return.
if server_args.node_rank == 0:
response = requests.post(
f"{base_url}/generate",
json={
"input_ids": [0, 1, 2, 3],
"sampling_params": {
"max_new_tokens": 8,
"temperature": 0,
},
},
timeout=600,
)
if response.status_code != 200:
error = response.json()
raise RuntimeError(f"Sync request failed: {error}")
# Other nodes should wait for the exit signal from Rank-0 node.
else:
start_time_waiting = time.time()
while proc.is_alive():
if time.time() - start_time_waiting < timeout:
time.sleep(10)
else:
raise TimeoutError("Waiting for main node timeout!")
return proc
except requests.RequestException:
pass
Expand Down Expand Up @@ -122,10 +150,19 @@ def run_compile(server_args: ServerArgs, compile_args: CompileArgs):

proc = launch_server_process_and_send_one_request(server_args, compile_args)

kill_process_tree(proc.pid)

print("\nDeepGEMM Kernels compilation finished successfully.")

# Sleep for safety
time.sleep(10)
if proc.is_alive():
# This is the rank0 node.
kill_process_tree(proc.pid)
else:
try:
kill_process_tree(proc.pid)
except Exception:
pass


if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down
Loading